Unverified Commit 8ae50c42 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4804)



* [Misc] clang-format auto fix.

* manual

* manual

* manual

* manual

* todo

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 81831111
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/array_iterator.h> #include <dgl/array_iterator.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <utility> #include <dgl/runtime/parallel_for.h>
#include <algorithm> #include <algorithm>
#include <utility>
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -19,15 +20,12 @@ namespace impl { ...@@ -19,15 +20,12 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix &csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy) {
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
const int64_t num_row = csr.num_rows; const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols; const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy)); const int64_t num_actual_samples =
static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx); IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx); IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdType* row_data = row.Ptr<IdType>(); IdType* row_data = row.Ptr<IdType>();
...@@ -48,23 +46,30 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -48,23 +46,30 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
}); });
PairIterator<IdType> begin(row_data, col_data); PairIterator<IdType> begin(row_data, col_data);
PairIterator<IdType> end = std::remove_if(begin, begin + num_actual_samples, PairIterator<IdType> end = std::remove_if(
begin, begin + num_actual_samples,
[](const std::pair<IdType, IdType>& val) { return val.first == -1; }); [](const std::pair<IdType, IdType>& val) { return val.first == -1; });
if (!replace) { if (!replace) {
std::sort(begin, end, std::sort(
[](const std::pair<IdType, IdType>& a, const std::pair<IdType, IdType>& b) { begin, end,
return a.first < b.first || (a.first == b.first && a.second < b.second); [](const std::pair<IdType, IdType>& a,
});; const std::pair<IdType, IdType>& b) {
return a.first < b.first ||
(a.first == b.first && a.second < b.second);
});
end = std::unique(begin, end); end = std::unique(begin, end);
} }
int64_t num_sampled = std::min(static_cast<int64_t>(end - begin), num_samples); int64_t num_sampled =
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)}; std::min(static_cast<int64_t>(end - begin), num_samples);
return {
row.CreateView({num_sampled}, row->dtype),
col.CreateView({num_sampled}, col->dtype)};
} }
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCPU, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCPU, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include "../selector.h" #include "../selector.h"
namespace dgl { namespace dgl {
...@@ -25,37 +26,40 @@ namespace cpu { ...@@ -25,37 +26,40 @@ namespace cpu {
* \note it uses node parallel strategy, different threads are responsible * \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. * for the computation of different nodes.
*/ */
template <typename IdType, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename IdType, typename DType, typename Op, int LhsTarget = 0,
void SDDMMCsr(const BcastOff& bcast, int RhsTarget = 2>
const CSRMatrix& csr, void SDDMMCsr(
NDArray lhs, NDArray rhs, NDArray out) { const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
NDArray out) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>(); const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = lhs.Ptr<DType>(); const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>(); const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len, const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>(); DType* O = out.Ptr<DType>();
runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) { runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) {
for (auto rid = b; rid < e; ++rid) { for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
DType* out_off = O + eid * dim; DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) { for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs const DType* lhs_off =
? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size Op::use_lhs
? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
lhs_add * reduce_size
: nullptr; : nullptr;
const DType* rhs_off = Op::use_rhs const DType* rhs_off =
? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size Op::use_rhs
? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
rhs_add * reduce_size
: nullptr; : nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size); out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
} }
...@@ -74,35 +78,38 @@ void SDDMMCsr(const BcastOff& bcast, ...@@ -74,35 +78,38 @@ void SDDMMCsr(const BcastOff& bcast,
* \note it uses edge parallel strategy, different threads are responsible * \note it uses edge parallel strategy, different threads are responsible
* for the computation of different edges. * for the computation of different edges.
*/ */
template <typename IdType, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename IdType, typename DType, typename Op, int LhsTarget = 0,
void SDDMMCoo(const BcastOff& bcast, int RhsTarget = 2>
const COOMatrix& coo, void SDDMMCoo(
NDArray lhs, NDArray rhs, NDArray out) { const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data); const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>(); const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>(); const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>(); const IdType* edges = coo.data.Ptr<IdType>();
const DType* X = lhs.Ptr<DType>(); const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>(); const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len, const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>(); DType* O = out.Ptr<DType>();
#pragma omp parallel for #pragma omp parallel for
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const IdType rid = row[i]; const IdType rid = row[i];
const IdType cid = col[i]; const IdType cid = col[i];
const IdType eid = has_idx? edges[i] : i; const IdType eid = has_idx ? edges[i] : i;
DType* out_off = O + eid * dim; DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) { for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs ? const DType* lhs_off =
X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size : nullptr; Op::use_lhs ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
const DType* rhs_off = Op::use_rhs ? lhs_add * reduce_size
Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr; : nullptr;
const DType* rhs_off =
Op::use_rhs ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
rhs_add * reduce_size
: nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size); out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
} }
} }
...@@ -110,12 +117,13 @@ void SDDMMCoo(const BcastOff& bcast, ...@@ -110,12 +117,13 @@ void SDDMMCoo(const BcastOff& bcast,
namespace op { namespace op {
//////////////////////////////// binary operators on CPU //////////////////////////////// ////////////////////////// binary operators on CPU /////////////////////////////
template <typename DType> template <typename DType>
struct Add { struct Add {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off + *rhs_off; return *lhs_off + *rhs_off;
} }
}; };
...@@ -124,7 +132,8 @@ template <typename DType> ...@@ -124,7 +132,8 @@ template <typename DType>
struct Sub { struct Sub {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off - *rhs_off; return *lhs_off - *rhs_off;
} }
}; };
...@@ -133,7 +142,8 @@ template <typename DType> ...@@ -133,7 +142,8 @@ template <typename DType>
struct Mul { struct Mul {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off * *rhs_off; return *lhs_off * *rhs_off;
} }
}; };
...@@ -142,7 +152,8 @@ template <typename DType> ...@@ -142,7 +152,8 @@ template <typename DType>
struct Div { struct Div {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off / *rhs_off; return *lhs_off / *rhs_off;
} }
}; };
...@@ -151,7 +162,8 @@ template <typename DType> ...@@ -151,7 +162,8 @@ template <typename DType>
struct CopyLhs { struct CopyLhs {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false; static constexpr bool use_rhs = false;
inline static DType Call(const DType* lhs_off, const DType*, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType*, int64_t len = 1) {
return *lhs_off; return *lhs_off;
} }
}; };
...@@ -160,7 +172,8 @@ template <typename DType> ...@@ -160,7 +172,8 @@ template <typename DType>
struct CopyRhs { struct CopyRhs {
static constexpr bool use_lhs = false; static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* , const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType*, const DType* rhs_off, int64_t len = 1) {
return *rhs_off; return *rhs_off;
} }
}; };
...@@ -169,7 +182,8 @@ template <typename DType> ...@@ -169,7 +182,8 @@ template <typename DType>
struct Dot { struct Dot {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
DType rst = 0; DType rst = 0;
for (int64_t l = 0; l < len; ++l) { for (int64_t l = 0; l < len; ++l) {
rst += lhs_off[l] * rhs_off[l]; rst += lhs_off[l] * rhs_off[l];
......
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_ #define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector> #include <dgl/runtime/parallel_for.h>
#include <string> #include <string>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -26,11 +27,10 @@ template <typename IdType, typename DType> ...@@ -26,11 +27,10 @@ template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) { void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
int n = out->shape[0]; int n = out->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>(); const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
runtime::parallel_for(0, n, [=](int b, int e) { runtime::parallel_for(0, n, [=](int b, int e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) { for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
...@@ -51,16 +51,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) { ...@@ -51,16 +51,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
* used in backward phase. * used in backward phase.
*/ */
template <typename IdType, typename DType, typename Cmp> template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets, void SegmentCmp(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
NDArray out, NDArray arg) {
int n = out->shape[0]; int n = out->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>(); const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
IdType *arg_data = arg.Ptr<IdType>(); IdType* arg_data = arg.Ptr<IdType>();
std::fill(out_data, out_data + out.NumElements(), Cmp::zero); std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
std::fill(arg_data, arg_data + arg.NumElements(), -1); std::fill(arg_data, arg_data + arg.NumElements(), -1);
runtime::parallel_for(0, n, [=](int b, int e) { runtime::parallel_for(0, n, [=](int b, int e) {
...@@ -89,8 +87,7 @@ template <typename IdType, typename DType> ...@@ -89,8 +87,7 @@ template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
int n = feat->shape[0]; int n = feat->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>(); const IdType* idx_data = idx.Ptr<IdType>();
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
...@@ -114,24 +111,26 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { ...@@ -114,24 +111,26 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
* \param list_out List of the output tensors. * \param list_out List of the output tensors.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph, void UpdateGradMinMax_hetero(
const std::string& op, HeteroGraphPtr graph, const std::string& op,
const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx_types, const std::vector<NDArray>& list_idx_types,
std::vector<NDArray>* list_out) { std::vector<NDArray>* list_out) {
if (op == "copy_lhs" || op == "copy_rhs") { if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(), std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
std::vector<dgl_id_t>()); graph->NumVertexTypes(), std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_ntype = pair.first; // graph is reversed const dgl_id_t dst_ntype = pair.first; // graph is reversed
const dgl_id_t src_ntype = pair.second; const dgl_id_t src_ntype = pair.second;
auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]), auto same_src_dst_ntype = std::find(
std::begin(src_dst_ntypes[dst_ntype]),
std::end(src_dst_ntypes[dst_ntype]), src_ntype); std::end(src_dst_ntypes[dst_ntype]), src_ntype);
// if op is "copy_lhs", relation type with same src and dst node type will be updated once // if op is "copy_lhs", relation type with same src and dst node type will
if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype])) // be updated once
if (op == "copy_lhs" &&
same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
continue; continue;
src_dst_ntypes[dst_ntype].push_back(src_ntype); src_dst_ntypes[dst_ntype].push_back(src_ntype);
const DType* feat_data = list_feat[dst_ntype].Ptr<DType>(); const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
...@@ -149,7 +148,8 @@ void UpdateGradMinMax_hetero(HeteroGraphPtr graph, ...@@ -149,7 +148,8 @@ void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
if (type == idx_type_data[i * dim + k]) { if (type == idx_type_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k]; const int write_row = idx_data[i * dim + k];
#pragma omp atomic #pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ out_data[write_row * dim + k] +=
feat_data[i * dim + k]; // feat = dZ
} }
} }
} }
...@@ -170,8 +170,7 @@ template <typename IdType, typename DType> ...@@ -170,8 +170,7 @@ template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) { void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
int n = feat->shape[0]; int n = feat->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* arg_data = arg.Ptr<IdType>(); const IdType* arg_data = arg.Ptr<IdType>();
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
* \file array/cpu/spmat_op_impl.cc * \file array/cpu/spmat_op_impl.cc
* \brief CPU implementation of COO sparse matrix operators * \brief CPU implementation of COO sparse matrix operators
*/ */
#include <dmlc/omp.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <vector> #include <dmlc/omp.h>
#include <unordered_set>
#include <unordered_map>
#include <tuple>
#include <numeric> #include <numeric>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -33,11 +35,10 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -33,11 +35,10 @@ template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col; CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
if (coo_row_data[i] == row && coo_col_data[i] == col) if (coo_row_data[i] == row && coo_col_data[i] == col) return true;
return true;
} }
return false; return false;
} }
...@@ -51,9 +52,9 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { ...@@ -51,9 +52,9 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
const auto collen = col->shape[0]; const auto collen = col->shape[0];
const auto rstlen = std::max(rowlen, collen); const auto rstlen = std::max(rowlen, collen);
NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx); NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data); IdType *rst_data = static_cast<IdType *>(rst->data);
const IdType* row_data = static_cast<IdType*>(row->data); const IdType *row_data = static_cast<IdType *>(row->data);
const IdType* col_data = static_cast<IdType*>(col->data); const IdType *col_data = static_cast<IdType *>(col->data);
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t kmax = std::max(rowlen, collen); const int64_t kmax = std::max(rowlen, collen);
...@@ -61,7 +62,8 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { ...@@ -61,7 +62,8 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
for (auto k = b; k < e; ++k) { for (auto k = b; k < e; ++k) {
int64_t i = row_stride * k; int64_t i = row_stride * k;
int64_t j = col_stride * k; int64_t j = col_stride * k;
rst_data[k] = COOIsNonZero<XPU, IdType>(coo, row_data[i], col_data[j])? 1 : 0; rst_data[k] =
COOIsNonZero<XPU, IdType>(coo, row_data[i], col_data[j]) ? 1 : 0;
} }
}); });
return rst; return rst;
...@@ -75,11 +77,11 @@ template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray); ...@@ -75,11 +77,11 @@ template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo) { bool COOHasDuplicate(COOMatrix coo) {
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap; std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
const IdType* src_data = static_cast<IdType*>(coo.row->data); const IdType *src_data = static_cast<IdType *>(coo.row->data);
const IdType* dst_data = static_cast<IdType*>(coo.col->data); const IdType *dst_data = static_cast<IdType *>(coo.col->data);
const auto nnz = coo.row->shape[0]; const auto nnz = coo.row->shape[0];
for (IdType eid = 0; eid < nnz; ++eid) { for (IdType eid = 0; eid < nnz; ++eid) {
const auto& p = std::make_pair(src_data[eid], dst_data[eid]); const auto &p = std::make_pair(src_data[eid], dst_data[eid]);
if (hashmap.count(p)) { if (hashmap.count(p)) {
return true; return true;
} else { } else {
...@@ -97,11 +99,10 @@ template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo); ...@@ -97,11 +99,10 @@ template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
int64_t result = 0; int64_t result = 0;
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
if (coo_row_data[i] == row) if (coo_row_data[i] == row) ++result;
++result;
} }
return result; return result;
} }
...@@ -113,9 +114,9 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -113,9 +114,9 @@ template <DGLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows); CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data); const IdType *vid_data = static_cast<IdType *>(rows->data);
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data); IdType *rst_data = static_cast<IdType *>(rst->data);
#pragma omp parallel for #pragma omp parallel for
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst_data[i] = COOGetRowNNZ<XPU, IdType>(coo, vid_data[i]); rst_data[i] = COOGetRowNNZ<XPU, IdType>(coo, vid_data[i]);
...@@ -126,16 +127,17 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { ...@@ -126,16 +127,17 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray); template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray); template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOGetRowDataAndIndices ///////////////////////////// ////////////////////////// COOGetRowDataAndIndices /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<NDArray, NDArray> COOGetRowDataAndIndices( std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) { COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr; const IdType *coo_data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
std::vector<IdType> indices; std::vector<IdType> indices;
std::vector<IdType> data; std::vector<IdType> data;
...@@ -147,13 +149,14 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices( ...@@ -147,13 +149,14 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
} }
} }
return std::make_pair(NDArray::FromVector(data), NDArray::FromVector(indices)); return std::make_pair(
NDArray::FromVector(data), NDArray::FromVector(indices));
} }
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int32_t>(
COOGetRowDataAndIndices<kDGLCPU, int32_t>(COOMatrix, int64_t); COOMatrix, int64_t);
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray> COOGetRowDataAndIndices<kDGLCPU, int64_t>(
COOGetRowDataAndIndices<kDGLCPU, int64_t>(COOMatrix, int64_t); COOMatrix, int64_t);
///////////////////////////// COOGetData ///////////////////////////// ///////////////////////////// COOGetData /////////////////////////////
...@@ -165,31 +168,32 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) { ...@@ -165,31 +168,32 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
<< "Invalid row and col Id array:" << rows << " " << cols; << "Invalid row and col Id array:" << rows << " " << cols;
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = rows.Ptr<IdType>(); const IdType *row_data = rows.Ptr<IdType>();
const IdType* col_data = cols.Ptr<IdType>(); const IdType *col_data = cols.Ptr<IdType>();
const IdType* coo_row = coo.row.Ptr<IdType>(); const IdType *coo_row = coo.row.Ptr<IdType>();
const IdType* coo_col = coo.col.Ptr<IdType>(); const IdType *coo_col = coo.col.Ptr<IdType>();
const IdType* data = COOHasData(coo) ? coo.data.Ptr<IdType>() : nullptr; const IdType *data = COOHasData(coo) ? coo.data.Ptr<IdType>() : nullptr;
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const int64_t retlen = std::max(rowlen, collen); const int64_t retlen = std::max(rowlen, collen);
IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx); IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx);
IdType* ret_data = ret.Ptr<IdType>(); IdType *ret_data = ret.Ptr<IdType>();
// TODO(minjie): We might need to consider sorting the COO beforehand especially // TODO(minjie): We might need to consider sorting the COO beforehand
// when the number of (row, col) pairs is large. Need more benchmarks to justify // especially when the number of (row, col) pairs is large. Need more
// the choice. // benchmarks to justify the choice.
if (coo.row_sorted) { if (coo.row_sorted) {
parallel_for(0, retlen, [&](size_t b, size_t e) { parallel_for(0, retlen, [&](size_t b, size_t e) {
for (auto p = b; p < e; ++p) { for (auto p = b; p < e; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; const IdType row_id = row_data[p * row_stride],
col_id = col_data[p * col_stride];
auto it = std::lower_bound(coo_row, coo_row + nnz, row_id); auto it = std::lower_bound(coo_row, coo_row + nnz, row_id);
for (; it < coo_row + nnz && *it == row_id; ++it) { for (; it < coo_row + nnz && *it == row_id; ++it) {
const auto idx = it - coo_row; const auto idx = it - coo_row;
if (coo_col[idx] == col_id) { if (coo_col[idx] == col_id) {
ret_data[p] = data? data[idx] : idx; ret_data[p] = data ? data[idx] : idx;
break; break;
} }
} }
...@@ -198,10 +202,11 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) { ...@@ -198,10 +202,11 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
} else { } else {
#pragma omp parallel for #pragma omp parallel for
for (int64_t p = 0; p < retlen; ++p) { for (int64_t p = 0; p < retlen; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; const IdType row_id = row_data[p * row_stride],
col_id = col_data[p * col_stride];
for (int64_t idx = 0; idx < nnz; ++idx) { for (int64_t idx = 0; idx < nnz; ++idx) {
if (coo_row[idx] == row_id && coo_col[idx] == col_id) { if (coo_row[idx] == row_id && coo_col[idx] == col_id) {
ret_data[p] = data? data[idx] : idx; ret_data[p] = data ? data[idx] : idx;
break; break;
} }
} }
...@@ -217,8 +222,8 @@ template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray); ...@@ -217,8 +222,8 @@ template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
///////////////////////////// COOGetDataAndIndices ///////////////////////////// ///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, std::vector<NDArray> COOGetDataAndIndices(
NDArray cols) { COOMatrix coo, NDArray rows, NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows); CHECK_SAME_DTYPE(coo.col, rows);
CHECK_SAME_DTYPE(coo.col, cols); CHECK_SAME_DTYPE(coo.col, cols);
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
...@@ -230,12 +235,13 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, ...@@ -230,12 +235,13 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = static_cast<IdType*>(rows->data); const IdType *row_data = static_cast<IdType *>(rows->data);
const IdType* col_data = static_cast<IdType*>(cols->data); const IdType *col_data = static_cast<IdType *>(cols->data);
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
const IdType* data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr; const IdType *data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
std::vector<IdType> ret_rows, ret_cols; std::vector<IdType> ret_rows, ret_cols;
std::vector<IdType> ret_data; std::vector<IdType> ret_data;
...@@ -244,21 +250,27 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, ...@@ -244,21 +250,27 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
ret_data.reserve(len); ret_data.reserve(len);
// NOTE(BarclayII): With a small number of lookups, linear scan is faster. // NOTE(BarclayII): With a small number of lookups, linear scan is faster.
// The threshold 200 comes from benchmarking both algorithms on a P3.8x instance. // The threshold 200 comes from benchmarking both algorithms on a P3.8x
// I also tried sorting plus binary search. The speed gain is only significant for // instance. I also tried sorting plus binary search. The speed gain is only
// medium-sized graphs and lookups, so I didn't include it. // significant for medium-sized graphs and lookups, so I didn't include it.
if (len >= 200) { if (len >= 200) {
// TODO(BarclayII) Ideally we would want to cache this object. However I'm not sure // TODO(BarclayII) Ideally we would want to cache this object. However I'm
// what is the best way to do so since this object is valid for CPU only. // not sure what is the best way to do so since this object is valid for CPU
std::unordered_multimap<std::pair<IdType, IdType>, IdType, PairHash> pair_map; // only.
std::unordered_multimap<std::pair<IdType, IdType>, IdType, PairHash>
pair_map;
pair_map.reserve(coo.row->shape[0]); pair_map.reserve(coo.row->shape[0]);
for (int64_t k = 0; k < coo.row->shape[0]; ++k) for (int64_t k = 0; k < coo.row->shape[0]; ++k)
pair_map.emplace(std::make_pair(coo_row_data[k], coo_col_data[k]), data ? data[k]: k); pair_map.emplace(
std::make_pair(coo_row_data[k], coo_col_data[k]), data ? data[k] : k);
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { for (int64_t i = 0, j = 0; i < rowlen && j < collen;
i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < coo.num_rows) << "Invalid row index: " << row_id; CHECK(row_id >= 0 && row_id < coo.num_rows)
CHECK(col_id >= 0 && col_id < coo.num_cols) << "Invalid col index: " << col_id; << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < coo.num_cols)
<< "Invalid col index: " << col_id;
auto range = pair_map.equal_range({row_id, col_id}); auto range = pair_map.equal_range({row_id, col_id});
for (auto it = range.first; it != range.second; ++it) { for (auto it = range.first; it != range.second; ++it) {
ret_rows.push_back(row_id); ret_rows.push_back(row_id);
...@@ -267,10 +279,13 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, ...@@ -267,10 +279,13 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
} }
} }
} else { } else {
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { for (int64_t i = 0, j = 0; i < rowlen && j < collen;
i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < coo.num_rows) << "Invalid row index: " << row_id; CHECK(row_id >= 0 && row_id < coo.num_rows)
CHECK(col_id >= 0 && col_id < coo.num_cols) << "Invalid col index: " << col_id; << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < coo.num_cols)
<< "Invalid col index: " << col_id;
for (int64_t k = 0; k < coo.row->shape[0]; ++k) { for (int64_t k = 0; k < coo.row->shape[0]; ++k) {
if (coo_row_data[k] == row_id && coo_col_data[k] == col_id) { if (coo_row_data[k] == row_id && coo_col_data[k] == col_id) {
ret_rows.push_back(row_id); ret_rows.push_back(row_id);
...@@ -281,8 +296,8 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows, ...@@ -281,8 +296,8 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
} }
} }
return {NDArray::FromVector(ret_rows), return {
NDArray::FromVector(ret_cols), NDArray::FromVector(ret_rows), NDArray::FromVector(ret_cols),
NDArray::FromVector(ret_data)}; NDArray::FromVector(ret_data)};
} }
...@@ -304,7 +319,8 @@ template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo); ...@@ -304,7 +319,8 @@ template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR ///////////////////////////// ///////////////////////////// COOToCSR /////////////////////////////
namespace { namespace {
template <class IdType> CSRMatrix SortedCOOToCSR(const COOMatrix &coo) { template <class IdType>
CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
const int64_t N = coo.num_rows; const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0]; const int64_t NNZ = coo.row->shape[0];
const IdType *const row_data = static_cast<IdType *>(coo.row->data); const IdType *const row_data = static_cast<IdType *>(coo.row->data);
...@@ -389,11 +405,13 @@ template <class IdType> CSRMatrix SortedCOOToCSR(const COOMatrix &coo) { ...@@ -389,11 +405,13 @@ template <class IdType> CSRMatrix SortedCOOToCSR(const COOMatrix &coo) {
std::fill(Bp, Bp + N + 1, 0); std::fill(Bp, Bp + N + 1, 0);
} }
return CSRMatrix(coo.num_rows, coo.num_cols, ret_indptr, ret_indices, return CSRMatrix(
ret_data, coo.col_sorted); coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
coo.col_sorted);
} }
template <class IdType> CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) { template <class IdType>
CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {
const int64_t N = coo.num_rows; const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0]; const int64_t NNZ = coo.row->shape[0];
const IdType *const row_data = static_cast<IdType *>(coo.row->data); const IdType *const row_data = static_cast<IdType *>(coo.row->data);
...@@ -507,11 +525,13 @@ template <class IdType> CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) { ...@@ -507,11 +525,13 @@ template <class IdType> CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) {
Bp[i + 1] += i_start; Bp[i + 1] += i_start;
} }
} }
return CSRMatrix(coo.num_rows, coo.num_cols, ret_indptr, ret_indices, return CSRMatrix(
ret_data, coo.col_sorted); coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
coo.col_sorted);
} }
template <class IdType> CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) { template <class IdType>
CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {
const int64_t N = coo.num_rows; const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0]; const int64_t NNZ = coo.row->shape[0];
const IdType *const row_data = static_cast<IdType *>(coo.row->data); const IdType *const row_data = static_cast<IdType *>(coo.row->data);
...@@ -597,8 +617,9 @@ template <class IdType> CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) { ...@@ -597,8 +617,9 @@ template <class IdType> CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) {
} }
CHECK_EQ(Bp[N], NNZ); CHECK_EQ(Bp[N], NNZ);
return CSRMatrix(coo.num_rows, coo.num_cols, ret_indptr, ret_indices, return CSRMatrix(
ret_data, coo.col_sorted); coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data,
coo.col_sorted);
} }
} // namespace } // namespace
...@@ -643,9 +664,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { ...@@ -643,9 +664,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start; CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end; CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr; const IdType *coo_data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
std::vector<IdType> ret_row, ret_col; std::vector<IdType> ret_row, ret_col;
std::vector<IdType> ret_data; std::vector<IdType> ret_data;
...@@ -660,13 +682,9 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { ...@@ -660,13 +682,9 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
} }
} }
return COOMatrix( return COOMatrix(
end - start, end - start, coo.num_cols, NDArray::FromVector(ret_row),
coo.num_cols, NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),
NDArray::FromVector(ret_row), coo.row_sorted, coo.col_sorted);
NDArray::FromVector(ret_col),
NDArray::FromVector(ret_data),
coo.row_sorted,
coo.col_sorted);
} }
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
...@@ -674,9 +692,10 @@ template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t); ...@@ -674,9 +692,10 @@ template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr; const IdType *coo_data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
std::vector<IdType> ret_row, ret_col; std::vector<IdType> ret_row, ret_col;
std::vector<IdType> ret_data; std::vector<IdType> ret_data;
...@@ -700,19 +719,22 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { ...@@ -700,19 +719,22 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
NDArray::FromVector(ret_row), NDArray::FromVector(ret_row),
NDArray::FromVector(ret_col), NDArray::FromVector(ret_col),
NDArray::FromVector(ret_data), NDArray::FromVector(ret_data),
coo.row_sorted, coo.col_sorted}; coo.row_sorted,
coo.col_sorted};
} }
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOSliceMatrix ///////////////////////////// ///////////////////////////// COOSliceMatrix /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) { COOMatrix COOSliceMatrix(
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType *coo_row_data = static_cast<IdType *>(coo.row->data);
const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr; const IdType *coo_col_data = static_cast<IdType *>(coo.col->data);
const IdType *coo_data =
COOHasData(coo) ? static_cast<IdType *>(coo.data->data) : nullptr;
IdHashMap<IdType> row_map(rows), col_map(cols); IdHashMap<IdType> row_map(rows), col_map(cols);
...@@ -733,10 +755,9 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray ...@@ -733,10 +755,9 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
} }
} }
return COOMatrix(rows->shape[0], cols->shape[0], return COOMatrix(
NDArray::FromVector(ret_row), rows->shape[0], cols->shape[0], NDArray::FromVector(ret_row),
NDArray::FromVector(ret_col), NDArray::FromVector(ret_col), NDArray::FromVector(ret_data),
NDArray::FromVector(ret_data),
coo.row_sorted, coo.col_sorted); coo.row_sorted, coo.col_sorted);
} }
...@@ -745,36 +766,38 @@ template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>( ...@@ -745,36 +766,38 @@ template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>( template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// COOReorder ///////////////////////////// ///////////////////////////// COOReorder /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr, COOMatrix COOReorder(
COOMatrix coo, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) { runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(coo.row, new_row_id_arr); CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
CHECK_SAME_DTYPE(coo.col, new_col_id_arr); CHECK_SAME_DTYPE(coo.col, new_col_id_arr);
// Input COO // Input COO
const IdType* in_rows = static_cast<IdType*>(coo.row->data); const IdType *in_rows = static_cast<IdType *>(coo.row->data);
const IdType* in_cols = static_cast<IdType*>(coo.col->data); const IdType *in_cols = static_cast<IdType *>(coo.col->data);
int64_t num_rows = coo.num_rows; int64_t num_rows = coo.num_rows;
int64_t num_cols = coo.num_cols; int64_t num_cols = coo.num_cols;
int64_t nnz = coo.row->shape[0]; int64_t nnz = coo.row->shape[0];
CHECK_EQ(num_rows, new_row_id_arr->shape[0]) CHECK_EQ(num_rows, new_row_id_arr->shape[0])
<< "The new row Id array needs to be the same as the number of rows of COO"; << "The new row Id array needs to be the same as the number of rows of "
"COO";
CHECK_EQ(num_cols, new_col_id_arr->shape[0]) CHECK_EQ(num_cols, new_col_id_arr->shape[0])
<< "The new col Id array needs to be the same as the number of cols of COO"; << "The new col Id array needs to be the same as the number of cols of "
"COO";
// New row/col Ids. // New row/col Ids.
const IdType* new_row_ids = static_cast<IdType*>(new_row_id_arr->data); const IdType *new_row_ids = static_cast<IdType *>(new_row_id_arr->data);
const IdType* new_col_ids = static_cast<IdType*>(new_col_id_arr->data); const IdType *new_col_ids = static_cast<IdType *>(new_col_id_arr->data);
// Output COO // Output COO
NDArray out_row_arr = NDArray::Empty({nnz}, coo.row->dtype, coo.row->ctx); NDArray out_row_arr = NDArray::Empty({nnz}, coo.row->dtype, coo.row->ctx);
NDArray out_col_arr = NDArray::Empty({nnz}, coo.col->dtype, coo.col->ctx); NDArray out_col_arr = NDArray::Empty({nnz}, coo.col->dtype, coo.col->ctx);
NDArray out_data_arr = COOHasData(coo) ? coo.data : NullArray(); NDArray out_data_arr = COOHasData(coo) ? coo.data : NullArray();
IdType *out_row = static_cast<IdType*>(out_row_arr->data); IdType *out_row = static_cast<IdType *>(out_row_arr->data);
IdType *out_col = static_cast<IdType*>(out_col_arr->data); IdType *out_col = static_cast<IdType *>(out_col_arr->data);
parallel_for(0, nnz, [=](size_t b, size_t e) { parallel_for(0, nnz, [=](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
...@@ -785,10 +808,10 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr, ...@@ -785,10 +808,10 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr); return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
} }
template COOMatrix COOReorder<kDGLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids, template COOMatrix COOReorder<kDGLCPU, int64_t>(
runtime::NDArray new_col_ids); COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template COOMatrix COOReorder<kDGLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids, template COOMatrix COOReorder<kDGLCPU, int32_t>(
runtime::NDArray new_col_ids); COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include <atomic> #include <atomic>
#include <numeric>
#include <unordered_set>
#include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -26,8 +28,8 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -26,8 +28,8 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
if (csr.sorted) { if (csr.sorted) {
const IdType *start = indices_data + indptr_data[row]; const IdType* start = indices_data + indptr_data[row];
const IdType *end = indices_data + indptr_data[row + 1]; const IdType* end = indices_data + indptr_data[row + 1];
return std::binary_search(start, end, col); return std::binary_search(start, end, col);
} else { } else {
for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) { for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
...@@ -53,11 +55,14 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -53,11 +55,14 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const IdType* col_data = static_cast<IdType*>(col->data); const IdType* col_data = static_cast<IdType*>(col->data);
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
runtime::parallel_for(0, std::max(rowlen, collen), 1, [=](int64_t b, int64_t e) { runtime::parallel_for(
0, std::max(rowlen, collen), 1, [=](int64_t b, int64_t e) {
int64_t i = (row_stride == 0) ? 0 : b; int64_t i = (row_stride == 0) ? 0 : b;
int64_t j = (col_stride == 0) ? 0 : b; int64_t j = (col_stride == 0) ? 0 : b;
for (int64_t k = b; i < e && j < e; i += row_stride, j += col_stride, ++k) for (int64_t k = b; i < e && j < e;
rst_data[k] = CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j]) ? 1 : 0; i += row_stride, j += col_stride, ++k)
rst_data[k] =
CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j]) ? 1 : 0;
}); });
return rst; return rst;
} }
...@@ -73,7 +78,7 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -73,7 +78,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
for (IdType src = 0; src < csr.num_rows; ++src) { for (IdType src = 0; src < csr.num_rows; ++src) {
std::unordered_set<IdType> hashmap; std::unordered_set<IdType> hashmap;
for (IdType eid = indptr_data[src]; eid < indptr_data[src+1]; ++eid) { for (IdType eid = indptr_data[src]; eid < indptr_data[src + 1]; ++eid) {
const IdType dst = indices_data[eid]; const IdType dst = indices_data[eid];
if (hashmap.count(dst)) { if (hashmap.count(dst)) {
return true; return true;
...@@ -117,7 +122,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -117,7 +122,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray); template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices ///////////////////////////// /////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
...@@ -140,7 +145,8 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { ...@@ -140,7 +145,8 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
if (CSRHasData(csr)) if (CSRHasData(csr))
return csr.data.CreateView({len}, csr.data->dtype, offset); return csr.data.CreateView({len}, csr.data->dtype, offset);
else else
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx); return aten::Range(
offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
} }
template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t); template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
...@@ -150,12 +156,12 @@ template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t); ...@@ -150,12 +156,12 @@ template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data, void CollectDataIndicesFromSorted(
const IdType start, const IdType end, const IdType col, const IdType* indices_data, const IdType* data, const IdType start,
std::vector<IdType> *col_vec, const IdType end, const IdType col, std::vector<IdType>* col_vec,
std::vector<IdType> *ret_vec) { std::vector<IdType>* ret_vec) {
const IdType *start_ptr = indices_data + start; const IdType* start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end; const IdType* end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col); auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched // This might be a multi-graph. We need to collect all of the matched
// columns. // columns.
...@@ -173,8 +179,10 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data ...@@ -173,8 +179,10 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) { std::vector<NDArray> CSRGetDataAndIndices(
// TODO(minjie): more efficient implementation for matrix without duplicate entries CSRMatrix csr, NDArray rows, NDArray cols) {
// TODO(minjie): more efficient implementation for matrix without duplicate
// entries
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -188,38 +196,41 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c ...@@ -188,38 +196,41 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data =
CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
std::vector<IdType> ret_rows, ret_cols; std::vector<IdType> ret_rows, ret_cols;
std::vector<IdType> ret_data; std::vector<IdType> ret_data;
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { for (int64_t i = 0, j = 0; i < rowlen && j < collen;
i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id; CHECK(row_id >= 0 && row_id < csr.num_rows)
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id; << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols)
<< "Invalid col index: " << col_id;
if (csr.sorted) { if (csr.sorted) {
// Here we collect col indices and data. // Here we collect col indices and data.
CollectDataIndicesFromSorted<XPU, IdType>(indices_data, data, CollectDataIndicesFromSorted<XPU, IdType>(
indptr_data[row_id], indices_data, data, indptr_data[row_id], indptr_data[row_id + 1],
indptr_data[row_id + 1], col_id, &ret_cols, &ret_data);
col_id, &ret_cols,
&ret_data);
// We need to add row Ids. // We need to add row Ids.
while (ret_rows.size() < ret_data.size()) { while (ret_rows.size() < ret_data.size()) {
ret_rows.push_back(row_id); ret_rows.push_back(row_id);
} }
} else { } else {
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) { for (IdType i = indptr_data[row_id]; i < indptr_data[row_id + 1]; ++i) {
if (indices_data[i] == col_id) { if (indices_data[i] == col_id) {
ret_rows.push_back(row_id); ret_rows.push_back(row_id);
ret_cols.push_back(col_id); ret_cols.push_back(col_id);
ret_data.push_back(data? data[i] : i); ret_data.push_back(data ? data[i] : i);
} }
} }
} }
} }
return {NDArray::FromVector(ret_rows, csr.indptr->ctx), return {
NDArray::FromVector(ret_rows, csr.indptr->ctx),
NDArray::FromVector(ret_cols, csr.indptr->ctx), NDArray::FromVector(ret_cols, csr.indptr->ctx),
NDArray::FromVector(ret_data, csr.data->ctx)}; NDArray::FromVector(ret_data, csr.data->ctx)};
} }
...@@ -240,9 +251,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -240,9 +251,12 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const IdType* Ap = static_cast<IdType*>(csr.indptr->data); const IdType* Ap = static_cast<IdType*>(csr.indptr->data);
const IdType* Aj = static_cast<IdType*>(csr.indices->data); const IdType* Aj = static_cast<IdType*>(csr.indices->data);
const IdType* Ax = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* Ax =
NDArray ret_indptr = NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx); CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
NDArray ret_indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray ret_indptr =
NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);
NDArray ret_indices =
NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx); NDArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
IdType* Bp = static_cast<IdType*>(ret_indptr->data); IdType* Bp = static_cast<IdType*>(ret_indptr->data);
IdType* Bi = static_cast<IdType*>(ret_indices->data); IdType* Bi = static_cast<IdType*>(ret_indices->data);
...@@ -263,10 +277,10 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -263,10 +277,10 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
Bp[M] = nnz; Bp[M] = nnz;
for (int64_t i = 0; i < N; ++i) { for (int64_t i = 0; i < N; ++i) {
for (IdType j = Ap[i]; j < Ap[i+1]; ++j) { for (IdType j = Ap[i]; j < Ap[i + 1]; ++j) {
const IdType dst = Aj[j]; const IdType dst = Aj[j];
Bi[Bp[dst]] = i; Bi[Bp[dst]] = i;
Bx[Bp[dst]] = Ax? Ax[j] : j; Bx[Bp[dst]] = Ax ? Ax[j] : j;
Bp[dst]++; Bp[dst]++;
} }
} }
...@@ -278,7 +292,8 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -278,7 +292,8 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
last = temp; last = temp;
} }
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data}; return CSRMatrix{
csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
} }
template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr); template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
...@@ -293,14 +308,13 @@ COOMatrix CSRToCOO(CSRMatrix csr) { ...@@ -293,14 +308,13 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
IdType* ret_row_data = static_cast<IdType*>(ret_row->data); IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
parallel_for(0, csr.indptr->shape[0] - 1, 10000, [=](int64_t b, int64_t e) { parallel_for(0, csr.indptr->shape[0] - 1, 10000, [=](int64_t b, int64_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
std::fill(ret_row_data + indptr_data[i], std::fill(
ret_row_data + indptr_data[i + 1], ret_row_data + indptr_data[i], ret_row_data + indptr_data[i + 1], i);
i);
} }
}); });
return COOMatrix(csr.num_rows, csr.num_cols, return COOMatrix(
ret_row, csr.indices, csr.data, csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
true, csr.sorted); csr.sorted);
} }
template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
...@@ -315,7 +329,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -315,7 +329,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
// data array should have the same type as the indices arrays // data array should have the same type as the indices arrays
const IdType* data = CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data =
CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
IdType* ret_row_data = static_cast<IdType*>(ret_row->data); IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
...@@ -343,7 +358,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { ...@@ -343,7 +358,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const int64_t num_rows = end - start; const int64_t num_rows = end - start;
const int64_t nnz = indptr[end] - indptr[start]; const int64_t nnz = indptr[end] - indptr[start];
IdArray ret_indptr = IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx); IdArray ret_indptr =
IdArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);
IdType* r_indptr = static_cast<IdType*>(ret_indptr->data); IdType* r_indptr = static_cast<IdType*>(ret_indptr->data);
for (int64_t i = start; i < end + 1; ++i) { for (int64_t i = start; i < end + 1; ++i) {
r_indptr[i - start] = indptr[i] - indptr[start]; r_indptr[i - start] = indptr[i] - indptr[start];
...@@ -353,13 +369,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { ...@@ -353,13 +369,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
{nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType)); {nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));
IdArray ret_data; IdArray ret_data;
if (CSRHasData(csr)) if (CSRHasData(csr))
ret_data = csr.data.CreateView({nnz}, csr.data->dtype, indptr[start] * sizeof(IdType)); ret_data = csr.data.CreateView(
{nnz}, csr.data->dtype, indptr[start] * sizeof(IdType));
else else
ret_data = aten::Range(indptr[start], indptr[end], ret_data = aten::Range(
csr.indptr->dtype.bits, csr.indptr->ctx); indptr[start], indptr[end], csr.indptr->dtype.bits, csr.indptr->ctx);
return CSRMatrix(num_rows, csr.num_cols, return CSRMatrix(
ret_indptr, ret_indices, ret_data, num_rows, csr.num_cols, ret_indptr, ret_indices, ret_data, csr.sorted);
csr.sorted);
} }
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t); template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
...@@ -370,7 +386,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -370,7 +386,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows); CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data =
CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* rows_data = static_cast<IdType*>(rows->data); const IdType* rows_data = static_cast<IdType*>(rows->data);
int64_t nnz = 0; int64_t nnz = 0;
...@@ -389,13 +406,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -389,13 +406,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
bool err = false; bool err = false;
std::stringstream err_msg_stream; std::stringstream err_msg_stream;
// Perform two-round parallel prefix sum using OpenMP // Perform two-round parallel prefix sum using OpenMP
#pragma omp parallel #pragma omp parallel
{ {
int64_t tid = omp_get_thread_num(); int64_t tid = omp_get_thread_num();
int64_t num_threads = omp_get_num_threads(); int64_t num_threads = omp_get_num_threads();
#pragma omp single #pragma omp single
{ {
sums.resize(num_threads + 1); sums.resize(num_threads + 1);
sums[0] = 0; sums[0] = 0;
...@@ -403,14 +420,14 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -403,14 +420,14 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
int64_t sum = 0; int64_t sum = 0;
// First round of parallel prefix sum. All threads perform local prefix sums. // First round of parallel prefix sum. All threads perform local prefix sums.
#pragma omp for schedule(static) nowait #pragma omp for schedule(static) nowait
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
int64_t rid = rows_data[i]; int64_t rid = rows_data[i];
if (rid >= csr.num_rows) { if (rid >= csr.num_rows) {
if (!err_flag.test_and_set()) { if (!err_flag.test_and_set()) {
err_msg_stream << "expect row ID " << rid << " to be less than number of rows " err_msg_stream << "expect row ID " << rid
<< csr.num_rows; << " to be less than number of rows " << csr.num_rows;
err = true; err = true;
} }
} else { } else {
...@@ -419,20 +436,18 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -419,20 +436,18 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
} }
} }
sums[tid + 1] = sum; sums[tid + 1] = sum;
#pragma omp barrier #pragma omp barrier
#pragma omp single #pragma omp single
{ {
for (int64_t i = 1; i < num_threads; ++i) for (int64_t i = 1; i < num_threads; ++i) sums[i] += sums[i - 1];
sums[i] += sums[i - 1];
} }
int64_t offset = sums[tid]; int64_t offset = sums[tid];
// Second round of parallel prefix sum. Update the local prefix sums. // Second round of parallel prefix sum. Update the local prefix sums.
#pragma omp for schedule(static) #pragma omp for schedule(static)
for (int64_t i = 0; i < len; ++i) for (int64_t i = 0; i < len; ++i) ret_indptr_data[i + 1] += offset;
ret_indptr_data[i + 1] += offset;
} }
if (err) { if (err) {
LOG(FATAL) << err_msg_stream.str(); LOG(FATAL) << err_msg_stream.str();
...@@ -454,26 +469,30 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -454,26 +469,30 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
const IdType rid = rows_data[i]; const IdType rid = rows_data[i];
// note: zero is allowed // note: zero is allowed
std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1], std::copy(
indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
ret_indices_data + ret_indptr_data[i]); ret_indices_data + ret_indptr_data[i]);
if (data) if (data)
std::copy(data + indptr_data[rid], data + indptr_data[rid + 1], std::copy(
data + indptr_data[rid], data + indptr_data[rid + 1],
ret_data + ret_indptr_data[i]); ret_data + ret_indptr_data[i]);
else else
std::iota(ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1], std::iota(
ret_data + ret_indptr_data[i], ret_data + ret_indptr_data[i + 1],
indptr_data[rid]); indptr_data[rid]);
} }
}); });
return ret; return ret;
} }
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix , NDArray); template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRSliceMatrix ///////////////////////////// ///////////////////////////// CSRSliceMatrix /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) { CSRMatrix CSRSliceMatrix(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
IdHashMap<IdType> hashmap(cols); IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0]; const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0]; const int64_t new_ncols = cols->shape[0];
...@@ -482,7 +501,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -482,7 +501,8 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = has_data? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data =
has_data ? static_cast<IdType*>(csr.data->data) : nullptr;
std::vector<IdType> sub_indptr, sub_indices; std::vector<IdType> sub_indptr, sub_indices;
std::vector<IdType> sub_data; std::vector<IdType> sub_data;
...@@ -498,7 +518,7 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -498,7 +518,7 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
if (newj != kInvalidId) { if (newj != kInvalidId) {
++sub_indptr[i]; ++sub_indptr[i];
sub_indices.push_back(newj); sub_indices.push_back(newj);
sub_data.push_back(has_data? data[p] : p); sub_data.push_back(has_data ? data[p] : p);
} }
} }
} }
...@@ -512,13 +532,13 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -512,13 +532,13 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
sub_indptr[new_nrows] = sub_indices.size(); sub_indptr[new_nrows] = sub_indices.size();
const int64_t nnz = sub_data.size(); const int64_t nnz = sub_data.size();
NDArray sub_data_arr = NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx); NDArray sub_data_arr =
NDArray::Empty({nnz}, csr.indptr->dtype, csr.indptr->ctx);
IdType* ptr = static_cast<IdType*>(sub_data_arr->data); IdType* ptr = static_cast<IdType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr); std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols, return CSRMatrix{
NDArray::FromVector(sub_indptr, csr.indptr->ctx), new_nrows, new_ncols, NDArray::FromVector(sub_indptr, csr.indptr->ctx),
NDArray::FromVector(sub_indices, csr.indptr->ctx), NDArray::FromVector(sub_indices, csr.indptr->ctx), sub_data_arr};
sub_data_arr};
} }
template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>( template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
...@@ -529,7 +549,8 @@ template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>( ...@@ -529,7 +549,8 @@ template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
///////////////////////////// CSRReorder ///////////////////////////// ///////////////////////////// CSRReorder /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, CSRMatrix CSRReorder(
CSRMatrix csr, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) { runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(csr.indices, new_row_id_arr); CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
CHECK_SAME_DTYPE(csr.indices, new_col_id_arr); CHECK_SAME_DTYPE(csr.indices, new_col_id_arr);
...@@ -543,21 +564,25 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, ...@@ -543,21 +564,25 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
int64_t nnz = csr.indices->shape[0]; int64_t nnz = csr.indices->shape[0];
CHECK_EQ(nnz, in_indptr[num_rows]); CHECK_EQ(nnz, in_indptr[num_rows]);
CHECK_EQ(num_rows, new_row_id_arr->shape[0]) CHECK_EQ(num_rows, new_row_id_arr->shape[0])
<< "The new row Id array needs to be the same as the number of rows of CSR"; << "The new row Id array needs to be the same as the number of rows of "
"CSR";
CHECK_EQ(num_cols, new_col_id_arr->shape[0]) CHECK_EQ(num_cols, new_col_id_arr->shape[0])
<< "The new col Id array needs to be the same as the number of cols of CSR"; << "The new col Id array needs to be the same as the number of cols of "
"CSR";
// New row/col Ids. // New row/col Ids.
const IdType* new_row_ids = static_cast<IdType*>(new_row_id_arr->data); const IdType* new_row_ids = static_cast<IdType*>(new_row_id_arr->data);
const IdType* new_col_ids = static_cast<IdType*>(new_col_id_arr->data); const IdType* new_col_ids = static_cast<IdType*>(new_col_id_arr->data);
// Output CSR // Output CSR
NDArray out_indptr_arr = NDArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indptr->ctx); NDArray out_indptr_arr =
NDArray out_indices_arr = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indptr->ctx);
NDArray out_indices_arr =
NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray out_data_arr = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx); NDArray out_data_arr = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
IdType *out_indptr = static_cast<IdType*>(out_indptr_arr->data); IdType* out_indptr = static_cast<IdType*>(out_indptr_arr->data);
IdType *out_indices = static_cast<IdType*>(out_indices_arr->data); IdType* out_indices = static_cast<IdType*>(out_indices_arr->data);
IdType *out_data = static_cast<IdType*>(out_data_arr->data); IdType* out_data = static_cast<IdType*>(out_data_arr->data);
// Compute the length of rows for the new matrix. // Compute the length of rows for the new matrix.
std::vector<IdType> new_row_lens(num_rows, -1); std::vector<IdType> new_row_lens(num_rows, -1);
...@@ -579,12 +604,12 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, ...@@ -579,12 +604,12 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
// Here I iterate rows in the order of the old matrix. // Here I iterate rows in the order of the old matrix.
parallel_for(0, num_rows, [=](size_t b, size_t e) { parallel_for(0, num_rows, [=](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
const IdType *in_row = in_indices + in_indptr[i]; const IdType* in_row = in_indices + in_indptr[i];
const IdType *in_row_data = in_data + in_indptr[i]; const IdType* in_row_data = in_data + in_indptr[i];
int64_t new_row_id = new_row_ids[i]; int64_t new_row_id = new_row_ids[i];
IdType *out_row = out_indices + out_indptr[new_row_id]; IdType* out_row = out_indices + out_indptr[new_row_id];
IdType *out_row_data = out_data + out_indptr[new_row_id]; IdType* out_row_data = out_data + out_indptr[new_row_id];
int64_t row_len = new_row_lens[new_row_id]; int64_t row_len = new_row_lens[new_row_id];
// Here I iterate col indices in a row in the order of the old matrix. // Here I iterate col indices in a row in the order of the old matrix.
...@@ -595,14 +620,14 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr, ...@@ -595,14 +620,14 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
// TODO(zhengda) maybe we should sort the column indices. // TODO(zhengda) maybe we should sort the column indices.
} }
}); });
return CSRMatrix(num_rows, num_cols, return CSRMatrix(
out_indptr_arr, out_indices_arr, out_data_arr); num_rows, num_cols, out_indptr_arr, out_indices_arr, out_data_arr);
} }
template CSRMatrix CSRReorder<kDGLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids, template CSRMatrix CSRReorder<kDGLCPU, int64_t>(
runtime::NDArray new_col_ids); CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template CSRMatrix CSRReorder<kDGLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids, template CSRMatrix CSRReorder<kDGLCPU, int32_t>(
runtime::NDArray new_col_ids); CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -8,14 +8,15 @@ ...@@ -8,14 +8,15 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/runtime/config.h> #include <dgl/runtime/config.h>
#include <dgl/runtime/parallel_for.h>
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <algorithm>
#include <vector> #include <vector>
#include "spmm_binary_ops.h" #include "spmm_binary_ops.h"
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
...@@ -44,7 +45,8 @@ namespace cpu { ...@@ -44,7 +45,8 @@ namespace cpu {
* JIT'ed kernel. * JIT'ed kernel.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast, void SpMMSumCsrXbyak(
dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
const CSRMatrix& csr, const DType* X, const DType* W, DType* O) { const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
...@@ -79,8 +81,9 @@ void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast ...@@ -79,8 +81,9 @@ void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast
* for the computation of different nodes. * for the computation of different nodes.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X, void SpMMSumCsrNaive(
const DType* W, DType* O) { const BcastOff& bcast, const CSRMatrix& csr, const DType* X, const DType* W,
DType* O) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
...@@ -118,8 +121,9 @@ void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X ...@@ -118,8 +121,9 @@ void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X
* for the computation of different nodes. * for the computation of different nodes.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void SpMMSumCsr(
NDArray efeat, NDArray out) { const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
...@@ -135,15 +139,13 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -135,15 +139,13 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
CHECK_NOTNULL(X); CHECK_NOTNULL(X);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
if (has_idx) if (has_idx) CHECK_NOTNULL(edges);
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W); CHECK_NOTNULL(W);
} }
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
#ifdef USE_LIBXSMM #ifdef USE_LIBXSMM
const bool no_libxsmm = const bool no_libxsmm = bcast.use_bcast ||
bcast.use_bcast ||
std::is_same<DType, double>::value || std::is_same<DType, double>::value ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable(); !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) { if (!no_libxsmm) {
...@@ -186,8 +188,9 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -186,8 +188,9 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
* we use atomic operators in the reduction phase. * we use atomic operators in the reduction phase.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, void SpMMSumCoo(
NDArray efeat, NDArray out) { const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data); const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>(); const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>(); const IdType* col = coo.col.Ptr<IdType>();
...@@ -232,16 +235,19 @@ void SpMMSumCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, ...@@ -232,16 +235,19 @@ void SpMMSumCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
* \param argu Arg-Min/Max on source nodes, which refers the source node indices * \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. \param arge Arg-Min/Max on edges. which refers the source node * reducer.
* indices correspond to the minimum/maximum values of reduction result on * \param arge Arg-Min/Max on edges. which refers the source node indices
correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. \note It uses node parallel strategy, different threads are * reducer.
* responsible for the computation of different nodes. \note The result will * \note It uses node parallel strategy, different threads are responsible for
* contain infinity for zero-degree nodes. * the computation of different nodes.
* \note The result will contain infinity for zero-degree nodes.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void SpMMCmpCsr(
NDArray efeat, NDArray out, NDArray argu, NDArray arge) { const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* indices = static_cast<IdType*>(csr.indices->data); const IdType* indices = static_cast<IdType*>(csr.indices->data);
...@@ -262,8 +268,7 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -262,8 +268,7 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
CHECK_NOTNULL(argX); CHECK_NOTNULL(argX);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
if (has_idx) if (has_idx) CHECK_NOTNULL(edges);
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W); CHECK_NOTNULL(W);
CHECK_NOTNULL(argW); CHECK_NOTNULL(argW);
} }
...@@ -271,12 +276,12 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -271,12 +276,12 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
#ifdef USE_AVX #ifdef USE_AVX
#ifdef USE_LIBXSMM #ifdef USE_LIBXSMM
const bool no_libxsmm = const bool no_libxsmm = bcast.use_bcast ||
bcast.use_bcast ||
std::is_same<DType, double>::value || std::is_same<DType, double>::value ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable(); !dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) { if (!no_libxsmm) {
SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge); SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(
bcast, csr, ufeat, efeat, out, argu, arge);
} else { } else {
#endif // USE_LIBXSMM #endif // USE_LIBXSMM
#endif // USE_AVX #endif // USE_AVX
...@@ -328,24 +333,26 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -328,24 +333,26 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. * reducer.
* \param arge Arg-Min/Max on edges. which refers the source node * \param arge Arg-Min/Max on edges. which refers the source node indices
* indices correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. * reducer.
* \param argu_ntype Node type of the arg-Min/Max on source nodes, which refers the * \param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
* source node types correspond to the minimum/maximum values of reduction result * the source node types correspond to the minimum/maximum values of
* on destination nodes. It's useful in computing gradients of Min/Max reducer. * reduction result on destination nodes. It's useful in computing
* \param arge_etype Edge-type of the arg-Min/Max on edges. which refers the source * gradients of Min/Max reducer.
* node indices correspond to the minimum/maximum values of reduction result on * \param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
* destination nodes. It's useful in computing gradients of Min/Max reducer. * source node indices correspond to the minimum/maximum values of
* reduction result on destination nodes. It's useful in computing
* gradients of Min/Max reducer.
* \param src_type Node type of the source nodes of an etype * \param src_type Node type of the source nodes of an etype
* \param etype Edge type * \param etype Edge type
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void SpMMCmpCsrHetero(
NDArray efeat, NDArray out, NDArray argu, NDArray arge, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray argu_ntype, NDArray arge_etype, NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
const int ntype, const int etype) { NDArray arge_etype, const int ntype, const int etype) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* indices = static_cast<IdType*>(csr.indices->data); const IdType* indices = static_cast<IdType*>(csr.indices->data);
...@@ -358,8 +365,10 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat ...@@ -358,8 +365,10 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat
DType* O = static_cast<DType*>(out->data); DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr; IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr; IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
IdType* argX_ntype = Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr; IdType* argX_ntype =
IdType* argW_etype = Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr; Op::use_lhs ? static_cast<IdType*>(argu_ntype->data) : nullptr;
IdType* argW_etype =
Op::use_rhs ? static_cast<IdType*>(arge_etype->data) : nullptr;
CHECK_NOTNULL(indptr); CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O); CHECK_NOTNULL(O);
if (Op::use_lhs) { if (Op::use_lhs) {
...@@ -368,8 +377,7 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat ...@@ -368,8 +377,7 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat
CHECK_NOTNULL(argX); CHECK_NOTNULL(argX);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
if (has_idx) if (has_idx) CHECK_NOTNULL(edges);
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W); CHECK_NOTNULL(W);
CHECK_NOTNULL(argW); CHECK_NOTNULL(argW);
} }
...@@ -410,7 +418,6 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat ...@@ -410,7 +418,6 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat
}); });
} }
/*! /*!
* \brief CPU kernel of SpMM-Min/Max on Coo format. * \brief CPU kernel of SpMM-Min/Max on Coo format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
...@@ -421,17 +428,20 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat ...@@ -421,17 +428,20 @@ void SpMMCmpCsrHetero(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat
* \param argu Arg-Min/Max on source nodes, which refers the source node indices * \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. \param arge Arg-Min/Max on edges. which refers the source node * reducer.
* indices correspond to the minimum/maximum values of reduction result on * \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max * destination nodes. It's useful in computing gradients of Min/Max
* reducer. \note it uses node parallel strategy, different threads are * reducer.
* responsible for the computation of different nodes. To avoid possible data * \note it uses node parallel strategy, different threads are responsible for
* hazard, we use atomic operators in the reduction phase. \note The result will * the computation of different nodes. To avoid possible data hazard, we
* contain infinity for zero-degree nodes. * use atomic operators in the reduction phase.
* \note The result will contain infinity for zero-degree nodes.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, void SpMMCmpCoo(
NDArray efeat, NDArray out, NDArray argu, NDArray arge) { const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const bool has_idx = !IsNullArray(coo.data); const bool has_idx = !IsNullArray(coo.data);
const IdType* row = static_cast<IdType*>(coo.row->data); const IdType* row = static_cast<IdType*>(coo.row->data);
const IdType* col = static_cast<IdType*>(coo.col->data); const IdType* col = static_cast<IdType*>(coo.col->data);
...@@ -474,7 +484,6 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, ...@@ -474,7 +484,6 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
} }
} }
/*! /*!
* \brief CPU kernel of Edge_softmax_csr_forward on Csr format. * \brief CPU kernel of Edge_softmax_csr_forward on Csr format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
...@@ -484,8 +493,9 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, ...@@ -484,8 +493,9 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
* \param out The result of edge_softmax_forward. * \param out The result of edge_softmax_forward.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void Edge_softmax_csr_forward(
NDArray efeat, NDArray out) { const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges = const IdType* edges =
...@@ -495,8 +505,8 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr ...@@ -495,8 +505,8 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) { runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) { for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
std::vector<DType> data_e(row_end-row_start, 0); std::vector<DType> data_e(row_end - row_start, 0);
std::vector<IdType> num(row_end-row_start, 0); std::vector<IdType> num(row_end - row_start, 0);
for (int64_t k = 0; k < dim; ++k) { for (int64_t k = 0; k < dim; ++k) {
DType max_v = -std::numeric_limits<DType>::infinity(); DType max_v = -std::numeric_limits<DType>::infinity();
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
...@@ -504,8 +514,8 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr ...@@ -504,8 +514,8 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off = const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr; Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
data_e[j-row_start] = *rhs_off; data_e[j - row_start] = *rhs_off;
num[j-row_start] = eid*rhs_dim+rhs_add; num[j - row_start] = eid * rhs_dim + rhs_add;
max_v = std::max<DType>(max_v, (*rhs_off)); max_v = std::max<DType>(max_v, (*rhs_off));
} }
DType exp_sum = 0; DType exp_sum = 0;
...@@ -514,15 +524,14 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr ...@@ -514,15 +524,14 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr
element = std::exp(element); element = std::exp(element);
exp_sum += element; exp_sum += element;
} }
for (int i=0; i < row_end-row_start; i++) { for (int i = 0; i < row_end - row_start; i++) {
out.Ptr<DType>()[num[i]] = data_e[i]/exp_sum; out.Ptr<DType>()[num[i]] = data_e[i] / exp_sum;
} }
} }
} }
}); });
} }
/*! /*!
* \brief CPU kernel of Edge_softmax_csr_backward on Csr format. * \brief CPU kernel of Edge_softmax_csr_backward on Csr format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
...@@ -532,8 +541,9 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr ...@@ -532,8 +541,9 @@ void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArr
* \param back_out The result of edge_softmax_backward. * \param back_out The result of edge_softmax_backward.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDArray out, void Edge_softmax_csr_backward(
NDArray sds, NDArray back_out) { const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds,
NDArray back_out) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data); const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges = const IdType* edges =
...@@ -553,14 +563,15 @@ void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDAr ...@@ -553,14 +563,15 @@ void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDAr
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr; Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
sum_sds += (*rhs_off_sds); sum_sds += (*rhs_off_sds);
} }
for (IdType j = row_start; j< row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off_out = const DType* rhs_off_out =
Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr; Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
const DType* rhs_off_sds = const DType* rhs_off_sds =
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr; Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
back_out.Ptr<DType>()[eid*rhs_dim+rhs_add] = (*rhs_off_sds) - sum_sds*(*rhs_off_out); back_out.Ptr<DType>()[eid * rhs_dim + rhs_add] =
(*rhs_off_sds) - sum_sds * (*rhs_off_out);
} }
} }
} }
......
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <algorithm> #include <algorithm>
#if !defined(_WIN32) #if !defined(_WIN32)
#ifdef USE_AVX #ifdef USE_AVX
#ifdef USE_LIBXSMM #ifdef USE_LIBXSMM
#include <unistd.h>
#include <libxsmm.h> #include <libxsmm.h>
#include <unistd.h>
#ifdef DEBUG #ifdef DEBUG
#include <x86intrin.h> #include <x86intrin.h>
#endif // DEBUG #endif // DEBUG
...@@ -53,8 +54,10 @@ int32_t GetLLCSize() { ...@@ -53,8 +54,10 @@ int32_t GetLLCSize() {
* are assigned to OMP threads. * are assigned to OMP threads.
* \param csr The Csr matrix. * \param csr The Csr matrix.
* \param block_csr_array The array containing csr matrices of all blocks. * \param block_csr_array The array containing csr matrices of all blocks.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix. * \param num_M_blocks Number of blocks to create along the rows of adjacency
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix. * matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency
* matrix.
* \param M_block_size block size along the rows of adjacency matrix. * \param M_block_size block size along the rows of adjacency matrix.
* \param K_block_size block size along the columns of adjacency matrix. * \param K_block_size block size along the columns of adjacency matrix.
* \param use_lhs Whether to use lhs. * \param use_lhs Whether to use lhs.
...@@ -62,38 +65,30 @@ int32_t GetLLCSize() { ...@@ -62,38 +65,30 @@ int32_t GetLLCSize() {
*/ */
template <typename IdType> template <typename IdType>
inline void SpMMCreateBlocks( inline void SpMMCreateBlocks(
const CSRMatrix& csr, const CSRMatrix &csr, CSRMatrixInternal<IdType, IdType> *block_csr_array,
CSRMatrixInternal<IdType, IdType> *block_csr_array, IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
IdType num_M_blocks, IdType K_block_size, bool use_lhs, bool use_rhs) {
IdType num_K_blocks,
IdType M_block_size,
IdType K_block_size,
bool use_lhs, bool use_rhs) {
const IdType M = csr.num_rows; const IdType M = csr.num_rows;
const IdType K = csr.num_cols; const IdType K = csr.num_cols;
IdType* indptr = csr.indptr.Ptr<IdType>(); IdType *indptr = csr.indptr.Ptr<IdType>();
IdType* indices = csr.indices.Ptr<IdType>(); IdType *indices = csr.indices.Ptr<IdType>();
IdType* edges = csr.data.Ptr<IdType>(); IdType *edges = csr.data.Ptr<IdType>();
CHECK_NOTNULL(indptr); CHECK_NOTNULL(indptr);
if (use_lhs) if (use_lhs) CHECK_NOTNULL(indices);
CHECK_NOTNULL(indices); if (use_rhs) CHECK_NOTNULL(edges);
if (use_rhs)
CHECK_NOTNULL(edges);
if (num_K_blocks > 1) { if (num_K_blocks > 1) {
IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64, IdType *indptr_block_buf = reinterpret_cast<IdType *>(aligned_alloc(
(M_block_size + 1) * num_M_blocks * 64, (M_block_size + 1) * num_M_blocks * num_K_blocks * sizeof(IdType)));
num_K_blocks * sizeof(IdType))); IdType *indices_block_buf = reinterpret_cast<IdType *>(
IdType *indices_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64, aligned_alloc(64, indptr[M] * sizeof(IdType)));
indptr[M] * sizeof(IdType))); IdType *edges_block_buf = reinterpret_cast<IdType *>(
IdType *edges_block_buf = reinterpret_cast<IdType *>(aligned_alloc(64, aligned_alloc(64, indptr[M] * sizeof(IdType)));
indptr[M] * sizeof(IdType)));
#pragma omp parallel #pragma omp parallel
{ {
IdType *my_cur_col_id = reinterpret_cast<IdType *>(aligned_alloc(64, 2 * M_block_size * IdType *my_cur_col_id = reinterpret_cast<IdType *>(
sizeof(IdType))); aligned_alloc(64, 2 * M_block_size * sizeof(IdType)));
#pragma omp for #pragma omp for
for (IdType m = 0; m < num_M_blocks; m++) { for (IdType m = 0; m < num_M_blocks; m++) {
...@@ -103,10 +98,8 @@ inline void SpMMCreateBlocks( ...@@ -103,10 +98,8 @@ inline void SpMMCreateBlocks(
IdType cur_indices_id = 0; IdType cur_indices_id = 0;
IdType *my_indices_block_buf, *my_edges_block_buf; IdType *my_indices_block_buf, *my_edges_block_buf;
if (use_lhs) if (use_lhs) my_indices_block_buf = indices_block_buf + indptr[M_start];
my_indices_block_buf = indices_block_buf + indptr[M_start]; if (use_rhs) my_edges_block_buf = edges_block_buf + indptr[M_start];
if (use_rhs)
my_edges_block_buf = edges_block_buf + indptr[M_start];
for (IdType i = M_start; i < M_end; i++) { for (IdType i = M_start; i < M_end; i++) {
my_cur_col_id[(i - M_start) * 2] = indptr[i]; my_cur_col_id[(i - M_start) * 2] = indptr[i];
...@@ -119,12 +112,11 @@ inline void SpMMCreateBlocks( ...@@ -119,12 +112,11 @@ inline void SpMMCreateBlocks(
cur_csr.num_rows = M_end - M_start; cur_csr.num_rows = M_end - M_start;
cur_csr.num_cols = K_end - K_start; cur_csr.num_cols = K_end - K_start;
// Create csr_ij // Create csr_ij
IdType *cur_csr_indptr = indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1); IdType *cur_csr_indptr =
indptr_block_buf + (m * num_K_blocks + k) * (M_block_size + 1);
IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr; IdType *cur_csr_indices = nullptr, *cur_csr_edges = nullptr;
if (use_lhs) if (use_lhs) cur_csr_indices = my_indices_block_buf + cur_indices_id;
cur_csr_indices = my_indices_block_buf + cur_indices_id; if (use_rhs) cur_csr_edges = my_edges_block_buf + cur_indices_id;
if (use_rhs)
cur_csr_edges = my_edges_block_buf + cur_indices_id;
IdType cur_nnz = 0; IdType cur_nnz = 0;
for (IdType i = M_start; i < M_end; i++) { for (IdType i = M_start; i < M_end; i++) {
const IdType row_start = my_cur_col_id[(i - M_start) * 2]; const IdType row_start = my_cur_col_id[(i - M_start) * 2];
...@@ -138,10 +130,8 @@ inline void SpMMCreateBlocks( ...@@ -138,10 +130,8 @@ inline void SpMMCreateBlocks(
break; break;
} }
CHECK_LT(cur_indices_id + cur_nnz, nnz); CHECK_LT(cur_indices_id + cur_nnz, nnz);
if (use_lhs) if (use_lhs) cur_csr_indices[cur_nnz] = src;
cur_csr_indices[cur_nnz] = src; if (use_rhs) cur_csr_edges[cur_nnz] = edge;
if (use_rhs)
cur_csr_edges[cur_nnz] = edge;
cur_nnz++; cur_nnz++;
} }
my_cur_col_id[(i - M_start) * 2] = eid; my_cur_col_id[(i - M_start) * 2] = eid;
...@@ -149,10 +139,8 @@ inline void SpMMCreateBlocks( ...@@ -149,10 +139,8 @@ inline void SpMMCreateBlocks(
cur_csr_indptr[cur_csr.num_rows] = cur_nnz; cur_csr_indptr[cur_csr.num_rows] = cur_nnz;
cur_indices_id += cur_nnz; cur_indices_id += cur_nnz;
cur_csr.indptr = cur_csr_indptr; cur_csr.indptr = cur_csr_indptr;
if (use_lhs) if (use_lhs) cur_csr.indices = cur_csr_indices;
cur_csr.indices = cur_csr_indices; if (use_rhs) cur_csr.data = cur_csr_edges;
if (use_rhs)
cur_csr.data = cur_csr_edges;
block_csr_array[m * num_K_blocks + k] = cur_csr; block_csr_array[m * num_K_blocks + k] = cur_csr;
} }
CHECK_EQ(nnz, cur_indices_id); CHECK_EQ(nnz, cur_indices_id);
...@@ -199,9 +187,7 @@ inline void SpMMCreateBlocks( ...@@ -199,9 +187,7 @@ inline void SpMMCreateBlocks(
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel( inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
bool has_idx, bool has_idx, IdType N, libxsmm_meltw_opreduce_vecs_flags redop_flag,
IdType N,
libxsmm_meltw_opreduce_vecs_flags redop_flag,
bool is_cmp) { bool is_cmp) {
int _ld = N; int _ld = N;
libxsmm_meltw_opreduce_vecs_flags opredop_flags; libxsmm_meltw_opreduce_vecs_flags opredop_flags;
...@@ -220,40 +206,52 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel( ...@@ -220,40 +206,52 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY; opredop_flags = LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OP_COPY;
} }
// Second, set which of lhs or rhs is considered first and second operand. // Second, set which of lhs or rhs is considered first and second operand.
// This is needed since libxsmm assumes that the copy operation always copies the first operand. // This is needed since libxsmm assumes that the copy operation always copies
// So, if we need to copy rhs, we need to set that as the first operand. // the first operand. So, if we need to copy rhs, we need to set that as the
// For rhs, we also set whether to use implicit indices or provided indices. // first operand. For rhs, we also set whether to use implicit indices or
// provided indices.
// TODO(Steve): fix this long line in a separate PR.
if (std::is_same<Op, op::CopyLhs<DType>>::value) { if (std::is_same<Op, op::CopyLhs<DType>>::value) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);
} else if (std::is_same<Op, op::CopyRhs<DType>>::value) { } else if (std::is_same<Op, op::CopyRhs<DType>>::value) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIN_VECIDX);
if (!has_idx) { if (!has_idx) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VECIDX);
} }
} else { } else {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_OPORDER_VECIDX_VECIN);
if (has_idx) { if (has_idx) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_INDEXED_VEC);
} else { } else {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_IMPLICIT_INDEXED_VEC);
} }
} }
// Third, we set the Redop in the opredop_flags // Third, we set the Redop in the opredop_flags
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag); opredop_flags =
// Fourth, in case of Cmp Redop, set whether to record argmax/argmin for lhs/rhs (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | redop_flag);
// Fourth, in case of Cmp Redop, set whether to record argmax/argmin for
// lhs/rhs
if (is_cmp) { if (is_cmp) {
if (Op::use_lhs) { if (Op::use_lhs) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_0);
} }
if (Op::use_rhs) { if (Op::use_rhs) {
opredop_flags = (libxsmm_meltw_opreduce_vecs_flags)(opredop_flags | opredop_flags =
(libxsmm_meltw_opreduce_vecs_flags)(opredop_flags |
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1); LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_RECORD_ARGOP_OFF_VEC_1);
} }
} }
...@@ -261,7 +259,8 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel( ...@@ -261,7 +259,8 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
if (std::is_same<DType, float>::value) { if (std::is_same<DType, float>::value) {
kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx( kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
(sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32, opredop_flags); (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
opredop_flags);
} }
if (kernel == nullptr) { if (kernel == nullptr) {
LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation." LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation."
...@@ -278,27 +277,29 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel( ...@@ -278,27 +277,29 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
* \param C The result feature on destination nodes. * \param C The result feature on destination nodes.
* \param has_idx For the edge features, are there indices available. * \param has_idx For the edge features, are there indices available.
* \param N Feature size. * \param N Feature size.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix. * \param num_M_blocks Number of blocks to create along the rows of adjacency
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix. * matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency
* matrix.
* \param M_block_size block size along the rows of adjacency matrix. * \param M_block_size block size along the rows of adjacency matrix.
* \param kernel The libxsmm kernel. * \param kernel The libxsmm kernel.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
inline void SpMMBlockwiseOpSum( inline void SpMMBlockwiseOpSum(
CSRMatrixInternal<IdType, IdType> *block_csr_array, CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
const DType *B, const DType *E, DType *C, bool has_idx, IdType N, const DType *E, DType *C, bool has_idx, IdType N, IdType num_M_blocks,
IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size, IdType num_K_blocks, IdType M_block_size,
libxsmm_meltwfunction_opreduce_vecs_idx kernel) { libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
DType(*in_matrix1)[N] = (DType(*)[N])B;
DType (*in_matrix1)[N] = (DType (*)[N])B; DType(*in_matrix2)[N] = (DType(*)[N])E;
DType (*in_matrix2)[N] = (DType (*)[N])E; DType(*output)[N] = (DType(*)[N])C;
DType (*output)[N] = (DType (*)[N])C;
#pragma omp parallel #pragma omp parallel
{ {
for (IdType k = 0; k < num_K_blocks; k++) { for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic) #pragma omp for schedule(dynamic)
for (IdType m = 0; m < num_M_blocks; m++) { for (IdType m = 0; m < num_M_blocks; m++) {
CSRMatrixInternal<IdType, IdType> cur_csr = block_csr_array[m * num_K_blocks + k]; CSRMatrixInternal<IdType, IdType> cur_csr =
block_csr_array[m * num_K_blocks + k];
const IdType M_start = m * M_block_size; const IdType M_start = m * M_block_size;
for (IdType i = 0; i < cur_csr.num_rows; i++) { for (IdType i = 0; i < cur_csr.num_rows; i++) {
...@@ -335,31 +336,32 @@ inline void SpMMBlockwiseOpSum( ...@@ -335,31 +336,32 @@ inline void SpMMBlockwiseOpSum(
* \param argE Arg-Min/Max on edges. * \param argE Arg-Min/Max on edges.
* \param has_idx For the edge features, are there indices available. * \param has_idx For the edge features, are there indices available.
* \param N Feature size. * \param N Feature size.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix. * \param num_M_blocks Number of blocks to create along the rows of adjacency
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix. * matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency
* matrix.
* \param M_block_size block size along the rows of adjacency matrix. * \param M_block_size block size along the rows of adjacency matrix.
* \param kernel The libxsmm kernel. * \param kernel The libxsmm kernel.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
inline void SpMMBlockwiseOpCmp( inline void SpMMBlockwiseOpCmp(
CSRMatrixInternal<IdType, IdType> *block_csr_array, CSRMatrixInternal<IdType, IdType> *block_csr_array, const DType *B,
const DType *B, const DType *E, DType *C, IdType *argB, IdType *argE, const DType *E, DType *C, IdType *argB, IdType *argE, bool has_idx,
bool has_idx, IdType N, IdType N, IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
IdType num_M_blocks, IdType num_K_blocks, IdType M_block_size,
libxsmm_meltwfunction_opreduce_vecs_idx kernel) { libxsmm_meltwfunction_opreduce_vecs_idx kernel) {
DType(*in_matrix1)[N] = (DType(*)[N])B;
DType (*in_matrix1)[N] = (DType (*)[N])B; DType(*in_matrix2)[N] = (DType(*)[N])E;
DType (*in_matrix2)[N] = (DType (*)[N])E; DType(*output)[N] = (DType(*)[N])C;
DType (*output)[N] = (DType (*)[N])C; IdType(*out_matrix1)[N] = (IdType(*)[N])argB;
IdType (*out_matrix1)[N] = (IdType (*)[N])argB; IdType(*out_matrix2)[N] = (IdType(*)[N])argE;
IdType (*out_matrix2)[N] = (IdType (*)[N])argE;
#pragma omp parallel #pragma omp parallel
{ {
for (IdType k = 0; k < num_K_blocks; k++) { for (IdType k = 0; k < num_K_blocks; k++) {
#pragma omp for schedule(dynamic) #pragma omp for schedule(dynamic)
for (IdType m = 0; m < num_M_blocks; m++) { for (IdType m = 0; m < num_M_blocks; m++) {
CSRMatrixInternal<IdType, IdType> cur_csr = block_csr_array[m * num_K_blocks + k]; CSRMatrixInternal<IdType, IdType> cur_csr =
block_csr_array[m * num_K_blocks + k];
const IdType M_start = m * M_block_size; const IdType M_start = m * M_block_size;
for (IdType i = 0; i < cur_csr.num_rows; i++) { for (IdType i = 0; i < cur_csr.num_rows; i++) {
...@@ -391,23 +393,21 @@ inline void SpMMBlockwiseOpCmp( ...@@ -391,23 +393,21 @@ inline void SpMMBlockwiseOpCmp(
/*! /*!
* \brief Free the tiled CSR matrix data. * \brief Free the tiled CSR matrix data.
* \param block_csr_array The array containing csr matrices of all blocks. * \param block_csr_array The array containing csr matrices of all blocks.
* \param num_M_blocks Number of blocks to create along the rows of adjacency matrix. * \param num_M_blocks Number of blocks to create along the rows of adjacency
* \param num_K_blocks Number of blocks to create along the columns of adjacency matrix. * matrix.
* \param num_K_blocks Number of blocks to create along the columns of adjacency
* matrix.
* \param use_lhs Whether to use lhs. * \param use_lhs Whether to use lhs.
* \param use_rhs Whether to use rhs. * \param use_rhs Whether to use rhs.
*/ */
template <typename IdType> template <typename IdType>
inline void SpMMFreeBlocks( inline void SpMMFreeBlocks(
CSRMatrixInternal<IdType, IdType> *block_csr_array, CSRMatrixInternal<IdType, IdType> *block_csr_array, IdType num_M_blocks,
IdType num_M_blocks, IdType num_K_blocks, IdType num_K_blocks, bool use_lhs, bool use_rhs) {
bool use_lhs, bool use_rhs) {
if (num_K_blocks > 1) { if (num_K_blocks > 1) {
free(block_csr_array[0].indptr); free(block_csr_array[0].indptr);
if (use_lhs) if (use_lhs) free(block_csr_array[0].indices);
free(block_csr_array[0].indices); if (use_rhs) free(block_csr_array[0].data);
if (use_rhs)
free(block_csr_array[0].data);
} }
free(block_csr_array); free(block_csr_array);
} }
...@@ -425,12 +425,8 @@ inline void SpMMFreeBlocks( ...@@ -425,12 +425,8 @@ inline void SpMMFreeBlocks(
*/ */
template <typename IdType, typename DType, typename Op, typename Redop> template <typename IdType, typename DType, typename Op, typename Redop>
void SpMMRedopCsrOpt( void SpMMRedopCsrOpt(
const BcastOff& bcast, const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
const CSRMatrix& csr, NDArray out, NDArray argu, NDArray arge) {
NDArray ufeat, NDArray efeat,
NDArray out,
NDArray argu, NDArray arge) {
int32_t llc_size = GetLLCSize(); int32_t llc_size = GetLLCSize();
#ifdef DEBUG #ifdef DEBUG
...@@ -440,11 +436,12 @@ void SpMMRedopCsrOpt( ...@@ -440,11 +436,12 @@ void SpMMRedopCsrOpt(
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
DType* C = out.Ptr<DType>(); DType *C = out.Ptr<DType>();
const DType* B = ufeat.Ptr<DType>(); const DType *B = ufeat.Ptr<DType>();
const DType* E = efeat.Ptr<DType>(); const DType *E = efeat.Ptr<DType>();
IdType *argB, *argE; IdType *argB, *argE;
if (std::is_same<Redop, op::Max<DType>>::value || std::is_same<Redop, op::Min<DType>>::value) { if (std::is_same<Redop, op::Max<DType>>::value ||
std::is_same<Redop, op::Min<DType>>::value) {
argB = argu.Ptr<IdType>(); argB = argu.Ptr<IdType>();
argE = arge.Ptr<IdType>(); argE = arge.Ptr<IdType>();
} }
...@@ -453,7 +450,7 @@ void SpMMRedopCsrOpt( ...@@ -453,7 +450,7 @@ void SpMMRedopCsrOpt(
const IdType M = csr.num_rows; const IdType M = csr.num_rows;
const IdType N = bcast.out_len; const IdType N = bcast.out_len;
const IdType K = csr.num_cols; const IdType K = csr.num_cols;
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType *indptr = csr.indptr.Ptr<IdType>();
CHECK_NOTNULL(indptr); CHECK_NOTNULL(indptr);
const IdType total_nnz = indptr[M]; const IdType total_nnz = indptr[M];
if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return; if (M <= 0 || K <= 0 || N <= 0 || total_nnz <= 0) return;
...@@ -461,8 +458,9 @@ void SpMMRedopCsrOpt( ...@@ -461,8 +458,9 @@ void SpMMRedopCsrOpt(
const double avg_degree = total_nnz * 1.0 / M; const double avg_degree = total_nnz * 1.0 / M;
const double nnz_prob = avg_degree / K; const double nnz_prob = avg_degree / K;
IdType K_block_size = std::min((int64_t)K, (int64_t)(llc_size / (N * sizeof(DType) * IdType K_block_size = std::min(
nnz_prob * BLOCKING_HEURISTIC_PARAM))); (int64_t)K,
(int64_t)(llc_size / (N * sizeof(DType) * nnz_prob * BLOCKING_HEURISTIC_PARAM)));
IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD); IdType M_block_size = M / (nthreads * NUM_BLOCKS_PER_THREAD);
if (M_block_size == 0) M_block_size = 1; if (M_block_size == 0) M_block_size = 1;
if (K_block_size == 0) K_block_size = 1; if (K_block_size == 0) K_block_size = 1;
...@@ -471,8 +469,9 @@ void SpMMRedopCsrOpt( ...@@ -471,8 +469,9 @@ void SpMMRedopCsrOpt(
IdType num_K_blocks = (K + K_block_size - 1) / K_block_size; IdType num_K_blocks = (K + K_block_size - 1) / K_block_size;
CSRMatrixInternal<IdType, IdType> *block_csr_array = CSRMatrixInternal<IdType, IdType> *block_csr_array =
(CSRMatrixInternal<IdType, IdType> *)aligned_alloc(64, (CSRMatrixInternal<IdType, IdType> *)aligned_alloc(
sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks * num_K_blocks); 64, sizeof(CSRMatrixInternal<IdType, IdType>) * num_M_blocks *
num_K_blocks);
#ifdef DEBUG #ifdef DEBUG
endTick = __rdtsc(); endTick = __rdtsc();
...@@ -489,14 +488,17 @@ void SpMMRedopCsrOpt( ...@@ -489,14 +488,17 @@ void SpMMRedopCsrOpt(
LOG(INFO) << "total_nnz = " << total_nnz << ", avg_degree = " << avg_degree; LOG(INFO) << "total_nnz = " << total_nnz << ", avg_degree = " << avg_degree;
LOG(INFO) << "has_idx = " << has_idx; LOG(INFO) << "has_idx = " << has_idx;
LOG(INFO) << "nnz_prob = " << nnz_prob; LOG(INFO) << "nnz_prob = " << nnz_prob;
LOG(INFO) << "K_block_size = " << K_block_size << ", M_block_size = " << M_block_size; LOG(INFO) << "K_block_size = " << K_block_size
LOG(INFO) << "num_K_blocks = " << num_K_blocks << ", num_M_blocks = " << num_M_blocks; << ", M_block_size = " << M_block_size;
LOG(INFO) << "num_K_blocks = " << num_K_blocks
<< ", num_M_blocks = " << num_M_blocks;
LOG(INFO) << "stage0 ticks = " << (endTick - startTick); LOG(INFO) << "stage0 ticks = " << (endTick - startTick);
startTick = __rdtsc(); startTick = __rdtsc();
#endif // DEBUG #endif // DEBUG
SpMMCreateBlocks(csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size, K_block_size, SpMMCreateBlocks(
Op::use_lhs, Op::use_rhs); csr, block_csr_array, num_M_blocks, num_K_blocks, M_block_size,
K_block_size, Op::use_lhs, Op::use_rhs);
#ifdef DEBUG #ifdef DEBUG
endTick = __rdtsc(); endTick = __rdtsc();
...@@ -506,17 +508,14 @@ void SpMMRedopCsrOpt( ...@@ -506,17 +508,14 @@ void SpMMRedopCsrOpt(
libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr; libxsmm_meltwfunction_opreduce_vecs_idx kernel = nullptr;
if (std::is_same<Redop, op::Max<DType>>::value) { if (std::is_same<Redop, op::Max<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N, kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MAX, true);
true);
} else if (std::is_same<Redop, op::Min<DType>>::value) { } else if (std::is_same<Redop, op::Min<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N, kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN, has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_MIN, true);
true);
} else if (std::is_same<Redop, op::Add<DType>>::value) { } else if (std::is_same<Redop, op::Add<DType>>::value) {
kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(has_idx, N, kernel = SpMMCreateLibxsmmKernel<IdType, DType, Op>(
LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, has_idx, N, LIBXSMM_MELTW_FLAG_OPREDUCE_VECS_REDOP_SUM, false);
false);
} }
#ifdef DEBUG #ifdef DEBUG
...@@ -525,11 +524,14 @@ void SpMMRedopCsrOpt( ...@@ -525,11 +524,14 @@ void SpMMRedopCsrOpt(
startTick = __rdtsc(); startTick = __rdtsc();
#endif // DEBUG #endif // DEBUG
if (std::is_same<Redop, op::Max<DType>>::value || std::is_same<Redop, op::Min<DType>>::value) { if (std::is_same<Redop, op::Max<DType>>::value ||
SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(block_csr_array, B, E, C, argB, argE, has_idx, N, std::is_same<Redop, op::Min<DType>>::value) {
num_M_blocks, num_K_blocks, M_block_size, kernel); SpMMBlockwiseOpCmp<IdType, DType, Op, Redop>(
block_csr_array, B, E, C, argB, argE, has_idx, N, num_M_blocks,
num_K_blocks, M_block_size, kernel);
} else { } else {
SpMMBlockwiseOpSum(block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks, SpMMBlockwiseOpSum(
block_csr_array, B, E, C, has_idx, N, num_M_blocks, num_K_blocks,
M_block_size, kernel); M_block_size, kernel);
} }
...@@ -539,7 +541,8 @@ void SpMMRedopCsrOpt( ...@@ -539,7 +541,8 @@ void SpMMRedopCsrOpt(
startTick = __rdtsc(); startTick = __rdtsc();
#endif // DEBUG #endif // DEBUG
SpMMFreeBlocks(block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs); SpMMFreeBlocks(
block_csr_array, num_M_blocks, num_K_blocks, Op::use_lhs, Op::use_rhs);
#ifdef DEBUG #ifdef DEBUG
endTick = __rdtsc(); endTick = __rdtsc();
...@@ -557,10 +560,12 @@ void SpMMRedopCsrOpt( ...@@ -557,10 +560,12 @@ void SpMMRedopCsrOpt(
* \note it uses libxsmm, blocking and dynamic thread scheduling. * \note it uses libxsmm, blocking and dynamic thread scheduling.
*/ */
template <typename IdType, typename DType, typename Op> template <typename IdType, typename DType, typename Op>
void SpMMSumCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr, void SpMMSumCsrLibxsmm(
NDArray ufeat, NDArray efeat, NDArray out) { const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
NDArray out) {
NDArray dummy; NDArray dummy;
SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(bcast, csr, ufeat, efeat, out, dummy, dummy); SpMMRedopCsrOpt<IdType, DType, Op, op::Add<DType>>(
bcast, csr, ufeat, efeat, out, dummy, dummy);
} }
/*! /*!
...@@ -575,9 +580,11 @@ void SpMMSumCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr, ...@@ -575,9 +580,11 @@ void SpMMSumCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr,
* \note it uses libxsmm, blocking and dynamic thread scheduling. * \note it uses libxsmm, blocking and dynamic thread scheduling.
*/ */
template <typename IdType, typename DType, typename Op, typename Cmp> template <typename IdType, typename DType, typename Op, typename Cmp>
void SpMMCmpCsrLibxsmm(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, void SpMMCmpCsrLibxsmm(
NDArray efeat, NDArray out, NDArray argu, NDArray arge) { const BcastOff &bcast, const CSRMatrix &csr, NDArray ufeat, NDArray efeat,
SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge); NDArray out, NDArray argu, NDArray arge) {
SpMMRedopCsrOpt<IdType, DType, Op, Cmp>(
bcast, csr, ufeat, efeat, out, argu, arge);
} }
} // namespace cpu } // namespace cpu
......
...@@ -4,57 +4,48 @@ ...@@ -4,57 +4,48 @@
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include "./traversal.h"
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
namespace { namespace {
// A utility view class to wrap a vector into a queue. // A utility view class to wrap a vector into a queue.
template<typename DType> template <typename DType>
struct VectorQueueWrapper { struct VectorQueueWrapper {
std::vector<DType>* vec; std::vector<DType>* vec;
size_t head = 0; size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {} explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) { void push(const DType& elem) { vec->push_back(elem); }
vec->push_back(elem);
}
DType top() const { DType top() const { return vec->operator[](head); }
return vec->operator[](head);
}
void pop() { void pop() { ++head; }
++head;
}
bool empty() const { bool empty() const { return head == vec->size(); }
return head == vec->size();
}
size_t size() const { size_t size() const { return vec->size() - head; }
return vec->size() - head;
}
}; };
// Internal function to merge multiple traversal traces into one ndarray. // Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together. // It is similar to zip the vectors together.
template<typename DType> template <typename DType>
IdArray MergeMultipleTraversals( IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0; int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, IdArray ret = IdArray::Empty(
DGLDataType{kDGLInt, sizeof(DType) * 8, 1}, {total_len}, DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
...@@ -71,15 +62,15 @@ IdArray MergeMultipleTraversals( ...@@ -71,15 +62,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces // Internal function to compute sections if multiple traversal traces
// are merged into one ndarray. // are merged into one ndarray.
template<typename DType> template <typename DType>
IdArray ComputeMergedSections( IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0; int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -101,8 +92,8 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -101,8 +92,8 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids); VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const int64_t v) { }; auto visit = [&](const int64_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
sections.push_back(queue.size()); sections.push_back(queue.size());
...@@ -116,8 +107,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -116,8 +107,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
...@@ -126,7 +119,7 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -126,7 +119,7 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<IdType> nodes; std::vector<IdType> nodes;
VectorQueueWrapper<IdType> queue(&nodes); VectorQueueWrapper<IdType> queue(&nodes);
auto visit = [&] (const IdType e) { ids.push_back(e); }; auto visit = [&](const IdType e) { ids.push_back(e); };
bool first_frontier = true; bool first_frontier = true;
auto make_frontier = [&] { auto make_frontier = [&] {
if (first_frontier) { if (first_frontier) {
...@@ -144,16 +137,18 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -144,16 +137,18 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids); VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const uint64_t v) { }; auto visit = [&](const uint64_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
sections.push_back(queue.size()); sections.push_back(queue.size());
...@@ -167,8 +162,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { ...@@ -167,8 +162,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front; return front;
} }
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&); const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
...@@ -177,7 +174,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -177,7 +174,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
std::vector<std::vector<IdType>> edges(len); std::vector<std::vector<IdType>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int tag) { edges[i].push_back(e); }; auto visit = [&](IdType e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit); DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit);
} }
...@@ -191,11 +188,9 @@ template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); ...@@ -191,11 +188,9 @@ template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels) {
const bool has_nontree_edge,
const bool return_labels) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len); std::vector<std::vector<IdType>> edges(len);
...@@ -206,14 +201,14 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -206,14 +201,14 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
} }
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int64_t tag) { auto visit = [&](IdType e, int64_t tag) {
edges[i].push_back(e); edges[i].push_back(e);
if (return_labels) { if (return_labels) {
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges<IdType>(csr, src_data[i], DFSLabeledEdges<IdType>(
has_reverse_edge, has_nontree_edge, visit); csr, src_data[i], has_reverse_edge, has_nontree_edge, visit);
} }
Frontiers front; Frontiers front;
...@@ -226,16 +221,10 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -226,16 +221,10 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front; return front;
} }
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(
IdArray, const CSRMatrix&, IdArray, const bool, const bool, const bool);
const bool, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(
const bool, const CSRMatrix&, IdArray, const bool, const bool, const bool);
const bool);
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
* \file array/cpu/traversal.h * \file array/cpu/traversal.h
* \brief Graph traversal routines. * \brief Graph traversal routines.
* *
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge * Traversal routines generate frontiers. Frontiers can be node frontiers or
* frontiers depending on the traversal function. Each frontier is a * edge frontiers depending on the traversal function. Each frontier is a list
* list of nodes/edges (specified by their ids). An optional tag can be specified * of nodes/edges (specified by their ids). An optional tag can be specified for
* for each node/edge (represented by an int value). * each node/edge (represented by an int value).
*/ */
#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_ #ifndef DGL_ARRAY_CPU_TRAVERSAL_H_
#define DGL_ARRAY_CPU_TRAVERSAL_H_ #define DGL_ARRAY_CPU_TRAVERSAL_H_
#include <dgl/graph_interface.h> #include <dgl/graph_interface.h>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -43,16 +44,16 @@ namespace impl { ...@@ -43,16 +44,16 @@ namespace impl {
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void BFSTraverseNodes(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
IdArray source, void BFSTraverseNodes(
Queue* queue, const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType *src_data = static_cast<IdType*>(source->data); const IdType *src_data = static_cast<IdType *>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -71,7 +72,7 @@ void BFSTraverseNodes(const CSRMatrix& csr, ...@@ -71,7 +72,7 @@ void BFSTraverseNodes(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
auto v = indices_data[idx]; auto v = indices_data[idx];
if (!visited[v]) { if (!visited[v]) {
visited[v] = true; visited[v] = true;
...@@ -109,16 +110,16 @@ void BFSTraverseNodes(const CSRMatrix& csr, ...@@ -109,16 +110,16 @@ void BFSTraverseNodes(const CSRMatrix& csr,
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* The argument would be edge ID. * The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made; * \param make_frontier The function to indicate that a new frontier can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void BFSTraverseEdges(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
IdArray source, void BFSTraverseEdges(
Queue* queue, const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType *src_data = static_cast<IdType *>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -138,7 +139,7 @@ void BFSTraverseEdges(const CSRMatrix& csr, ...@@ -138,7 +139,7 @@ void BFSTraverseEdges(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
auto e = eid_data ? eid_data[idx] : idx; auto e = eid_data ? eid_data[idx] : idx;
const IdType v = indices_data[idx]; const IdType v = indices_data[idx];
if (!visited[v]) { if (!visited[v]) {
...@@ -174,12 +175,13 @@ void BFSTraverseEdges(const CSRMatrix& csr, ...@@ -174,12 +175,13 @@ void BFSTraverseEdges(const CSRMatrix& csr,
* \param reversed If true, follows the in-edge direction * \param reversed If true, follows the in-edge direction
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void TopologicalNodes(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
Queue* queue, void TopologicalNodes(
VisitFn visit, const CSRMatrix &csr, Queue *queue, VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
int64_t num_visited_nodes = 0; int64_t num_visited_nodes = 0;
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
...@@ -206,7 +208,7 @@ void TopologicalNodes(const CSRMatrix& csr, ...@@ -206,7 +208,7 @@ void TopologicalNodes(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
const IdType v = indices_data[idx]; const IdType v = indices_data[idx];
if (--(degrees[v]) == 0) { if (--(degrees[v]) == 0) {
visit(v); visit(v);
...@@ -219,7 +221,8 @@ void TopologicalNodes(const CSRMatrix& csr, ...@@ -219,7 +221,8 @@ void TopologicalNodes(const CSRMatrix& csr,
} }
if (num_visited_nodes != num_nodes) { if (num_visited_nodes != num_nodes) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph."; LOG(FATAL)
<< "Error in topological traversal: loop detected in the given graph.";
} }
} }
...@@ -236,32 +239,29 @@ enum DFSEdgeTag { ...@@ -236,32 +239,29 @@ enum DFSEdgeTag {
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2)
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * been visisted but the edge is NOT in the DFS tree.
* is NOT in the DFS tree.
* *
* \param source Source node. * \param source Source node.
* \param reversed If true, DFS follows the in-edge direction * \param reversed If true, DFS follows the in-edge direction
* \param has_reverse_edge If true, REVERSE edges are included * \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included * \param has_nontree_edge If true, NONTREE edges are included
* \param visit The function to call when an edge is visited; the edge id and its * \param visit The function to call when an edge is visited; the edge id and
* tag will be given as the arguments. * its tag will be given as the arguments.
*/ */
template<typename IdType, typename VisitFn> template <typename IdType, typename VisitFn>
void DFSLabeledEdges(const CSRMatrix& csr, void DFSLabeledEdges(
IdType source, const CSRMatrix &csr, IdType source, bool has_reverse_edge,
bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {
bool has_nontree_edge,
VisitFn visit) {
const int64_t num_nodes = csr.num_rows; const int64_t num_nodes = csr.num_rows;
CHECK_GE(num_nodes, source) << "source " << source << CHECK_GE(num_nodes, source)
" is out of range [0," << num_nodes << "]"; << "source " << source << " is out of range [0," << num_nodes << "]";
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data); const IdType *eid_data = static_cast<IdType *>(csr.data->data);
if (indptr_data[source+1]-indptr_data[source] == 0) { if (indptr_data[source + 1] - indptr_data[source] == 0) {
// no out-going edges from the source node // no out-going edges from the source node
return; return;
} }
...@@ -278,7 +278,8 @@ void DFSLabeledEdges(const CSRMatrix& csr, ...@@ -278,7 +278,8 @@ void DFSLabeledEdges(const CSRMatrix& csr,
while (!stack.empty()) { while (!stack.empty()) {
std::tie(u, i, on_tree) = stack.top(); std::tie(u, i, on_tree) = stack.top();
const IdType v = indices_data[indptr_data[u] + i]; const IdType v = indices_data[indptr_data[u] + i];
const IdType uv = eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i; const IdType uv =
eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;
if (visited[v]) { if (visited[v]) {
if (!on_tree && has_nontree_edge) { if (!on_tree && has_nontree_edge) {
visit(uv, kNonTree); visit(uv, kNonTree);
...@@ -288,7 +289,7 @@ void DFSLabeledEdges(const CSRMatrix& csr, ...@@ -288,7 +289,7 @@ void DFSLabeledEdges(const CSRMatrix& csr,
stack.pop(); stack.pop();
// find next one. // find next one.
if (indptr_data[u] + i < indptr_data[u + 1] - 1) { if (indptr_data[u] + i < indptr_data[u + 1] - 1) {
stack.push(std::make_tuple(u, i+1, false)); stack.push(std::make_tuple(u, i + 1, false));
} }
} else { } else {
visited[v] = true; visited[v] = true;
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Array cumsum GPU implementation * \brief Array cumsum GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -17,7 +18,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -17,7 +18,8 @@ template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx); return !prepend_zero ? array
: aten::Full(0, 1, array->dtype.bits, array->ctx);
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -16,14 +17,11 @@ namespace impl { ...@@ -16,14 +17,11 @@ namespace impl {
template <typename IdType> template <typename IdType>
struct IsNonZeroIndex { struct IsNonZeroIndex {
explicit IsNonZeroIndex(const IdType * array) : array_(array) { explicit IsNonZeroIndex(const IdType* array) : array_(array) {}
}
__device__ bool operator() (const int64_t index) { __device__ bool operator()(const int64_t index) { return array_[index] != 0; }
return array_[index] != 0;
}
const IdType * array_; const IdType* array_;
}; };
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -36,22 +34,23 @@ IdArray NonZero(IdArray array) { ...@@ -36,22 +34,23 @@ IdArray NonZero(IdArray array) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType * const in_data = static_cast<const IdType*>(array->data); const IdType* const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data); int64_t* const out_data = static_cast<int64_t*>(ret->data);
IsNonZeroIndex<IdType> comp(in_data); IsNonZeroIndex<IdType> comp(in_data);
cub::CountingInputIterator<int64_t> counter(0); cub::CountingInputIterator<int64_t> counter(0);
// room for cub to output on GPU // room for cub to output on GPU
int64_t * d_num_nonzeros = static_cast<int64_t*>( int64_t* d_num_nonzeros =
device->AllocWorkspace(ctx, sizeof(int64_t))); static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
size_t temp_size = 0; size_t temp_size = 0;
CUDA_CALL(cub::DeviceSelect::If(nullptr, temp_size, counter, out_data, CUDA_CALL(cub::DeviceSelect::If(
d_num_nonzeros, len, comp, stream)); nullptr, temp_size, counter, out_data, d_num_nonzeros, len, comp,
void * temp = device->AllocWorkspace(ctx, temp_size); stream));
CUDA_CALL(cub::DeviceSelect::If(temp, temp_size, counter, out_data, void* temp = device->AllocWorkspace(ctx, temp_size);
d_num_nonzeros, len, comp, stream)); CUDA_CALL(cub::DeviceSelect::If(
temp, temp_size, counter, out_data, d_num_nonzeros, len, comp, stream));
device->FreeWorkspace(ctx, temp); device->FreeWorkspace(ctx, temp);
// copy number of selected elements from GPU to CPU // copy number of selected elements from GPU to CPU
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Array sort GPU implementation * \brief Array sort GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -29,26 +30,30 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -29,26 +30,30 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_bits == 0) { if (num_bits == 0) {
num_bits = sizeof(IdType)*8; num_bits = sizeof(IdType) * 8;
} }
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); nullptr, workspace_size, keys_in, keys_out, values_in, values_out, nitems,
0, num_bits, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
CUDA_CALL(cub::DeviceRadixSort::SortPairs(workspace, workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); workspace, workspace_size, keys_in, keys_out, values_in, values_out,
nitems, 0, num_bits, stream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits); IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(
IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief COO2CSR * \brief COO2CSR
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -46,18 +47,15 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) { ...@@ -46,18 +47,15 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
if (!COOHasData(coo)) if (!COOHasData(coo))
coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx); coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);
NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits); NDArray indptr =
aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
CUSPARSE_CALL(cusparseXcoo2csr( CUSPARSE_CALL(cusparseXcoo2csr(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.row.Ptr<int32_t>(), nnz, coo.num_rows,
coo.row.Ptr<int32_t>(), indptr_ptr, CUSPARSE_INDEX_BASE_ZERO));
nnz,
coo.num_rows, return CSRMatrix(
indptr_ptr, coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
CUSPARSE_INDEX_BASE_ZERO));
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, coo.col, coo.data, col_sorted);
} }
/*! /*!
...@@ -77,9 +75,8 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) { ...@@ -77,9 +75,8 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SortedSearchKernelUpperBound( __global__ void _SortedSearchKernelUpperBound(
const IdType* hay, int64_t hay_size, const IdType* hay, int64_t hay_size, const IdType* needles,
const IdType* needles, int64_t num_needles, int64_t num_needles, IdType* pos) {
IdType* pos) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_needles) { while (tx < num_needles) {
...@@ -123,14 +120,12 @@ CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) { ...@@ -123,14 +120,12 @@ CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const int nt = cuda::FindNumThreads(coo.num_rows); const int nt = cuda::FindNumThreads(coo.num_rows);
const int nb = (coo.num_rows + nt - 1) / nt; const int nb = (coo.num_rows + nt - 1) / nt;
IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx); IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SortedSearchKernelUpperBound, nb, nt, 0, stream, coo.row.Ptr<int64_t>(),
coo.row.Ptr<int64_t>(), nnz, nnz, rowids.Ptr<int64_t>(), coo.num_rows, indptr.Ptr<int64_t>() + 1);
rowids.Ptr<int64_t>(), coo.num_rows,
indptr.Ptr<int64_t>() + 1); return CSRMatrix(
coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, coo.col, coo.data, col_sorted);
} }
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Sort COO index * \brief Sort COO index
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
...@@ -18,21 +19,20 @@ namespace impl { ...@@ -18,21 +19,20 @@ namespace impl {
///////////////////////////// COOSort_ ///////////////////////////// ///////////////////////////// COOSort_ /////////////////////////////
/** /**
* @brief Encode row and column IDs into a single scalar per edge. * @brief Encode row and column IDs into a single scalar per edge.
* *
* @tparam IdType The type to encode as. * @tparam IdType The type to encode as.
* @param row The row (src) IDs per edge. * @param row The row (src) IDs per edge.
* @param col The column (dst) IDs per edge. * @param col The column (dst) IDs per edge.
* @param nnz The number of edges. * @param nnz The number of edges.
* @param col_bits The number of bits used to encode the destination. The row * @param col_bits The number of bits used to encode the destination. The row
* information is packed into the remaining bits. * information is packed into the remaining bits.
* @param key The encoded edges (output). * @param key The encoded edges (output).
*/ */
template <typename IdType> template <typename IdType>
__global__ void _COOEncodeEdgesKernel( __global__ void _COOEncodeEdgesKernel(
const IdType* const row, const IdType* const col, const IdType* const row, const IdType* const col, const int64_t nnz,
const int64_t nnz, const int col_bits, IdType * const key) { const int col_bits, IdType* const key) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) { if (tx < nnz) {
...@@ -41,20 +41,19 @@ __global__ void _COOEncodeEdgesKernel( ...@@ -41,20 +41,19 @@ __global__ void _COOEncodeEdgesKernel(
} }
/** /**
* @brief Decode row and column IDs from the encoded edges. * @brief Decode row and column IDs from the encoded edges.
* *
* @tparam IdType The type the edges are encoded as. * @tparam IdType The type the edges are encoded as.
* @param key The encoded edges. * @param key The encoded edges.
* @param nnz The number of edges. * @param nnz The number of edges.
* @param col_bits The number of bits used to store the column/dst ID. * @param col_bits The number of bits used to store the column/dst ID.
* @param row The row (src) IDs per edge (output). * @param row The row (src) IDs per edge (output).
* @param col The col (dst) IDs per edge (output). * @param col The col (dst) IDs per edge (output).
*/ */
template <typename IdType> template <typename IdType>
__global__ void _COODecodeEdgesKernel( __global__ void _COODecodeEdgesKernel(
const IdType* const key, const int64_t nnz, const int col_bits, const IdType* const key, const int64_t nnz, const int col_bits,
IdType * const row, IdType * const col) { IdType* const row, IdType* const col) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) { if (tx < nnz) {
...@@ -64,9 +63,7 @@ __global__ void _COODecodeEdgesKernel( ...@@ -64,9 +63,7 @@ __global__ void _COODecodeEdgesKernel(
} }
} }
template <typename T>
template<typename T>
int _NumberOfBits(const T& range) { int _NumberOfBits(const T& range) {
if (range <= 1) { if (range <= 1) {
// ranges of 0 or 1 require no bits to store // ranges of 0 or 1 require no bits to store
...@@ -74,12 +71,12 @@ int _NumberOfBits(const T& range) { ...@@ -74,12 +71,12 @@ int _NumberOfBits(const T& range) {
} }
int bits = 1; int bits = 1;
while (bits < static_cast<int>(sizeof(T)*8) && (1 << bits) < range) { while (bits < static_cast<int>(sizeof(T) * 8) && (1 << bits) < range) {
++bits; ++bits;
} }
CHECK_EQ((range-1) >> bits, 0); CHECK_EQ((range - 1) >> bits, 0);
CHECK_NE((range-1) >> (bits-1), 0); CHECK_NE((range - 1) >> (bits - 1), 0);
return bits; return bits;
} }
...@@ -95,20 +92,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -95,20 +92,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
const int num_bits = row_bits + col_bits; const int num_bits = row_bits + col_bits;
const int nt = 256; const int nt = 256;
const int nb = (nnz+nt-1)/nt; const int nb = (nnz + nt - 1) / nt;
CHECK(static_cast<int64_t>(nb)*nt >= nnz); CHECK(static_cast<int64_t>(nb) * nt >= nnz);
IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits); IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits);
CUDA_KERNEL_CALL(_COOEncodeEdgesKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>(), _COOEncodeEdgesKernel, nb, nt, 0, stream, coo->row.Ptr<IdType>(),
nnz, col_bits, pos.Ptr<IdType>()); coo->col.Ptr<IdType>(), nnz, col_bits, pos.Ptr<IdType>());
auto sorted = Sort(pos, num_bits); auto sorted = Sort(pos, num_bits);
CUDA_KERNEL_CALL(_COODecodeEdgesKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
sorted.first.Ptr<IdType>(), nnz, col_bits, _COODecodeEdgesKernel, nb, nt, 0, stream, sorted.first.Ptr<IdType>(),
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>()); nnz, col_bits, coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>());
if (aten::COOHasData(*coo)) if (aten::COOHasData(*coo))
coo->data = IndexSelect(coo->data, sorted.second); coo->data = IndexSelect(coo->data, sorted.second);
...@@ -138,8 +135,8 @@ template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column); ...@@ -138,8 +135,8 @@ template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
template <typename IdType> template <typename IdType>
__global__ void _COOIsSortedKernel( __global__ void _COOIsSortedKernel(
const IdType* row, const IdType* col, const IdType* row, const IdType* col, int64_t nnz, int8_t* row_sorted,
int64_t nnz, int8_t* row_sorted, int8_t* col_sorted) { int8_t* col_sorted) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < nnz) { while (tx < nnz) {
...@@ -148,8 +145,8 @@ __global__ void _COOIsSortedKernel( ...@@ -148,8 +145,8 @@ __global__ void _COOIsSortedKernel(
col_sorted[0] = 1; col_sorted[0] = 1;
} else { } else {
row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]); row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]);
col_sorted[tx] = static_cast<int8_t>( col_sorted[tx] =
row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]); static_cast<int8_t>(row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]);
} }
tx += stride_x; tx += stride_x;
} }
...@@ -161,18 +158,19 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -161,18 +158,19 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but should // We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but
// be fine. // should be fine.
int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz)); int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz)); int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
const int nt = cuda::FindNumThreads(nnz); const int nt = cuda::FindNumThreads(nnz);
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_COOIsSortedKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
coo.row.Ptr<IdType>(), coo.col.Ptr<IdType>(), _COOIsSortedKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),
nnz, row_flags, col_flags); coo.col.Ptr<IdType>(), nnz, row_flags, col_flags);
const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx); const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx);
const bool col_sorted = row_sorted? cuda::AllTrue(col_flags, nnz, ctx) : false; const bool col_sorted =
row_sorted ? cuda::AllTrue(col_flags, nnz, ctx) : false;
device->FreeWorkspace(ctx, row_flags); device->FreeWorkspace(ctx, row_flags);
device->FreeWorkspace(ctx, col_flags); device->FreeWorkspace(ctx, col_flags);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief CSR2COO * \brief CSR2COO
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -32,20 +33,16 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -32,20 +33,16 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data; NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
NDArray row = aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits); NDArray row =
aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);
int32_t* row_ptr = static_cast<int32_t*>(row->data); int32_t* row_ptr = static_cast<int32_t*>(row->data);
CUSPARSE_CALL(cusparseXcsr2coo( CUSPARSE_CALL(cusparseXcsr2coo(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, indptr_ptr, indices->shape[0], csr.num_rows,
indptr_ptr, row_ptr, CUSPARSE_INDEX_BASE_ZERO));
indices->shape[0],
csr.num_rows, return COOMatrix(
row_ptr, csr.num_rows, csr.num_cols, row, indices, data, true, csr.sorted);
CUSPARSE_INDEX_BASE_ZERO));
return COOMatrix(csr.num_rows, csr.num_cols,
row, indices, data,
true, csr.sorted);
} }
/*! /*!
...@@ -65,8 +62,8 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -65,8 +62,8 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
*/ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _RepeatKernel( __global__ void _RepeatKernel(
const DType* val, const IdType* pos, const DType* val, const IdType* pos, DType* out, int64_t n_row,
DType* out, int64_t n_row, int64_t length) { int64_t length) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -88,15 +85,13 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) { ...@@ -88,15 +85,13 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const int nt = 256; const int nt = 256;
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_RepeatKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _RepeatKernel, nb, nt, 0, stream, rowids.Ptr<int64_t>(),
rowids.Ptr<int64_t>(), csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(), csr.num_rows, nnz);
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
csr.num_rows, nnz); return COOMatrix(
csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
return COOMatrix(csr.num_rows, csr.num_cols, csr.sorted);
ret_row, csr.indices, csr.data,
true, csr.sorted);
} }
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
...@@ -111,8 +106,7 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -111,8 +106,7 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data)) return coo;
return coo;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx); auto device = runtime::DeviceAPI::Get(coo.row->ctx);
...@@ -130,21 +124,12 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -130,21 +124,12 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt( CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
coo.num_rows, coo.num_cols, data_ptr, row_ptr, &workspace_size));
row->shape[0],
data_ptr,
row_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size); void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow( CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
coo.num_rows, coo.num_cols, data_ptr, row_ptr, col_ptr, workspace));
row->shape[0],
data_ptr,
row_ptr,
col_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace); device->FreeWorkspace(row->ctx, workspace);
// The row and column field have already been reordered according // The row and column field have already been reordered according
...@@ -158,8 +143,7 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -158,8 +143,7 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data)) return coo;
return coo;
const auto& sorted = Sort(coo.data); const auto& sorted = Sort(coo.data);
coo.row = IndexSelect(coo.row, sorted.second); coo.row = IndexSelect(coo.row, sorted.second);
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Sort CSR index * \brief Sort CSR index
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
...@@ -20,8 +21,8 @@ namespace impl { ...@@ -20,8 +21,8 @@ namespace impl {
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentIsSorted( __global__ void _SegmentIsSorted(
const IdType* indptr, const IdType* indices, const IdType* indptr, const IdType* indices, int64_t num_rows,
int64_t num_rows, int8_t* flags) { int8_t* flags) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_rows) { while (tx < num_rows) {
...@@ -39,15 +40,15 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -39,15 +40,15 @@ bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory
// be fine. // but should be fine.
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows)); int8_t* flags =
static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
const int nt = cuda::FindNumThreads(csr.num_rows); const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentIsSorted, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentIsSorted, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, flags);
csr.num_rows, flags);
bool ret = cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, flags);
return ret; return ret;
...@@ -82,10 +83,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) { ...@@ -82,10 +83,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt( CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz,
csr->num_rows, csr->num_cols, nnz, indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), &workspace_size));
indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
cusparseMatDescr_t descr; cusparseMatDescr_t descr;
...@@ -93,11 +92,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) { ...@@ -93,11 +92,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(cusparseXcsrsort( CUSPARSE_CALL(cusparseXcsrsort(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz, descr,
csr->num_rows, csr->num_cols, nnz, indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), data.Ptr<int32_t>(),
descr,
indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
data.Ptr<int32_t>(),
workspace)); workspace));
csr->sorted = true; csr->sorted = true;
...@@ -115,8 +111,7 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) { ...@@ -115,8 +111,7 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
const auto& ctx = csr->indptr->ctx; const auto& ctx = csr->indptr->ctx;
const int64_t nnz = csr->indices->shape[0]; const int64_t nnz = csr->indices->shape[0];
const auto nbits = csr->indptr->dtype.bits; const auto nbits = csr->indptr->dtype.bits;
if (!aten::CSRHasData(*csr)) if (!aten::CSRHasData(*csr)) csr->data = aten::Range(0, nnz, nbits, ctx);
csr->data = aten::Range(0, nnz, nbits, ctx);
IdArray new_indices = csr->indices.Clone(); IdArray new_indices = csr->indices.Clone();
IdArray new_data = csr->data.Clone(); IdArray new_data = csr->data.Clone();
...@@ -129,15 +124,15 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) { ...@@ -129,15 +124,15 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
key_in, key_out, value_in, value_out, nullptr, workspace_size, key_in, key_out, value_in, value_out, nnz,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream)); csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
key_in, key_out, value_in, value_out, workspace, workspace_size, key_in, key_out, value_in, value_out, nnz,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream)); csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
csr->sorted = true; csr->sorted = true;
csr->indices = new_indices; csr->indices = new_indices;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief CSR transpose (convert to CSC) * \brief CSR transpose (convert to CSC)
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
namespace dgl { namespace dgl {
...@@ -33,14 +34,13 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -33,14 +34,13 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
const int64_t nnz = indices->shape[0]; const int64_t nnz = indices->shape[0];
const auto& ctx = indptr->ctx; const auto& ctx = indptr->ctx;
const auto bits = indptr->dtype.bits; const auto bits = indptr->dtype.bits;
if (aten::IsNullArray(data)) if (aten::IsNullArray(data)) data = aten::Range(0, nnz, bits, ctx);
data = aten::Range(0, nnz, bits, ctx);
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
const int32_t* indices_ptr = static_cast<int32_t*>(indices->data); const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
const void* data_ptr = data->data; const void* data_ptr = data->data;
// (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz == 0. // (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz
// We need to do it ourselves. // == 0. We need to do it ourselves.
NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx); NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);
NDArray t_indices = aten::NewIdArray(nnz, ctx, bits); NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
NDArray t_data = aten::NewIdArray(nnz, ctx, bits); NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
...@@ -53,40 +53,29 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -53,40 +53,29 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
// workspace // workspace
size_t workspace_size; size_t workspace_size;
CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize( CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
csr.num_rows, csr.num_cols, nnz, indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
data_ptr, indptr_ptr, indices_ptr, CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUDA_R_32F,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
&workspace_size)); &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseCsr2cscEx2( CUSPARSE_CALL(cusparseCsr2cscEx2(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
csr.num_rows, csr.num_cols, nnz, indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
data_ptr, indptr_ptr, indices_ptr, CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUDA_R_32F,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
workspace)); workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
#else #else
CUSPARSE_CALL(cusparseScsr2csc( CUSPARSE_CALL(cusparseScsr2csc(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz,
csr.num_rows, csr.num_cols, nnz,
static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr, static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr,
static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr, static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr,
CUSPARSE_ACTION_NUMERIC, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_INDEX_BASE_ZERO));
#endif #endif
return CSRMatrix(csr.num_cols, csr.num_rows, return CSRMatrix(
t_indptr, t_indices, t_data, csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false);
false);
} }
template <> template <>
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../filter.h"
#include "../../runtime/cuda/cuda_hashtable.cuh" #include "../../runtime/cuda/cuda_hashtable.cuh"
#include "../filter.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
using namespace dgl::runtime::cuda; using namespace dgl::runtime::cuda;
...@@ -20,35 +20,29 @@ namespace { ...@@ -20,35 +20,29 @@ namespace {
cudaStream_t cudaStream = runtime::getCurrentCUDAStream(); cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
template<typename IdType, bool include> template <typename IdType, bool include>
__global__ void _IsInKernel( __global__ void _IsInKernel(
DeviceOrderedHashTable<IdType> table, DeviceOrderedHashTable<IdType> table, const IdType* const array,
const IdType * const array, const int64_t size, IdType* const mark) {
const int64_t size, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const mark) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) { if (idx < size) {
mark[idx] = table.Contains(array[idx]) ^ (!include); mark[idx] = table.Contains(array[idx]) ^ (!include);
} }
} }
template<typename IdType> template <typename IdType>
__global__ void _InsertKernel( __global__ void _InsertKernel(
const IdType * const prefix, const IdType* const prefix, const int64_t size, IdType* const result) {
const int64_t size, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const result) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) { if (idx < size) {
if (prefix[idx] != prefix[idx+1]) { if (prefix[idx] != prefix[idx + 1]) {
result[prefix[idx]] = idx; result[prefix[idx]] = idx;
} }
} }
} }
template<typename IdType, bool include> template <typename IdType, bool include>
IdArray _PerformFilter( IdArray _PerformFilter(const OrderedHashTable<IdType>& table, IdArray test) {
const OrderedHashTable<IdType>& table,
IdArray test) {
const auto& ctx = test->ctx; const auto& ctx = test->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t size = test->shape[0]; const int64_t size = test->shape[0];
...@@ -60,22 +54,20 @@ IdArray _PerformFilter( ...@@ -60,22 +54,20 @@ IdArray _PerformFilter(
// we need two arrays: 1) to act as a prefixsum // we need two arrays: 1) to act as a prefixsum
// for the number of entries that will be inserted, and // for the number of entries that will be inserted, and
// 2) to collect the included items. // 2) to collect the included items.
IdType * prefix = static_cast<IdType*>( IdType* prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx, sizeof(IdType)*(size+1))); device->AllocWorkspace(ctx, sizeof(IdType) * (size + 1)));
// will resize down later // will resize down later
IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType)*8); IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType) * 8);
// mark each index based on it's existence in the hashtable // mark each index based on it's existence in the hashtable
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x); const dim3 grid((size + block.x - 1) / block.x);
CUDA_KERNEL_CALL((_IsInKernel<IdType, include>), CUDA_KERNEL_CALL(
grid, block, 0, cudaStream, (_IsInKernel<IdType, include>), grid, block, 0, cudaStream,
table.DeviceHandle(), table.DeviceHandle(), static_cast<const IdType*>(test->data), size,
static_cast<const IdType*>(test->data),
size,
prefix); prefix);
} }
...@@ -83,40 +75,28 @@ IdArray _PerformFilter( ...@@ -83,40 +75,28 @@ IdArray _PerformFilter(
{ {
size_t workspace_bytes; size_t workspace_bytes;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
workspace_bytes, static_cast<IdType*>(nullptr), size + 1, cudaStream));
static_cast<IdType*>(nullptr), void* workspace = device->AllocWorkspace(ctx, workspace_bytes);
static_cast<IdType*>(nullptr),
size+1, cudaStream));
void * workspace = device->AllocWorkspace(ctx, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
workspace, workspace, workspace_bytes, prefix, prefix, size + 1, cudaStream));
workspace_bytes,
prefix,
prefix,
size+1, cudaStream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
// copy number using the internal current stream; // copy number using the internal current stream;
IdType num_unique; IdType num_unique;
device->CopyDataFromTo(prefix+size, 0, device->CopyDataFromTo(
&num_unique, 0, prefix + size, 0, &num_unique, 0, sizeof(num_unique), ctx,
sizeof(num_unique), DGLContext{kDGLCPU, 0}, test->dtype);
ctx,
DGLContext{kDGLCPU, 0},
test->dtype);
// insert items into set // insert items into set
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x); const dim3 grid((size + block.x - 1) / block.x);
CUDA_KERNEL_CALL(_InsertKernel, CUDA_KERNEL_CALL(
grid, block, 0, cudaStream, _InsertKernel, grid, block, 0, cudaStream, prefix, size,
prefix,
size,
static_cast<IdType*>(result->data)); static_cast<IdType*>(result->data));
} }
device->FreeWorkspace(ctx, prefix); device->FreeWorkspace(ctx, prefix);
...@@ -124,16 +104,13 @@ IdArray _PerformFilter( ...@@ -124,16 +104,13 @@ IdArray _PerformFilter(
return result.CreateView({num_unique}, result->dtype); return result.CreateView({num_unique}, result->dtype);
} }
template <typename IdType>
template<typename IdType>
class CudaFilterSet : public Filter { class CudaFilterSet : public Filter {
public: public:
explicit CudaFilterSet(IdArray array) : explicit CudaFilterSet(IdArray array)
table_(array->shape[0], array->ctx, cudaStream) { : table_(array->shape[0], array->ctx, cudaStream) {
table_.FillWithUnique( table_.FillWithUnique(
static_cast<const IdType*>(array->data), static_cast<const IdType*>(array->data), array->shape[0], cudaStream);
array->shape[0],
cudaStream);
} }
IdArray find_included_indices(IdArray test) override { IdArray find_included_indices(IdArray test) override {
...@@ -150,7 +127,7 @@ class CudaFilterSet : public Filter { ...@@ -150,7 +127,7 @@ class CudaFilterSet : public Filter {
} // namespace } // namespace
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) { FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set)); return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
} }
......
/** /**
* Copyright (c) 2022, NVIDIA CORPORATION. * Copyright (c) 2022, NVIDIA CORPORATION.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* \file array/gpu/disjoint_union.cu * \file array/gpu/disjoint_union.cu
* \brief Disjoint union GPU implementation. * \brief Disjoint union GPU implementation.
*/ */
#include <dgl/runtime/parallel_for.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/runtime/parallel_for.h>
#include <tuple> #include <tuple>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -31,8 +33,8 @@ namespace impl { ...@@ -31,8 +33,8 @@ namespace impl {
template <typename IdType> template <typename IdType>
__global__ void _DisjointUnionKernel( __global__ void _DisjointUnionKernel(
IdType** arrs, IdType* prefix, IdType* offset, IdType* out, IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,
int64_t n_arrs, int n_elms) { int n_elms) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < n_elms) { while (tx < n_elms) {
...@@ -48,7 +50,8 @@ __global__ void _DisjointUnionKernel( ...@@ -48,7 +50,8 @@ __global__ void _DisjointUnionKernel(
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) { std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(
const std::vector<COOMatrix>& coos) {
IdType n = coos.size(), nbits = coos[0].row->dtype.bits; IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
IdArray n_rows = NewIdArray(n, CPU, nbits); IdArray n_rows = NewIdArray(n, CPU, nbits);
IdArray n_cols = NewIdArray(n, CPU, nbits); IdArray n_cols = NewIdArray(n, CPU, nbits);
...@@ -58,7 +61,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -58,7 +61,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
IdType* n_cols_data = n_cols.Ptr<IdType>(); IdType* n_cols_data = n_cols.Ptr<IdType>();
IdType* n_elms_data = n_elms.Ptr<IdType>(); IdType* n_elms_data = n_elms.Ptr<IdType>();
dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e){ dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {
for (IdType i = b; i < e; ++i) { for (IdType i = b; i < e; ++i) {
n_rows_data[i] = coos[i].num_rows; n_rows_data[i] = coos[i].num_rows;
n_cols_data[i] = coos[i].num_cols; n_cols_data[i] = coos[i].num_cols;
...@@ -66,30 +69,30 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -66,30 +69,30 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
} }
}); });
return std::make_tuple(CumSum(n_rows.CopyTo(coos[0].row->ctx), true), return std::make_tuple(
CumSum(n_rows.CopyTo(coos[0].row->ctx), true),
CumSum(n_cols.CopyTo(coos[0].row->ctx), true), CumSum(n_cols.CopyTo(coos[0].row->ctx), true),
CumSum(n_elms.CopyTo(coos[0].row->ctx), true)); CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, void _Merge(
int64_t n_arrs, int n_elms, IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,
DGLContext ctx, DGLDataType dtype, cudaStream_t stream) { int n_elms, DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
int nt = 256; int nt = 256;
int nb = (n_elms + nt - 1) / nt; int nb = (n_elms + nt - 1) / nt;
IdType** arrs_dev = static_cast<IdType**>( IdType** arrs_dev = static_cast<IdType**>(
device->AllocWorkspace(ctx, n_arrs*sizeof(IdType*))); device->AllocWorkspace(ctx, n_arrs * sizeof(IdType*)));
device->CopyDataFromTo( device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs, arrs, 0, arrs_dev, 0, sizeof(IdType*) * n_arrs, DGLContext{kDGLCPU, 0},
DGLContext{kDGLCPU, 0}, ctx, dtype); ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _DisjointUnionKernel, nb, nt, 0, stream, arrs_dev, prefix, offset, out,
arrs_dev, prefix, offset, n_arrs, n_elms);
out, n_arrs, n_elms);
device->FreeWorkspace(ctx, arrs_dev); device->FreeWorkspace(ctx, arrs_dev);
} }
...@@ -132,52 +135,50 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -132,52 +135,50 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
IdType n_elements = 0; IdType n_elements = 0;
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0, &prefix_elm[coos.size()], 0, &n_elements, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0, &prefix_src[coos.size()], 0, &src_offset, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0, &prefix_dst[coos.size()], 0, &dst_offset, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
// Union src array // Union src array
IdArray result_src = NewIdArray( IdArray result_src =
n_elements, coos[0].row->ctx, coos[0].row->dtype.bits); NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);
_Merge<XPU, IdType>(rows.get(), prefix_src, prefix_elm, result_src.Ptr<IdType>(), _Merge<XPU, IdType>(
coos.size(), n_elements, ctx, dtype, stream); rows.get(), prefix_src, prefix_elm, result_src.Ptr<IdType>(), coos.size(),
n_elements, ctx, dtype, stream);
// Union dst array // Union dst array
IdArray result_dst = NewIdArray( IdArray result_dst =
n_elements, coos[0].col->ctx, coos[0].col->dtype.bits); NewIdArray(n_elements, coos[0].col->ctx, coos[0].col->dtype.bits);
_Merge<XPU, IdType>(cols.get(), prefix_dst, prefix_elm, result_dst.Ptr<IdType>(), _Merge<XPU, IdType>(
coos.size(), n_elements, ctx, dtype, stream); cols.get(), prefix_dst, prefix_elm, result_dst.Ptr<IdType>(), coos.size(),
n_elements, ctx, dtype, stream);
// Union data array if exists and fetch number of elements // Union data array if exists and fetch number of elements
IdArray result_dat = NullArray(); IdArray result_dat = NullArray();
if (has_data) { if (has_data) {
result_dat = NewIdArray( result_dat =
n_elements, coos[0].row->ctx, coos[0].row->dtype.bits); NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);
_Merge<XPU, IdType>(data.get(), prefix_elm, prefix_elm, result_dat.Ptr<IdType>(), _Merge<XPU, IdType>(
data.get(), prefix_elm, prefix_elm, result_dat.Ptr<IdType>(),
coos.size(), n_elements, ctx, dtype, stream); coos.size(), n_elements, ctx, dtype, stream);
} }
return COOMatrix( return COOMatrix(
src_offset, dst_offset, src_offset, dst_offset, result_src, result_dst, result_dat, row_sorted,
result_src,
result_dst,
result_dat,
row_sorted,
col_sorted); col_sorted);
} }
template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(
const std::vector<COOMatrix>& coos);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment