Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
......@@ -16,7 +16,7 @@ namespace impl {
namespace {
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COORemoveConsecutive(
COOMatrix coo,
IdArray entries,
......@@ -47,7 +47,7 @@ void COORemoveConsecutive(
}
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COORemoveShuffled(
COOMatrix coo,
IdArray entries,
......@@ -73,7 +73,7 @@ void COORemoveShuffled(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0];
......@@ -98,8 +98,8 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
IdArray::FromVector(new_eids));
}
template COOMatrix COORemove<kDLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDLCPU, int64_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);
}; // namespace impl
}; // namespace aten
......
......@@ -167,7 +167,7 @@ namespace impl {
///////////////////////////// COOSort_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
const int64_t nnz = coo->row->shape[0];
IdType* coo_row = coo->row.Ptr<IdType>();
......@@ -208,13 +208,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
coo->col_sorted = sort_column;
}
template void COOSort_<kDLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDLCPU, int64_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
///////////////////////////// COOIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
IdType* row = coo.row.Ptr<IdType>();
......@@ -230,8 +230,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted};
}
template std::pair<bool, bool> COOIsSorted<kDLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLCPU, int64_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
......
......@@ -17,7 +17,7 @@ using runtime::parallel_for;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *ret_vec) {
......@@ -38,7 +38,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
}
}
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
......@@ -59,7 +59,7 @@ NDArray CSRGetData(
const int64_t retlen = std::max(rowlen, collen);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
NDArray ret = Full(filler, retlen, rows->ctx);
......@@ -106,19 +106,19 @@ NDArray CSRGetData(
return ret;
}
template NDArray CSRGetData<kDLCPU, int32_t, float>(
template NDArray CSRGetData<kDGLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int64_t, float>(
template NDArray CSRGetData<kDGLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int32_t, double>(
template NDArray CSRGetData<kDGLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLCPU, int64_t, double>(
template NDArray CSRGetData<kDGLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl
......
......@@ -134,13 +134,13 @@ std::pair<CSRMatrix, NDArray> CSRMM(
C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
}; // namespace aten
......
......@@ -15,7 +15,7 @@ namespace impl {
namespace {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRRemoveConsecutive(
CSRMatrix csr,
IdArray entries,
......@@ -48,7 +48,7 @@ void CSRRemoveConsecutive(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRRemoveShuffled(
CSRMatrix csr,
IdArray entries,
......@@ -77,7 +77,7 @@ void CSRRemoveShuffled(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0];
......@@ -103,8 +103,8 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
IdArray::FromVector(new_eids));
}
template CSRMatrix CSRRemove<kDLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDLCPU, int64_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);
}; // namespace impl
}; // namespace aten
......
......@@ -14,7 +14,7 @@ namespace aten {
namespace impl {
///////////////////////////// CSRIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
......@@ -31,12 +31,12 @@ bool CSRIsSorted(CSRMatrix csr) {
[](bool a, bool b) { return a && b; });
}
template bool CSRIsSorted<kDLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
///////////////////////////// CSRSort /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
typedef std::pair<IdType, IdType> ShufflePair;
const int64_t num_rows = csr->num_rows;
......@@ -79,10 +79,10 @@ void CSRSort_(CSRMatrix* csr) {
csr->sorted = true;
}
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType>
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
......@@ -143,13 +143,13 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
return std::make_pair(output, tag_pos);
}
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int64_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int32_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int64_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int32_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
} // namespace impl
......
......@@ -130,13 +130,13 @@ std::pair<CSRMatrix, NDArray> CSRSum(
C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
}; // namespace aten
......
......@@ -12,7 +12,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
if (!csr.sorted)
csr = CSRSort(csr);
......@@ -67,8 +67,8 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
return std::make_tuple(res_csr, edge_count, eids_remapped);
}
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int32_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int64_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(CSRMatrix);
} // namespace impl
} // namespace aten
......
......@@ -14,7 +14,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<IdType> res_indptr;
std::vector<IdType> res_indices;
......@@ -109,8 +109,8 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
sorted);
}
template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix>&);
} // namespace impl
} // namespace aten
......
......@@ -26,7 +26,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdArray prefix_src_arr = NewIdArray(
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
......@@ -52,7 +52,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(prefix_elm_arr, true));
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
bool has_data = false;
bool row_sorted = true;
......@@ -118,8 +118,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted);
}
template COOMatrix DisjointUnionCoo<kDLCPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLCPU, int64_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl
} // namespace aten
......
......@@ -62,74 +62,74 @@ void GatherMMScatter(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDLCPU, int32_t, 16>(
template void GatherMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 16>(
template void GatherMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 32>(
template void GatherMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 32>(
template void GatherMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 64>(
template void GatherMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 64>(
template void GatherMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLCPU, int32_t, 16>(
template void GatherMMScatter<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 16>(
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 32>(
template void GatherMMScatter<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 32>(
template void GatherMMScatter<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 64>(
template void GatherMMScatter<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 64>(
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLCPU, int32_t, 16>(
template void SegmentMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 16>(
template void SegmentMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 32>(
template void SegmentMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 32>(
template void SegmentMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 64>(
template void SegmentMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 64>(
template void SegmentMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten
......
......@@ -17,7 +17,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix &csr,
int64_t num_samples,
......@@ -61,9 +61,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)};
}
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int32_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int64_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl
......
......@@ -97,13 +97,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// significant. (minjie)
IdArray picked_row = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_col = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
......
......@@ -117,7 +117,7 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
......@@ -125,16 +125,16 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
......@@ -143,28 +143,28 @@ COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
......@@ -172,12 +172,12 @@ COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
......@@ -191,22 +191,22 @@ COOMatrix CSRRowWiseSamplingBiased(
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
......@@ -214,16 +214,16 @@ COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>(
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>(
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>(
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
......@@ -232,28 +232,28 @@ COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>(
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, int64_t, bool);
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
......@@ -261,9 +261,9 @@ COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
} // namespace impl
......
......@@ -55,52 +55,52 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
} // namespace
template <DLDeviceType XPU, typename IdxType, typename DType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int32_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int32_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int64_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int64_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template <DLDeviceType XPU, typename IdxType, typename DType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int32_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int32_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int64_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int64_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, float>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, float>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, double>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, double>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool);
} // namespace impl
......
......@@ -102,67 +102,67 @@ void SDDMMCsrHetero(const std::string& op,
});
}
template void SDDMMCsr<kDLCPU, int32_t, 16>(
template void SDDMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>(
template void SDDMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>(
template void SDDMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 32>(
template void SDDMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 64>(
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 64>(
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......@@ -217,67 +217,67 @@ void SDDMMCooHetero(const std::string& op,
});
}
template void SDDMMCoo<kDLCPU, int32_t, 16>(
template void SDDMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>(
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 32>(
template void SDDMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 32>(
template void SDDMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 64>(
template void SDDMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 64>(
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDLCPU, int32_t, 16>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 16>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
......@@ -74,113 +74,113 @@ void BackwardSegmentCmp(
});
}
template void SegmentReduce<kDLCPU, int32_t, 16>(
template void SegmentReduce<kDGLCPU, int32_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 16>(
template void SegmentReduce<kDGLCPU, int64_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 32>(
template void SegmentReduce<kDGLCPU, int32_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 32>(
template void SegmentReduce<kDGLCPU, int64_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 64>(
template void SegmentReduce<kDGLCPU, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 64>(
template void SegmentReduce<kDGLCPU, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDLCPU, int32_t, 16>(
template void ScatterAdd<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 16>(
template void ScatterAdd<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 32>(
template void ScatterAdd<kDGLCPU, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 32>(
template void ScatterAdd<kDGLCPU, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 64>(
template void ScatterAdd<kDGLCPU, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 64>(
template void ScatterAdd<kDGLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 16>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
......
......@@ -29,7 +29,7 @@ namespace impl {
///////////////////////////// COOIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
......@@ -42,10 +42,10 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
return false;
}
template bool COOIsNonZero<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
......@@ -67,12 +67,12 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
return rst;
}
template NDArray COOIsNonZero<kDLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDLCPU, int64_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
///////////////////////////// COOHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo) {
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
const IdType* src_data = static_cast<IdType*>(coo.row->data);
......@@ -89,12 +89,12 @@ bool COOHasDuplicate(COOMatrix coo) {
return false;
}
template bool COOHasDuplicate<kDLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDLCPU, int64_t>(COOMatrix coo);
template bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
......@@ -106,10 +106,10 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
return result;
}
template int64_t COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0];
......@@ -123,12 +123,12 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
return rst;
}
template NDArray COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOGetRowDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
......@@ -151,13 +151,13 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
}
template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int32_t>(COOMatrix, int64_t);
COOGetRowDataAndIndices<kDGLCPU, int32_t>(COOMatrix, int64_t);
template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t);
COOGetRowDataAndIndices<kDGLCPU, int64_t>(COOMatrix, int64_t);
///////////////////////////// COOGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
......@@ -211,12 +211,12 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
return ret;
}
template IdArray COOGetData<kDLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDLCPU, int64_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows);
......@@ -286,20 +286,20 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray::FromVector(ret_data)};
}
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t>(
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(
COOMatrix coo, NDArray rows, NDArray cols);
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t>(
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(
COOMatrix coo, NDArray rows, NDArray cols);
///////////////////////////// COOTranspose /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo) {
return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
}
template COOMatrix COOTranspose<kDLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR /////////////////////////////
namespace {
......@@ -615,7 +615,7 @@ P^2).
degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ +
N*P).
*/
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
if (!coo.row_sorted) {
const int64_t num_threads = omp_get_num_threads();
......@@ -632,12 +632,12 @@ CSRMatrix COOToCSR(COOMatrix coo) {
return SortedCOOToCSR<IdType>(coo);
}
template CSRMatrix COOToCSR<kDLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLCPU, int64_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
// TODO(minjie): use binary search when coo.row_sorted is true
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
......@@ -669,10 +669,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
coo.col_sorted);
}
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
......@@ -703,12 +703,12 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
coo.row_sorted, coo.col_sorted};
}
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix , NDArray);
///////////////////////////// COOSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
......@@ -740,15 +740,15 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
coo.row_sorted, coo.col_sorted);
}
template COOMatrix COOSliceMatrix<kDLCPU, int32_t>(
template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template COOMatrix COOSliceMatrix<kDLCPU, int64_t>(
template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// COOReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
......@@ -785,9 +785,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
}
template COOMatrix COOReorder<kDLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
template COOMatrix COOReorder<kDGLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
template COOMatrix COOReorder<kDLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
template COOMatrix COOReorder<kDGLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
} // namespace impl
......
......@@ -21,7 +21,7 @@ namespace impl {
///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
......@@ -39,10 +39,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
return false;
}
template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
......@@ -62,12 +62,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst;
}
template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
......@@ -85,21 +85,21 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return false;
}
template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
return indptr_data[row + 1] - indptr_data[row];
}
template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const auto len = rows->shape[0];
......@@ -114,12 +114,12 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return rst;
}
template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -127,12 +127,12 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}
template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -143,13 +143,13 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
}
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData /////////////////////////////
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *col_vec,
......@@ -172,7 +172,7 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
// TODO(minjie): more efficient implementation for matrix without duplicate entries
const int64_t rowlen = rows->shape[0];
......@@ -224,16 +224,16 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
NDArray::FromVector(ret_data, csr.data->ctx)};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRTranspose /////////////////////////////
// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
......@@ -281,11 +281,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
}
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -303,11 +303,11 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);
// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
......@@ -333,12 +333,12 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
return COOMatrix(N, M, ret_row, ret_col);
}
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const int64_t num_rows = end - start;
......@@ -362,10 +362,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
......@@ -467,12 +467,12 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
return ret;
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0];
......@@ -521,14 +521,14 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
sub_data_arr};
}
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>(
template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>(
template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// CSRReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
......@@ -599,9 +599,9 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
out_indptr_arr, out_indices_arr, out_data_arr);
}
template CSRMatrix CSRReorder<kDLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
template CSRMatrix CSRReorder<kDGLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
template CSRMatrix CSRReorder<kDLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
template CSRMatrix CSRReorder<kDGLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
} // namespace impl
......
......@@ -124,67 +124,67 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
}
template void SpMMCsr<kDLCPU, int32_t, 16>(
template void SpMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 16>(
template void SpMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 32>(
template void SpMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 32>(
template void SpMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 64>(
template void SpMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
template void SpMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
......@@ -222,52 +222,52 @@ void Edge_softmax_csr_backward(const std::string& op,
});
}
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
......@@ -303,27 +303,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
}
template void SpMMCoo<kDLCPU, int32_t, 16>(
template void SpMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 16>(
template void SpMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 32>(
template void SpMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 32>(
template void SpMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 64>(
template void SpMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
template void SpMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment