Unverified Commit cded5b80 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
parent f1689ad0
...@@ -16,7 +16,7 @@ namespace impl { ...@@ -16,7 +16,7 @@ namespace impl {
namespace { namespace {
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */ /*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COORemoveConsecutive( void COORemoveConsecutive(
COOMatrix coo, COOMatrix coo,
IdArray entries, IdArray entries,
...@@ -47,7 +47,7 @@ void COORemoveConsecutive( ...@@ -47,7 +47,7 @@ void COORemoveConsecutive(
} }
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */ /*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COORemoveShuffled( void COORemoveShuffled(
COOMatrix coo, COOMatrix coo,
IdArray entries, IdArray entries,
...@@ -73,7 +73,7 @@ void COORemoveShuffled( ...@@ -73,7 +73,7 @@ void COORemoveShuffled(
}; // namespace }; // namespace
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries) { COOMatrix COORemove(COOMatrix coo, IdArray entries) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
...@@ -98,8 +98,8 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) { ...@@ -98,8 +98,8 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
IdArray::FromVector(new_eids)); IdArray::FromVector(new_eids));
} }
template COOMatrix COORemove<kDLCPU, int32_t>(COOMatrix coo, IdArray entries); template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDLCPU, int64_t>(COOMatrix coo, IdArray entries); template COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -167,7 +167,7 @@ namespace impl { ...@@ -167,7 +167,7 @@ namespace impl {
///////////////////////////// COOSort_ ///////////////////////////// ///////////////////////////// COOSort_ /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
const int64_t nnz = coo->row->shape[0]; const int64_t nnz = coo->row->shape[0];
IdType* coo_row = coo->row.Ptr<IdType>(); IdType* coo_row = coo->row.Ptr<IdType>();
...@@ -208,13 +208,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -208,13 +208,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
coo->col_sorted = sort_column; coo->col_sorted = sort_column;
} }
template void COOSort_<kDLCPU, int32_t>(COOMatrix*, bool); template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDLCPU, int64_t>(COOMatrix*, bool); template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
///////////////////////////// COOIsSorted ///////////////////////////// ///////////////////////////// COOIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) { std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
IdType* row = coo.row.Ptr<IdType>(); IdType* row = coo.row.Ptr<IdType>();
...@@ -230,8 +230,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -230,8 +230,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted}; return {row_sorted, col_sorted};
} }
template std::pair<bool, bool> COOIsSorted<kDLCPU, int32_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLCPU, int64_t>(COOMatrix coo); template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -17,7 +17,7 @@ using runtime::parallel_for; ...@@ -17,7 +17,7 @@ using runtime::parallel_for;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data, void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col, const IdType start, const IdType end, const IdType col,
std::vector<IdType> *ret_vec) { std::vector<IdType> *ret_vec) {
...@@ -38,7 +38,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data, ...@@ -38,7 +38,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
} }
} }
template <DLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
...@@ -59,7 +59,7 @@ NDArray CSRGetData( ...@@ -59,7 +59,7 @@ NDArray CSRGetData(
const int64_t retlen = std::max(rowlen, collen); const int64_t retlen = std::max(rowlen, collen);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>(); const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids) if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype."; "DType does not match row's dtype.";
NDArray ret = Full(filler, retlen, rows->ctx); NDArray ret = Full(filler, retlen, rows->ctx);
...@@ -106,19 +106,19 @@ NDArray CSRGetData( ...@@ -106,19 +106,19 @@ NDArray CSRGetData(
return ret; return ret;
} }
template NDArray CSRGetData<kDLCPU, int32_t, float>( template NDArray CSRGetData<kDGLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int64_t, float>( template NDArray CSRGetData<kDGLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int32_t, double>( template NDArray CSRGetData<kDGLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLCPU, int64_t, double>( template NDArray CSRGetData<kDGLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>( template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>( template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
......
...@@ -134,13 +134,13 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -134,13 +134,13 @@ std::pair<CSRMatrix, NDArray> CSRMM(
C_weights}; C_weights};
} }
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
}; // namespace aten }; // namespace aten
......
...@@ -15,7 +15,7 @@ namespace impl { ...@@ -15,7 +15,7 @@ namespace impl {
namespace { namespace {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRRemoveConsecutive( void CSRRemoveConsecutive(
CSRMatrix csr, CSRMatrix csr,
IdArray entries, IdArray entries,
...@@ -48,7 +48,7 @@ void CSRRemoveConsecutive( ...@@ -48,7 +48,7 @@ void CSRRemoveConsecutive(
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRRemoveShuffled( void CSRRemoveShuffled(
CSRMatrix csr, CSRMatrix csr,
IdArray entries, IdArray entries,
...@@ -77,7 +77,7 @@ void CSRRemoveShuffled( ...@@ -77,7 +77,7 @@ void CSRRemoveShuffled(
}; // namespace }; // namespace
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries); CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
...@@ -103,8 +103,8 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { ...@@ -103,8 +103,8 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
IdArray::FromVector(new_eids)); IdArray::FromVector(new_eids));
} }
template CSRMatrix CSRRemove<kDLCPU, int32_t>(CSRMatrix csr, IdArray entries); template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDLCPU, int64_t>(CSRMatrix csr, IdArray entries); template CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -14,7 +14,7 @@ namespace aten { ...@@ -14,7 +14,7 @@ namespace aten {
namespace impl { namespace impl {
///////////////////////////// CSRIsSorted ///////////////////////////// ///////////////////////////// CSRIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) { bool CSRIsSorted(CSRMatrix csr) {
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
...@@ -31,12 +31,12 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -31,12 +31,12 @@ bool CSRIsSorted(CSRMatrix csr) {
[](bool a, bool b) { return a && b; }); [](bool a, bool b) { return a && b; });
} }
template bool CSRIsSorted<kDLCPU, int64_t>(CSRMatrix csr); template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLCPU, int32_t>(CSRMatrix csr); template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
///////////////////////////// CSRSort ///////////////////////////// ///////////////////////////// CSRSort /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) { void CSRSort_(CSRMatrix* csr) {
typedef std::pair<IdType, IdType> ShufflePair; typedef std::pair<IdType, IdType> ShufflePair;
const int64_t num_rows = csr->num_rows; const int64_t num_rows = csr->num_rows;
...@@ -79,10 +79,10 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -79,10 +79,10 @@ void CSRSort_(CSRMatrix* csr) {
csr->sorted = true; csr->sorted = true;
} }
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr); template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr); template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType> template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) { const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
const auto indptr_data = static_cast<const IdType *>(csr.indptr->data); const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
...@@ -143,13 +143,13 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -143,13 +143,13 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
return std::make_pair(output, tag_pos); return std::make_pair(output, tag_pos);
} }
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int64_t>( template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags); const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int32_t>( template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags); const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int64_t>( template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags); const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int32_t>( template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags); const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
} // namespace impl } // namespace impl
......
...@@ -130,13 +130,13 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -130,13 +130,13 @@ std::pair<CSRMatrix, NDArray> CSRSum(
C_weights}; C_weights};
} }
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, float>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, double>( template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&); const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
}; // namespace aten }; // namespace aten
......
...@@ -12,7 +12,7 @@ namespace dgl { ...@@ -12,7 +12,7 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) { std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
if (!csr.sorted) if (!csr.sorted)
csr = CSRSort(csr); csr = CSRSort(csr);
...@@ -67,8 +67,8 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) { ...@@ -67,8 +67,8 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
return std::make_tuple(res_csr, edge_count, eids_remapped); return std::make_tuple(res_csr, edge_count, eids_remapped);
} }
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int32_t>(CSRMatrix); template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int64_t>(CSRMatrix); template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(CSRMatrix);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -14,7 +14,7 @@ namespace dgl { ...@@ -14,7 +14,7 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<IdType> res_indptr; std::vector<IdType> res_indptr;
std::vector<IdType> res_indices; std::vector<IdType> res_indices;
...@@ -109,8 +109,8 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -109,8 +109,8 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
sorted); sorted);
} }
template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&); template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&); template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix>&);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -26,7 +26,7 @@ using runtime::NDArray; ...@@ -26,7 +26,7 @@ using runtime::NDArray;
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) { std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdArray prefix_src_arr = NewIdArray( IdArray prefix_src_arr = NewIdArray(
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits); coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
...@@ -52,7 +52,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -52,7 +52,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(prefix_elm_arr, true)); CumSum(prefix_elm_arr, true));
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
bool has_data = false; bool has_data = false;
bool row_sorted = true; bool row_sorted = true;
...@@ -118,8 +118,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -118,8 +118,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted); col_sorted);
} }
template COOMatrix DisjointUnionCoo<kDLCPU, int32_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLCPU, int64_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -62,74 +62,74 @@ void GatherMMScatter(const NDArray A, ...@@ -62,74 +62,74 @@ void GatherMMScatter(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for GatherMM."; LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
} }
template void GatherMM<kDLCPU, int32_t, 16>( template void GatherMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 16>( template void GatherMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 32>( template void GatherMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 32>( template void GatherMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 64>( template void GatherMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 64>( template void GatherMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLCPU, int32_t, 16>( template void GatherMMScatter<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 16>( template void GatherMMScatter<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 32>( template void GatherMMScatter<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 32>( template void GatherMMScatter<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 64>( template void GatherMMScatter<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 64>( template void GatherMMScatter<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLCPU, int32_t, 16>( template void SegmentMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 16>( template void SegmentMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 32>( template void SegmentMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 32>( template void SegmentMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 64>( template void SegmentMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 64>( template void SegmentMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>( template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>( template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>( template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>( template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>( template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>( template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten } // namespace aten
......
...@@ -17,7 +17,7 @@ namespace dgl { ...@@ -17,7 +17,7 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix &csr, const CSRMatrix &csr,
int64_t num_samples, int64_t num_samples,
...@@ -61,9 +61,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -61,9 +61,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)}; return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)};
} }
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int32_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double); const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int64_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double); const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl }; // namespace impl
......
...@@ -97,13 +97,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows, ...@@ -97,13 +97,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more // [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// significant. (minjie) // significant. (minjie)
IdArray picked_row = NDArray::Empty({num_rows * num_picks}, IdArray picked_row = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1}, DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx); ctx);
IdArray picked_col = NDArray::Empty({num_rows * num_picks}, IdArray picked_col = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1}, DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx); ctx);
IdArray picked_idx = NDArray::Empty({num_rows * num_picks}, IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1}, DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx); ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data); IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data); IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
......
...@@ -117,7 +117,7 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn( ...@@ -117,7 +117,7 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR /////////////////////////////// /////////////////////////////// CSR ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples, COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) { FloatArray prob, bool replace) {
CHECK(prob.defined()); CHECK(prob.defined());
...@@ -125,16 +125,16 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples, ...@@ -125,16 +125,16 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
} }
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>( template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) { FloatArray prob, bool replace, bool etype_sorted) {
...@@ -143,28 +143,28 @@ COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes ...@@ -143,28 +143,28 @@ COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>( template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows, COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) { int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
} }
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>( template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, int64_t, bool); CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>( template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool); CSRMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes, COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) { bool replace, bool etype_sorted) {
...@@ -172,12 +172,12 @@ COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray ...@@ -172,12 +172,12 @@ COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>( template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased( COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat, CSRMatrix mat,
IdArray rows, IdArray rows,
...@@ -191,22 +191,22 @@ COOMatrix CSRRowWiseSamplingBiased( ...@@ -191,22 +191,22 @@ COOMatrix CSRRowWiseSamplingBiased(
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn); return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
} }
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>( template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool); CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO /////////////////////////////// /////////////////////////////// COO ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples, COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) { FloatArray prob, bool replace) {
CHECK(prob.defined()); CHECK(prob.defined());
...@@ -214,16 +214,16 @@ COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples, ...@@ -214,16 +214,16 @@ COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
return COORowWisePick(mat, rows, num_samples, replace, pick_fn); return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
} }
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>( template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool); COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>( template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool); COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>( template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool); COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>( template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool); COOMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType> template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) { FloatArray prob, bool replace, bool etype_sorted) {
...@@ -232,28 +232,28 @@ COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes ...@@ -232,28 +232,28 @@ COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>( template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows, COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) { int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace); auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn); return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
} }
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>( template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, int64_t, bool); COOMatrix, IdArray, int64_t, bool);
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>( template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, int64_t, bool); COOMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes, COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) { bool replace, bool etype_sorted) {
...@@ -261,9 +261,9 @@ COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray ...@@ -261,9 +261,9 @@ COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn); return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
} }
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>( template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool); COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
} // namespace impl } // namespace impl
......
...@@ -55,52 +55,52 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending) ...@@ -55,52 +55,52 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
} // namespace } // namespace
template <DLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending); auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn); return CSRRowWisePick(mat, rows, k, false, pick_fn);
} }
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int32_t>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int32_t>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int64_t>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int64_t>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, float>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, float>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, double>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, double>( template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool); CSRMatrix, IdArray, int64_t, NDArray, bool);
template <DLDeviceType XPU, typename IdxType, typename DType> template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseTopk( COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) { COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending); auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn); return COORowWisePick(mat, rows, k, false, pick_fn);
} }
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int32_t>( template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int32_t>( template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int64_t>( template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int64_t>( template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, float>( template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, float>( template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, double>( template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, double>( template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool); COOMatrix, IdArray, int64_t, NDArray, bool);
} // namespace impl } // namespace impl
......
...@@ -102,67 +102,67 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -102,67 +102,67 @@ void SDDMMCsrHetero(const std::string& op,
}); });
} }
template void SDDMMCsr<kDLCPU, int32_t, 16>( template void SDDMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>( template void SDDMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>( template void SDDMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 32>( template void SDDMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 64>( template void SDDMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 64>( template void SDDMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>( template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>( template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>( template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>( template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>( template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>( template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
...@@ -217,67 +217,67 @@ void SDDMMCooHetero(const std::string& op, ...@@ -217,67 +217,67 @@ void SDDMMCooHetero(const std::string& op,
}); });
} }
template void SDDMMCoo<kDLCPU, int32_t, 16>( template void SDDMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>( template void SDDMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 32>( template void SDDMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 32>( template void SDDMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 64>( template void SDDMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 64>( template void SDDMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDLCPU, int32_t, 16>( template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 16>( template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 32>( template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 32>( template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 64>( template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 64>( template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
...@@ -74,113 +74,113 @@ void BackwardSegmentCmp( ...@@ -74,113 +74,113 @@ void BackwardSegmentCmp(
}); });
} }
template void SegmentReduce<kDLCPU, int32_t, 16>( template void SegmentReduce<kDGLCPU, int32_t, 16>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 16>( template void SegmentReduce<kDGLCPU, int64_t, 16>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 32>( template void SegmentReduce<kDGLCPU, int32_t, 32>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 32>( template void SegmentReduce<kDGLCPU, int64_t, 32>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 64>( template void SegmentReduce<kDGLCPU, int32_t, 64>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 64>( template void SegmentReduce<kDGLCPU, int64_t, 64>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDLCPU, int32_t, 16>( template void ScatterAdd<kDGLCPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 16>( template void ScatterAdd<kDGLCPU, int64_t, 16>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 32>( template void ScatterAdd<kDGLCPU, int32_t, 32>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 32>( template void ScatterAdd<kDGLCPU, int64_t, 32>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 64>( template void ScatterAdd<kDGLCPU, int32_t, 64>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 64>( template void ScatterAdd<kDGLCPU, int64_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>( template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>( template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>( template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>( template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>( template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 16>( template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 32>( template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 32>( template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 64>( template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 64>( template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
......
...@@ -29,7 +29,7 @@ namespace impl { ...@@ -29,7 +29,7 @@ namespace impl {
///////////////////////////// COOIsNonZero ///////////////////////////// ///////////////////////////// COOIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col; CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
...@@ -42,10 +42,10 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { ...@@ -42,10 +42,10 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
return false; return false;
} }
template bool COOIsNonZero<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t); template bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t); template bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
const auto rowlen = row->shape[0]; const auto rowlen = row->shape[0];
const auto collen = col->shape[0]; const auto collen = col->shape[0];
...@@ -67,12 +67,12 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { ...@@ -67,12 +67,12 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
return rst; return rst;
} }
template NDArray COOIsNonZero<kDLCPU, int32_t>(COOMatrix, NDArray, NDArray); template NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDLCPU, int64_t>(COOMatrix, NDArray, NDArray); template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
///////////////////////////// COOHasDuplicate ///////////////////////////// ///////////////////////////// COOHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo) { bool COOHasDuplicate(COOMatrix coo) {
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap; std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
const IdType* src_data = static_cast<IdType*>(coo.row->data); const IdType* src_data = static_cast<IdType*>(coo.row->data);
...@@ -89,12 +89,12 @@ bool COOHasDuplicate(COOMatrix coo) { ...@@ -89,12 +89,12 @@ bool COOHasDuplicate(COOMatrix coo) {
return false; return false;
} }
template bool COOHasDuplicate<kDLCPU, int32_t>(COOMatrix coo); template bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDLCPU, int64_t>(COOMatrix coo); template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOGetRowNNZ ///////////////////////////// ///////////////////////////// COOGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
...@@ -106,10 +106,10 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { ...@@ -106,10 +106,10 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
return result; return result;
} }
template int64_t COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, int64_t); template int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t); template int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows); CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0]; const auto len = rows->shape[0];
...@@ -123,12 +123,12 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { ...@@ -123,12 +123,12 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
return rst; return rst;
} }
template NDArray COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, NDArray); template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray); template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOGetRowDataAndIndices ///////////////////////////// ///////////////////////////// COOGetRowDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<NDArray, NDArray> COOGetRowDataAndIndices( std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) { COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
...@@ -151,13 +151,13 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices( ...@@ -151,13 +151,13 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
} }
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int32_t>(COOMatrix, int64_t); COOGetRowDataAndIndices<kDGLCPU, int32_t>(COOMatrix, int64_t);
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t); COOGetRowDataAndIndices<kDGLCPU, int64_t>(COOMatrix, int64_t);
///////////////////////////// COOGetData ///////////////////////////// ///////////////////////////// COOGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) { IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -211,12 +211,12 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) { ...@@ -211,12 +211,12 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
return ret; return ret;
} }
template IdArray COOGetData<kDLCPU, int32_t>(COOMatrix, IdArray, IdArray); template IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDLCPU, int64_t>(COOMatrix, IdArray, IdArray); template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
///////////////////////////// COOGetDataAndIndices ///////////////////////////// ///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray cols) { NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows); CHECK_SAME_DTYPE(coo.col, rows);
...@@ -286,20 +286,20 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, ...@@ -286,20 +286,20 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray::FromVector(ret_data)}; NDArray::FromVector(ret_data)};
} }
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t>( template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(
COOMatrix coo, NDArray rows, NDArray cols); COOMatrix coo, NDArray rows, NDArray cols);
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t>( template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(
COOMatrix coo, NDArray rows, NDArray cols); COOMatrix coo, NDArray rows, NDArray cols);
///////////////////////////// COOTranspose ///////////////////////////// ///////////////////////////// COOTranspose /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo) { COOMatrix COOTranspose(COOMatrix coo) {
return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data}; return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
} }
template COOMatrix COOTranspose<kDLCPU, int32_t>(COOMatrix coo); template COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo); template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR ///////////////////////////// ///////////////////////////// COOToCSR /////////////////////////////
namespace { namespace {
...@@ -615,7 +615,7 @@ P^2). ...@@ -615,7 +615,7 @@ P^2).
degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ + degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ +
N*P). N*P).
*/ */
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
if (!coo.row_sorted) { if (!coo.row_sorted) {
const int64_t num_threads = omp_get_num_threads(); const int64_t num_threads = omp_get_num_threads();
...@@ -632,12 +632,12 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -632,12 +632,12 @@ CSRMatrix COOToCSR(COOMatrix coo) {
return SortedCOOToCSR<IdType>(coo); return SortedCOOToCSR<IdType>(coo);
} }
template CSRMatrix COOToCSR<kDLCPU, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLCPU, int64_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOSliceRows ///////////////////////////// ///////////////////////////// COOSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
// TODO(minjie): use binary search when coo.row_sorted is true // TODO(minjie): use binary search when coo.row_sorted is true
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start; CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
...@@ -669,10 +669,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { ...@@ -669,10 +669,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
coo.col_sorted); coo.col_sorted);
} }
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
...@@ -703,12 +703,12 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { ...@@ -703,12 +703,12 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
coo.row_sorted, coo.col_sorted}; coo.row_sorted, coo.col_sorted};
} }
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix , NDArray);
///////////////////////////// COOSliceMatrix ///////////////////////////// ///////////////////////////// COOSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) { COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
...@@ -740,15 +740,15 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray ...@@ -740,15 +740,15 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
coo.row_sorted, coo.col_sorted); coo.row_sorted, coo.col_sorted);
} }
template COOMatrix COOSliceMatrix<kDLCPU, int32_t>( template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template COOMatrix COOSliceMatrix<kDLCPU, int64_t>( template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// COOReorder ///////////////////////////// ///////////////////////////// COOReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr, COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) { runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(coo.row, new_row_id_arr); CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
...@@ -785,9 +785,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr, ...@@ -785,9 +785,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr); return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
} }
template COOMatrix COOReorder<kDLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids, template COOMatrix COOReorder<kDGLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids); runtime::NDArray new_col_ids);
template COOMatrix COOReorder<kDLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids, template COOMatrix COOReorder<kDGLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids); runtime::NDArray new_col_ids);
} // namespace impl } // namespace impl
......
...@@ -21,7 +21,7 @@ namespace impl { ...@@ -21,7 +21,7 @@ namespace impl {
///////////////////////////// CSRIsNonZero ///////////////////////////// ///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
...@@ -39,10 +39,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -39,10 +39,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
return false; return false;
} }
template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t); template bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t); template bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0]; const auto rowlen = row->shape[0];
const auto collen = col->shape[0]; const auto collen = col->shape[0];
...@@ -62,12 +62,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -62,12 +62,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst; return rst;
} }
template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray); template NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray); template NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate ///////////////////////////// ///////////////////////////// CSRHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) { bool CSRHasDuplicate(CSRMatrix csr) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
...@@ -85,21 +85,21 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -85,21 +85,21 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return false; return false;
} }
template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr); template bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr); template bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ ///////////////////////////// ///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) { int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
return indptr_data[row + 1] - indptr_data[row]; return indptr_data[row + 1] - indptr_data[row];
} }
template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t); template int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t); template int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows); CHECK_SAME_DTYPE(csr.indices, rows);
const auto len = rows->shape[0]; const auto len = rows->shape[0];
...@@ -114,12 +114,12 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -114,12 +114,12 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return rst; return rst;
} }
template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices ///////////////////////////// ///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -127,12 +127,12 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { ...@@ -127,12 +127,12 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
return csr.indices.CreateView({len}, csr.indices->dtype, offset); return csr.indices.CreateView({len}, csr.indices->dtype, offset);
} }
template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t); template NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t); template NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData ///////////////////////////// ///////////////////////////// CSRGetRowData /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -143,13 +143,13 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { ...@@ -143,13 +143,13 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx); return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
} }
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t); template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t); template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData ///////////////////////////// ///////////////////////////// CSRGetData /////////////////////////////
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data, void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col, const IdType start, const IdType end, const IdType col,
std::vector<IdType> *col_vec, std::vector<IdType> *col_vec,
...@@ -172,7 +172,7 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data ...@@ -172,7 +172,7 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
} }
} }
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) { std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
// TODO(minjie): more efficient implementation for matrix without duplicate entries // TODO(minjie): more efficient implementation for matrix without duplicate entries
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
...@@ -224,16 +224,16 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c ...@@ -224,16 +224,16 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
NDArray::FromVector(ret_data, csr.data->ctx)}; NDArray::FromVector(ret_data, csr.data->ctx)};
} }
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>( template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols); CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>( template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols); CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRTranspose ///////////////////////////// ///////////////////////////// CSRTranspose /////////////////////////////
// for a matrix of shape (N, M) and NNZ // for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1) // complexity: time O(NNZ + max(N, M)), space O(1)
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) { CSRMatrix CSRTranspose(CSRMatrix csr) {
const int64_t N = csr.num_rows; const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols; const int64_t M = csr.num_cols;
...@@ -281,11 +281,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -281,11 +281,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data}; return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
} }
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRToCOO ///////////////////////////// ///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) { COOMatrix CSRToCOO(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -303,11 +303,11 @@ COOMatrix CSRToCOO(CSRMatrix csr) { ...@@ -303,11 +303,11 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
true, csr.sorted); true, csr.sorted);
} }
template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);
// complexity: time O(NNZ), space O(1) // complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
const int64_t N = csr.num_rows; const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols; const int64_t M = csr.num_cols;
...@@ -333,12 +333,12 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -333,12 +333,12 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
return COOMatrix(N, M, ret_row, ret_col); return COOMatrix(N, M, ret_row, ret_col);
} }
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr); template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRSliceRows ///////////////////////////// ///////////////////////////// CSRSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const int64_t num_rows = end - start; const int64_t num_rows = end - start;
...@@ -362,10 +362,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { ...@@ -362,10 +362,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr.sorted); csr.sorted);
} }
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t); template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t); template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows); CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
...@@ -467,12 +467,12 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -467,12 +467,12 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
return ret; return ret;
} }
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRSliceMatrix ///////////////////////////// ///////////////////////////// CSRSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
IdHashMap<IdType> hashmap(cols); IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0]; const int64_t new_nrows = rows->shape[0];
...@@ -521,14 +521,14 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -521,14 +521,14 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
sub_data_arr}; sub_data_arr};
} }
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>( template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>( template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// CSRReorder ///////////////////////////// ///////////////////////////// CSRReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) { runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(csr.indices, new_row_id_arr); CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
...@@ -599,9 +599,9 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, ...@@ -599,9 +599,9 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
out_indptr_arr, out_indices_arr, out_data_arr); out_indptr_arr, out_indices_arr, out_data_arr);
} }
template CSRMatrix CSRReorder<kDLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids, template CSRMatrix CSRReorder<kDGLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids); runtime::NDArray new_col_ids);
template CSRMatrix CSRReorder<kDLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids, template CSRMatrix CSRReorder<kDGLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids); runtime::NDArray new_col_ids);
} // namespace impl } // namespace impl
......
...@@ -124,67 +124,67 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -124,67 +124,67 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} }
} }
template void SpMMCsr<kDLCPU, int32_t, 16>( template void SpMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 16>( template void SpMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 32>( template void SpMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 32>( template void SpMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 64>( template void SpMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>( template void SpMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDLCPU, int32_t, 16>( template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>( template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>( template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>( template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>( template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids, const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids); const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>( template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
...@@ -222,52 +222,52 @@ void Edge_softmax_csr_backward(const std::string& op, ...@@ -222,52 +222,52 @@ void Edge_softmax_csr_backward(const std::string& op,
}); });
} }
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>( template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>( template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>( template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>( template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out); NDArray ufeat, NDArray efeat, NDArray out);
...@@ -303,27 +303,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -303,27 +303,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
} }
} }
template void SpMMCoo<kDLCPU, int32_t, 16>( template void SpMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 16>( template void SpMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 32>( template void SpMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 32>( template void SpMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 64>( template void SpMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>( template void SpMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment