Unverified Commit 8ae50c42 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4804)



* [Misc] clang-format auto fix.

* manual

* manual

* manual

* manual

* todo

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 81831111
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/array_iterator.h> #include <dgl/array_iterator.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <utility> #include <dgl/runtime/parallel_for.h>
#include <algorithm> #include <algorithm>
#include <utility>
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -19,15 +20,12 @@ namespace impl { ...@@ -19,15 +20,12 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix &csr, const CSRMatrix& csr, int64_t num_samples, int num_trials,
int64_t num_samples, bool exclude_self_loops, bool replace, double redundancy) {
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
const int64_t num_row = csr.num_rows; const int64_t num_row = csr.num_rows;
const int64_t num_col = csr.num_cols; const int64_t num_col = csr.num_cols;
const int64_t num_actual_samples = static_cast<int64_t>(num_samples * (1 + redundancy)); const int64_t num_actual_samples =
static_cast<int64_t>(num_samples * (1 + redundancy));
IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx); IdArray row = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx); IdArray col = Full<IdType>(-1, num_actual_samples, csr.indptr->ctx);
IdType* row_data = row.Ptr<IdType>(); IdType* row_data = row.Ptr<IdType>();
...@@ -48,23 +46,30 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling( ...@@ -48,23 +46,30 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
}); });
PairIterator<IdType> begin(row_data, col_data); PairIterator<IdType> begin(row_data, col_data);
PairIterator<IdType> end = std::remove_if(begin, begin + num_actual_samples, PairIterator<IdType> end = std::remove_if(
begin, begin + num_actual_samples,
[](const std::pair<IdType, IdType>& val) { return val.first == -1; }); [](const std::pair<IdType, IdType>& val) { return val.first == -1; });
if (!replace) { if (!replace) {
std::sort(begin, end, std::sort(
[](const std::pair<IdType, IdType>& a, const std::pair<IdType, IdType>& b) { begin, end,
return a.first < b.first || (a.first == b.first && a.second < b.second); [](const std::pair<IdType, IdType>& a,
});; const std::pair<IdType, IdType>& b) {
return a.first < b.first ||
(a.first == b.first && a.second < b.second);
});
end = std::unique(begin, end); end = std::unique(begin, end);
} }
int64_t num_sampled = std::min(static_cast<int64_t>(end - begin), num_samples); int64_t num_sampled =
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)}; std::min(static_cast<int64_t>(end - begin), num_samples);
return {
row.CreateView({num_sampled}, row->dtype),
col.CreateView({num_sampled}, col->dtype)};
} }
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCPU, int32_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>( template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<
const CSRMatrix&, int64_t, int, bool, bool, double); kDGLCPU, int64_t>(const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl }; // namespace impl
}; // namespace aten }; // namespace aten
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include "../selector.h" #include "../selector.h"
namespace dgl { namespace dgl {
...@@ -25,38 +26,41 @@ namespace cpu { ...@@ -25,38 +26,41 @@ namespace cpu {
* \note it uses node parallel strategy, different threads are responsible * \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. * for the computation of different nodes.
*/ */
template <typename IdType, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename IdType, typename DType, typename Op, int LhsTarget = 0,
void SDDMMCsr(const BcastOff& bcast, int RhsTarget = 2>
const CSRMatrix& csr, void SDDMMCsr(
NDArray lhs, NDArray rhs, NDArray out) { const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
NDArray out) {
const bool has_idx = !IsNullArray(csr.data); const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>(); const IdType* edges = csr.data.Ptr<IdType>();
const DType* X = lhs.Ptr<DType>(); const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>(); const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len, const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>(); DType* O = out.Ptr<DType>();
runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) { runtime::parallel_for(0, csr.num_rows, [=](IdType b, IdType e) {
for (auto rid = b; rid < e; ++rid) { for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
DType* out_off = O + eid * dim; DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) { for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs const DType* lhs_off =
? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size Op::use_lhs
: nullptr; ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
const DType* rhs_off = Op::use_rhs lhs_add * reduce_size
? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr;
: nullptr; const DType* rhs_off =
Op::use_rhs
? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
rhs_add * reduce_size
: nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size); out_off[k] = Op::Call(lhs_off, rhs_off, reduce_size);
} }
} }
...@@ -74,35 +78,38 @@ void SDDMMCsr(const BcastOff& bcast, ...@@ -74,35 +78,38 @@ void SDDMMCsr(const BcastOff& bcast,
* \note it uses edge parallel strategy, different threads are responsible * \note it uses edge parallel strategy, different threads are responsible
* for the computation of different edges. * for the computation of different edges.
*/ */
template <typename IdType, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename IdType, typename DType, typename Op, int LhsTarget = 0,
void SDDMMCoo(const BcastOff& bcast, int RhsTarget = 2>
const COOMatrix& coo, void SDDMMCoo(
NDArray lhs, NDArray rhs, NDArray out) { const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data); const bool has_idx = !IsNullArray(coo.data);
const IdType* row = coo.row.Ptr<IdType>(); const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>(); const IdType* col = coo.col.Ptr<IdType>();
const IdType* edges = coo.data.Ptr<IdType>(); const IdType* edges = coo.data.Ptr<IdType>();
const DType* X = lhs.Ptr<DType>(); const DType* X = lhs.Ptr<DType>();
const DType* Y = rhs.Ptr<DType>(); const DType* Y = rhs.Ptr<DType>();
const int64_t dim = bcast.out_len, const int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len,
lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len, reduce_size = bcast.reduce_size;
rhs_dim = bcast.rhs_len,
reduce_size = bcast.reduce_size;
DType* O = out.Ptr<DType>(); DType* O = out.Ptr<DType>();
#pragma omp parallel for #pragma omp parallel for
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const IdType rid = row[i]; const IdType rid = row[i];
const IdType cid = col[i]; const IdType cid = col[i];
const IdType eid = has_idx? edges[i] : i; const IdType eid = has_idx ? edges[i] : i;
DType* out_off = O + eid * dim; DType* out_off = O + eid * dim;
for (int64_t k = 0; k < dim; ++k) { for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs ? const DType* lhs_off =
X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim + lhs_add * reduce_size : nullptr; Op::use_lhs ? X + Selector<LhsTarget>::Call(rid, eid, cid) * lhs_dim +
const DType* rhs_off = Op::use_rhs ? lhs_add * reduce_size
Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim + rhs_add * reduce_size : nullptr; : nullptr;
const DType* rhs_off =
Op::use_rhs ? Y + Selector<RhsTarget>::Call(rid, eid, cid) * rhs_dim +
rhs_add * reduce_size
: nullptr;
out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size); out_off[k] = Op::Call(lhs_off, rhs_off, bcast.reduce_size);
} }
} }
...@@ -110,12 +117,13 @@ void SDDMMCoo(const BcastOff& bcast, ...@@ -110,12 +117,13 @@ void SDDMMCoo(const BcastOff& bcast,
namespace op { namespace op {
//////////////////////////////// binary operators on CPU //////////////////////////////// ////////////////////////// binary operators on CPU /////////////////////////////
template <typename DType> template <typename DType>
struct Add { struct Add {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off + *rhs_off; return *lhs_off + *rhs_off;
} }
}; };
...@@ -124,7 +132,8 @@ template <typename DType> ...@@ -124,7 +132,8 @@ template <typename DType>
struct Sub { struct Sub {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off - *rhs_off; return *lhs_off - *rhs_off;
} }
}; };
...@@ -133,7 +142,8 @@ template <typename DType> ...@@ -133,7 +142,8 @@ template <typename DType>
struct Mul { struct Mul {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off * *rhs_off; return *lhs_off * *rhs_off;
} }
}; };
...@@ -142,7 +152,8 @@ template <typename DType> ...@@ -142,7 +152,8 @@ template <typename DType>
struct Div { struct Div {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
return *lhs_off / *rhs_off; return *lhs_off / *rhs_off;
} }
}; };
...@@ -151,7 +162,8 @@ template <typename DType> ...@@ -151,7 +162,8 @@ template <typename DType>
struct CopyLhs { struct CopyLhs {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false; static constexpr bool use_rhs = false;
inline static DType Call(const DType* lhs_off, const DType*, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType*, int64_t len = 1) {
return *lhs_off; return *lhs_off;
} }
}; };
...@@ -160,7 +172,8 @@ template <typename DType> ...@@ -160,7 +172,8 @@ template <typename DType>
struct CopyRhs { struct CopyRhs {
static constexpr bool use_lhs = false; static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* , const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType*, const DType* rhs_off, int64_t len = 1) {
return *rhs_off; return *rhs_off;
} }
}; };
...@@ -169,7 +182,8 @@ template <typename DType> ...@@ -169,7 +182,8 @@ template <typename DType>
struct Dot { struct Dot {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
inline static DType Call(const DType* lhs_off, const DType* rhs_off, int64_t len = 1) { inline static DType Call(
const DType* lhs_off, const DType* rhs_off, int64_t len = 1) {
DType rst = 0; DType rst = 0;
for (int64_t l = 0; l < len; ++l) { for (int64_t l = 0; l < len; ++l) {
rst += lhs_off[l] * rhs_off[l]; rst += lhs_off[l] * rhs_off[l];
...@@ -178,32 +192,32 @@ struct Dot { ...@@ -178,32 +192,32 @@ struct Dot {
} }
}; };
#define SWITCH_OP(op, Op, ...) \ #define SWITCH_OP(op, Op, ...) \
do { \ do { \
if ((op) == "add") { \ if ((op) == "add") { \
typedef dgl::aten::cpu::op::Add<DType> Op; \ typedef dgl::aten::cpu::op::Add<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "sub") { \ } else if ((op) == "sub") { \
typedef dgl::aten::cpu::op::Sub<DType> Op; \ typedef dgl::aten::cpu::op::Sub<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "mul") { \ } else if ((op) == "mul") { \
typedef dgl::aten::cpu::op::Mul<DType> Op; \ typedef dgl::aten::cpu::op::Mul<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "div") { \ } else if ((op) == "div") { \
typedef dgl::aten::cpu::op::Div<DType> Op; \ typedef dgl::aten::cpu::op::Div<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \ } else if ((op) == "copy_lhs") { \
typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \ typedef dgl::aten::cpu::op::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \ } else if ((op) == "copy_rhs") { \
typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \ typedef dgl::aten::cpu::op::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "dot") { \ } else if ((op) == "dot") { \
typedef dgl::aten::cpu::op::Dot<DType> Op; \ typedef dgl::aten::cpu::op::Dot<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(FATAL) << "Unsupported SDDMM binary operator: " << op; \ LOG(FATAL) << "Unsupported SDDMM binary operator: " << op; \
} \ } \
} while (0) } while (0)
} // namespace op } // namespace op
......
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
#define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_ #define DGL_ARRAY_CPU_SEGMENT_REDUCE_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <vector> #include <dgl/runtime/parallel_for.h>
#include <string> #include <string>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -26,11 +27,10 @@ template <typename IdType, typename DType> ...@@ -26,11 +27,10 @@ template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) { void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
int n = out->shape[0]; int n = out->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>(); const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
runtime::parallel_for(0, n, [=](int b, int e) { runtime::parallel_for(0, n, [=](int b, int e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) { for (IdType j = offsets_data[i]; j < offsets_data[i + 1]; ++j) {
...@@ -51,16 +51,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) { ...@@ -51,16 +51,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
* used in backward phase. * used in backward phase.
*/ */
template <typename IdType, typename DType, typename Cmp> template <typename IdType, typename DType, typename Cmp>
void SegmentCmp(NDArray feat, NDArray offsets, void SegmentCmp(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
NDArray out, NDArray arg) {
int n = out->shape[0]; int n = out->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* offsets_data = offsets.Ptr<IdType>(); const IdType* offsets_data = offsets.Ptr<IdType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
IdType *arg_data = arg.Ptr<IdType>(); IdType* arg_data = arg.Ptr<IdType>();
std::fill(out_data, out_data + out.NumElements(), Cmp::zero); std::fill(out_data, out_data + out.NumElements(), Cmp::zero);
std::fill(arg_data, arg_data + arg.NumElements(), -1); std::fill(arg_data, arg_data + arg.NumElements(), -1);
runtime::parallel_for(0, n, [=](int b, int e) { runtime::parallel_for(0, n, [=](int b, int e) {
...@@ -89,8 +87,7 @@ template <typename IdType, typename DType> ...@@ -89,8 +87,7 @@ template <typename IdType, typename DType>
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
int n = feat->shape[0]; int n = feat->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* idx_data = idx.Ptr<IdType>(); const IdType* idx_data = idx.Ptr<IdType>();
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
...@@ -114,24 +111,26 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) { ...@@ -114,24 +111,26 @@ void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
* \param list_out List of the output tensors. * \param list_out List of the output tensors.
*/ */
template <typename IdType, typename DType> template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(HeteroGraphPtr graph, void UpdateGradMinMax_hetero(
const std::string& op, HeteroGraphPtr graph, const std::string& op,
const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
const std::vector<NDArray>& list_idx, const std::vector<NDArray>& list_idx_types,
const std::vector<NDArray>& list_idx_types, std::vector<NDArray>* list_out) {
std::vector<NDArray>* list_out) {
if (op == "copy_lhs" || op == "copy_rhs") { if (op == "copy_lhs" || op == "copy_rhs") {
std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(), std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
std::vector<dgl_id_t>()); graph->NumVertexTypes(), std::vector<dgl_id_t>());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t dst_ntype = pair.first; // graph is reversed const dgl_id_t dst_ntype = pair.first; // graph is reversed
const dgl_id_t src_ntype = pair.second; const dgl_id_t src_ntype = pair.second;
auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]), auto same_src_dst_ntype = std::find(
std::end(src_dst_ntypes[dst_ntype]), src_ntype); std::begin(src_dst_ntypes[dst_ntype]),
// if op is "copy_lhs", relation type with same src and dst node type will be updated once std::end(src_dst_ntypes[dst_ntype]), src_ntype);
if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype])) // if op is "copy_lhs", relation type with same src and dst node type will
// be updated once
if (op == "copy_lhs" &&
same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
continue; continue;
src_dst_ntypes[dst_ntype].push_back(src_ntype); src_dst_ntypes[dst_ntype].push_back(src_ntype);
const DType* feat_data = list_feat[dst_ntype].Ptr<DType>(); const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
...@@ -149,7 +148,8 @@ void UpdateGradMinMax_hetero(HeteroGraphPtr graph, ...@@ -149,7 +148,8 @@ void UpdateGradMinMax_hetero(HeteroGraphPtr graph,
if (type == idx_type_data[i * dim + k]) { if (type == idx_type_data[i * dim + k]) {
const int write_row = idx_data[i * dim + k]; const int write_row = idx_data[i * dim + k];
#pragma omp atomic #pragma omp atomic
out_data[write_row * dim + k] += feat_data[i * dim + k]; // feat = dZ out_data[write_row * dim + k] +=
feat_data[i * dim + k]; // feat = dZ
} }
} }
} }
...@@ -170,8 +170,7 @@ template <typename IdType, typename DType> ...@@ -170,8 +170,7 @@ template <typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) { void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
int n = feat->shape[0]; int n = feat->shape[0];
int dim = 1; int dim = 1;
for (int i = 1; i < out->ndim; ++i) for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
dim *= out->shape[i];
const DType* feat_data = feat.Ptr<DType>(); const DType* feat_data = feat.Ptr<DType>();
const IdType* arg_data = arg.Ptr<IdType>(); const IdType* arg_data = arg.Ptr<IdType>();
DType* out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -4,58 +4,49 @@ ...@@ -4,58 +4,49 @@
* \brief Graph traversal implementation * \brief Graph traversal implementation
*/ */
#include "./traversal.h"
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include "./traversal.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
namespace { namespace {
// A utility view class to wrap a vector into a queue. // A utility view class to wrap a vector into a queue.
template<typename DType> template <typename DType>
struct VectorQueueWrapper { struct VectorQueueWrapper {
std::vector<DType>* vec; std::vector<DType>* vec;
size_t head = 0; size_t head = 0;
explicit VectorQueueWrapper(std::vector<DType>* vec): vec(vec) {} explicit VectorQueueWrapper(std::vector<DType>* vec) : vec(vec) {}
void push(const DType& elem) { void push(const DType& elem) { vec->push_back(elem); }
vec->push_back(elem);
}
DType top() const { DType top() const { return vec->operator[](head); }
return vec->operator[](head);
}
void pop() { void pop() { ++head; }
++head;
}
bool empty() const { bool empty() const { return head == vec->size(); }
return head == vec->size();
}
size_t size() const { size_t size() const { return vec->size() - head; }
return vec->size() - head;
}
}; };
// Internal function to merge multiple traversal traces into one ndarray. // Internal function to merge multiple traversal traces into one ndarray.
// It is similar to zip the vectors together. // It is similar to zip the vectors together.
template<typename DType> template <typename DType>
IdArray MergeMultipleTraversals( IdArray MergeMultipleTraversals(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0, total_len = 0; int64_t max_len = 0, total_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
total_len += traces[i].size(); total_len += traces[i].size();
} }
IdArray ret = IdArray::Empty({total_len}, IdArray ret = IdArray::Empty(
DGLDataType{kDGLInt, sizeof(DType) * 8, 1}, {total_len}, DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) { for (size_t j = 0; j < traces.size(); ++j) {
...@@ -71,15 +62,15 @@ IdArray MergeMultipleTraversals( ...@@ -71,15 +62,15 @@ IdArray MergeMultipleTraversals(
// Internal function to compute sections if multiple traversal traces // Internal function to compute sections if multiple traversal traces
// are merged into one ndarray. // are merged into one ndarray.
template<typename DType> template <typename DType>
IdArray ComputeMergedSections( IdArray ComputeMergedSections(const std::vector<std::vector<DType>>& traces) {
const std::vector<std::vector<DType>>& traces) {
int64_t max_len = 0; int64_t max_len = 0;
for (size_t i = 0; i < traces.size(); ++i) { for (size_t i = 0; i < traces.size(); ++i) {
const int64_t tracelen = traces[i].size(); const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen); max_len = std::max(max_len, tracelen);
} }
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty(
{max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data); int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) { for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0; int64_t sec_len = 0;
...@@ -101,13 +92,13 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -101,13 +92,13 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids); VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const int64_t v) { }; auto visit = [&](const int64_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
sections.push_back(queue.size()); sections.push_back(queue.size());
} }
}; };
BFSTraverseNodes<IdType>(csr, source, &queue, visit, make_frontier); BFSTraverseNodes<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front; Frontiers front;
...@@ -116,8 +107,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -116,8 +107,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
...@@ -126,16 +119,16 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -126,16 +119,16 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
// NOTE: std::queue has no top() method. // NOTE: std::queue has no top() method.
std::vector<IdType> nodes; std::vector<IdType> nodes;
VectorQueueWrapper<IdType> queue(&nodes); VectorQueueWrapper<IdType> queue(&nodes);
auto visit = [&] (const IdType e) { ids.push_back(e); }; auto visit = [&](const IdType e) { ids.push_back(e); };
bool first_frontier = true; bool first_frontier = true;
auto make_frontier = [&] { auto make_frontier = [&] {
if (first_frontier) { if (first_frontier) {
first_frontier = false; // do not push the first section when doing edges first_frontier = false; // do not push the first section when doing edges
} else if (!queue.empty()) { } else if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
sections.push_back(queue.size()); sections.push_back(queue.size());
} }
}; };
BFSTraverseEdges<IdType>(csr, source, &queue, visit, make_frontier); BFSTraverseEdges<IdType>(csr, source, &queue, visit, make_frontier);
Frontiers front; Frontiers front;
...@@ -144,21 +137,23 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) { ...@@ -144,21 +137,23 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front; return front;
} }
template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids; std::vector<IdType> ids;
std::vector<int64_t> sections; std::vector<int64_t> sections;
VectorQueueWrapper<IdType> queue(&ids); VectorQueueWrapper<IdType> queue(&ids);
auto visit = [&] (const uint64_t v) { }; auto visit = [&](const uint64_t v) {};
auto make_frontier = [&] () { auto make_frontier = [&]() {
if (!queue.empty()) { if (!queue.empty()) {
// do not push zero-length frontier // do not push zero-length frontier
sections.push_back(queue.size()); sections.push_back(queue.size());
} }
}; };
TopologicalNodes<IdType>(csr, &queue, visit, make_frontier); TopologicalNodes<IdType>(csr, &queue, visit, make_frontier);
Frontiers front; Frontiers front;
...@@ -167,8 +162,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) { ...@@ -167,8 +162,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front; return front;
} }
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&); template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&); const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(
const CSRMatrix&);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
...@@ -177,7 +174,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -177,7 +174,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
std::vector<std::vector<IdType>> edges(len); std::vector<std::vector<IdType>> edges(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int tag) { edges[i].push_back(e); }; auto visit = [&](IdType e, int tag) { edges[i].push_back(e); };
DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit); DFSLabeledEdges<IdType>(csr, src_data[i], false, false, visit);
} }
...@@ -191,11 +188,9 @@ template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray); ...@@ -191,11 +188,9 @@ template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray); template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels) {
const bool has_nontree_edge,
const bool return_labels) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType* src_data = static_cast<IdType*>(source->data);
std::vector<std::vector<IdType>> edges(len); std::vector<std::vector<IdType>> edges(len);
...@@ -206,14 +201,14 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -206,14 +201,14 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
} }
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
auto visit = [&] (IdType e, int64_t tag) { auto visit = [&](IdType e, int64_t tag) {
edges[i].push_back(e); edges[i].push_back(e);
if (return_labels) { if (return_labels) {
tags[i].push_back(tag); tags[i].push_back(tag);
} }
}; };
DFSLabeledEdges<IdType>(csr, src_data[i], DFSLabeledEdges<IdType>(
has_reverse_edge, has_nontree_edge, visit); csr, src_data[i], has_reverse_edge, has_nontree_edge, visit);
} }
Frontiers front; Frontiers front;
...@@ -226,16 +221,10 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -226,16 +221,10 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front; return front;
} }
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(
IdArray, const CSRMatrix&, IdArray, const bool, const bool, const bool);
const bool, template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(
const bool, const CSRMatrix&, IdArray, const bool, const bool, const bool);
const bool);
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
* \file array/cpu/traversal.h * \file array/cpu/traversal.h
* \brief Graph traversal routines. * \brief Graph traversal routines.
* *
* Traversal routines generate frontiers. Frontiers can be node frontiers or edge * Traversal routines generate frontiers. Frontiers can be node frontiers or
* frontiers depending on the traversal function. Each frontier is a * edge frontiers depending on the traversal function. Each frontier is a list
* list of nodes/edges (specified by their ids). An optional tag can be specified * of nodes/edges (specified by their ids). An optional tag can be specified for
* for each node/edge (represented by an int value). * each node/edge (represented by an int value).
*/ */
#ifndef DGL_ARRAY_CPU_TRAVERSAL_H_ #ifndef DGL_ARRAY_CPU_TRAVERSAL_H_
#define DGL_ARRAY_CPU_TRAVERSAL_H_ #define DGL_ARRAY_CPU_TRAVERSAL_H_
#include <dgl/graph_interface.h> #include <dgl/graph_interface.h>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -43,16 +44,16 @@ namespace impl { ...@@ -43,16 +44,16 @@ namespace impl {
* \param reversed If true, BFS follows the in-edge direction * \param reversed If true, BFS follows the in-edge direction
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void BFSTraverseNodes(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
IdArray source, void BFSTraverseNodes(
Queue* queue, const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
VisitFn visit, FrontierFn make_frontier) {
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType *src_data = static_cast<IdType*>(source->data); const IdType *src_data = static_cast<IdType *>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -71,7 +72,7 @@ void BFSTraverseNodes(const CSRMatrix& csr, ...@@ -71,7 +72,7 @@ void BFSTraverseNodes(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
auto v = indices_data[idx]; auto v = indices_data[idx];
if (!visited[v]) { if (!visited[v]) {
visited[v] = true; visited[v] = true;
...@@ -109,16 +110,16 @@ void BFSTraverseNodes(const CSRMatrix& csr, ...@@ -109,16 +110,16 @@ void BFSTraverseNodes(const CSRMatrix& csr,
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* The argument would be edge ID. * The argument would be edge ID.
* \param make_frontier The function to indicate that a new frontier can be made; * \param make_frontier The function to indicate that a new frontier can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void BFSTraverseEdges(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
IdArray source, void BFSTraverseEdges(
Queue* queue, const CSRMatrix &csr, IdArray source, Queue *queue, VisitFn visit,
VisitFn visit, FrontierFn make_frontier) {
FrontierFn make_frontier) {
const int64_t len = source->shape[0]; const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data); const IdType *src_data = static_cast<IdType *>(source->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -138,7 +139,7 @@ void BFSTraverseEdges(const CSRMatrix& csr, ...@@ -138,7 +139,7 @@ void BFSTraverseEdges(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
auto e = eid_data ? eid_data[idx] : idx; auto e = eid_data ? eid_data[idx] : idx;
const IdType v = indices_data[idx]; const IdType v = indices_data[idx];
if (!visited[v]) { if (!visited[v]) {
...@@ -174,13 +175,14 @@ void BFSTraverseEdges(const CSRMatrix& csr, ...@@ -174,13 +175,14 @@ void BFSTraverseEdges(const CSRMatrix& csr,
* \param reversed If true, follows the in-edge direction * \param reversed If true, follows the in-edge direction
* \param queue The queue used to do bfs. * \param queue The queue used to do bfs.
* \param visit The function to call when a node is visited. * \param visit The function to call when a node is visited.
* \param make_frontier The function to indicate that a new froniter can be made; * \param make_frontier The function to indicate that a new froniter can be
* made;
*/ */
template<typename IdType, typename Queue, typename VisitFn, typename FrontierFn> template <
void TopologicalNodes(const CSRMatrix& csr, typename IdType, typename Queue, typename VisitFn, typename FrontierFn>
Queue* queue, void TopologicalNodes(
VisitFn visit, const CSRMatrix &csr, Queue *queue, VisitFn visit,
FrontierFn make_frontier) { FrontierFn make_frontier) {
int64_t num_visited_nodes = 0; int64_t num_visited_nodes = 0;
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -206,7 +208,7 @@ void TopologicalNodes(const CSRMatrix& csr, ...@@ -206,7 +208,7 @@ void TopologicalNodes(const CSRMatrix& csr,
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
const IdType u = queue->top(); const IdType u = queue->top();
queue->pop(); queue->pop();
for (auto idx = indptr_data[u]; idx < indptr_data[u+1]; ++idx) { for (auto idx = indptr_data[u]; idx < indptr_data[u + 1]; ++idx) {
const IdType v = indices_data[idx]; const IdType v = indices_data[idx];
if (--(degrees[v]) == 0) { if (--(degrees[v]) == 0) {
visit(v); visit(v);
...@@ -219,7 +221,8 @@ void TopologicalNodes(const CSRMatrix& csr, ...@@ -219,7 +221,8 @@ void TopologicalNodes(const CSRMatrix& csr,
} }
if (num_visited_nodes != num_nodes) { if (num_visited_nodes != num_nodes) {
LOG(FATAL) << "Error in topological traversal: loop detected in the given graph."; LOG(FATAL)
<< "Error in topological traversal: loop detected in the given graph.";
} }
} }
...@@ -236,32 +239,29 @@ enum DFSEdgeTag { ...@@ -236,32 +239,29 @@ enum DFSEdgeTag {
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2)
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree. A NONTREE edge is one in which both `u` and `v` have
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * been visisted but the edge is NOT in the DFS tree.
* is NOT in the DFS tree.
* *
* \param source Source node. * \param source Source node.
* \param reversed If true, DFS follows the in-edge direction * \param reversed If true, DFS follows the in-edge direction
* \param has_reverse_edge If true, REVERSE edges are included * \param has_reverse_edge If true, REVERSE edges are included
* \param has_nontree_edge If true, NONTREE edges are included * \param has_nontree_edge If true, NONTREE edges are included
* \param visit The function to call when an edge is visited; the edge id and its * \param visit The function to call when an edge is visited; the edge id and
* tag will be given as the arguments. * its tag will be given as the arguments.
*/ */
template<typename IdType, typename VisitFn> template <typename IdType, typename VisitFn>
void DFSLabeledEdges(const CSRMatrix& csr, void DFSLabeledEdges(
IdType source, const CSRMatrix &csr, IdType source, bool has_reverse_edge,
bool has_reverse_edge, bool has_nontree_edge, VisitFn visit) {
bool has_nontree_edge,
VisitFn visit) {
const int64_t num_nodes = csr.num_rows; const int64_t num_nodes = csr.num_rows;
CHECK_GE(num_nodes, source) << "source " << source << CHECK_GE(num_nodes, source)
" is out of range [0," << num_nodes << "]"; << "source " << source << " is out of range [0," << num_nodes << "]";
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data); const IdType *eid_data = static_cast<IdType *>(csr.data->data);
if (indptr_data[source+1]-indptr_data[source] == 0) { if (indptr_data[source + 1] - indptr_data[source] == 0) {
// no out-going edges from the source node // no out-going edges from the source node
return; return;
} }
...@@ -278,7 +278,8 @@ void DFSLabeledEdges(const CSRMatrix& csr, ...@@ -278,7 +278,8 @@ void DFSLabeledEdges(const CSRMatrix& csr,
while (!stack.empty()) { while (!stack.empty()) {
std::tie(u, i, on_tree) = stack.top(); std::tie(u, i, on_tree) = stack.top();
const IdType v = indices_data[indptr_data[u] + i]; const IdType v = indices_data[indptr_data[u] + i];
const IdType uv = eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i; const IdType uv =
eid_data ? eid_data[indptr_data[u] + i] : indptr_data[u] + i;
if (visited[v]) { if (visited[v]) {
if (!on_tree && has_nontree_edge) { if (!on_tree && has_nontree_edge) {
visit(uv, kNonTree); visit(uv, kNonTree);
...@@ -288,7 +289,7 @@ void DFSLabeledEdges(const CSRMatrix& csr, ...@@ -288,7 +289,7 @@ void DFSLabeledEdges(const CSRMatrix& csr,
stack.pop(); stack.pop();
// find next one. // find next one.
if (indptr_data[u] + i < indptr_data[u + 1] - 1) { if (indptr_data[u] + i < indptr_data[u + 1] - 1) {
stack.push(std::make_tuple(u, i+1, false)); stack.push(std::make_tuple(u, i + 1, false));
} }
} else { } else {
visited[v] = true; visited[v] = true;
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Array cumsum GPU implementation * \brief Array cumsum GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -17,7 +18,8 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -17,7 +18,8 @@ template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) { IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements(); const int64_t len = array.NumElements();
if (len == 0) if (len == 0)
return !prepend_zero ? array : aten::Full(0, 1, array->dtype.bits, array->ctx); return !prepend_zero ? array
: aten::Full(0, 1, array->dtype.bits, array->ctx);
auto device = runtime::DeviceAPI::Get(array->ctx); auto device = runtime::DeviceAPI::Get(array->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -16,14 +17,11 @@ namespace impl { ...@@ -16,14 +17,11 @@ namespace impl {
template <typename IdType> template <typename IdType>
struct IsNonZeroIndex { struct IsNonZeroIndex {
explicit IsNonZeroIndex(const IdType * array) : array_(array) { explicit IsNonZeroIndex(const IdType* array) : array_(array) {}
}
__device__ bool operator() (const int64_t index) { __device__ bool operator()(const int64_t index) { return array_[index] != 0; }
return array_[index] != 0;
}
const IdType * array_; const IdType* array_;
}; };
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -36,22 +34,23 @@ IdArray NonZero(IdArray array) { ...@@ -36,22 +34,23 @@ IdArray NonZero(IdArray array) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType * const in_data = static_cast<const IdType*>(array->data); const IdType* const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data); int64_t* const out_data = static_cast<int64_t*>(ret->data);
IsNonZeroIndex<IdType> comp(in_data); IsNonZeroIndex<IdType> comp(in_data);
cub::CountingInputIterator<int64_t> counter(0); cub::CountingInputIterator<int64_t> counter(0);
// room for cub to output on GPU // room for cub to output on GPU
int64_t * d_num_nonzeros = static_cast<int64_t*>( int64_t* d_num_nonzeros =
device->AllocWorkspace(ctx, sizeof(int64_t))); static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
size_t temp_size = 0; size_t temp_size = 0;
CUDA_CALL(cub::DeviceSelect::If(nullptr, temp_size, counter, out_data, CUDA_CALL(cub::DeviceSelect::If(
d_num_nonzeros, len, comp, stream)); nullptr, temp_size, counter, out_data, d_num_nonzeros, len, comp,
void * temp = device->AllocWorkspace(ctx, temp_size); stream));
CUDA_CALL(cub::DeviceSelect::If(temp, temp_size, counter, out_data, void* temp = device->AllocWorkspace(ctx, temp_size);
d_num_nonzeros, len, comp, stream)); CUDA_CALL(cub::DeviceSelect::If(
temp, temp_size, counter, out_data, d_num_nonzeros, len, comp, stream));
device->FreeWorkspace(ctx, temp); device->FreeWorkspace(ctx, temp);
// copy number of selected elements from GPU to CPU // copy number of selected elements from GPU to CPU
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Array sort GPU implementation * \brief Array sort GPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
...@@ -29,26 +30,30 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) { ...@@ -29,26 +30,30 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_bits == 0) { if (num_bits == 0) {
num_bits = sizeof(IdType)*8; num_bits = sizeof(IdType) * 8;
} }
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); nullptr, workspace_size, keys_in, keys_out, values_in, values_out, nitems,
0, num_bits, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
CUDA_CALL(cub::DeviceRadixSort::SortPairs(workspace, workspace_size, CUDA_CALL(cub::DeviceRadixSort::SortPairs(
keys_in, keys_out, values_in, values_out, nitems, 0, num_bits, stream)); workspace, workspace_size, keys_in, keys_out, values_in, values_out,
nitems, 0, num_bits, stream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits); IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(
IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief COO2CSR * \brief COO2CSR
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -46,18 +47,15 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) { ...@@ -46,18 +47,15 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
if (!COOHasData(coo)) if (!COOHasData(coo))
coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx); coo.data = aten::Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);
NDArray indptr = aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits); NDArray indptr =
aten::NewIdArray(coo.num_rows + 1, coo.row->ctx, coo.row->dtype.bits);
int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
CUSPARSE_CALL(cusparseXcoo2csr( CUSPARSE_CALL(cusparseXcoo2csr(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.row.Ptr<int32_t>(), nnz, coo.num_rows,
coo.row.Ptr<int32_t>(), indptr_ptr, CUSPARSE_INDEX_BASE_ZERO));
nnz,
coo.num_rows, return CSRMatrix(
indptr_ptr, coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
CUSPARSE_INDEX_BASE_ZERO));
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, coo.col, coo.data, col_sorted);
} }
/*! /*!
...@@ -77,9 +75,8 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) { ...@@ -77,9 +75,8 @@ CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SortedSearchKernelUpperBound( __global__ void _SortedSearchKernelUpperBound(
const IdType* hay, int64_t hay_size, const IdType* hay, int64_t hay_size, const IdType* needles,
const IdType* needles, int64_t num_needles, int64_t num_needles, IdType* pos) {
IdType* pos) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_needles) { while (tx < num_needles) {
...@@ -123,14 +120,12 @@ CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) { ...@@ -123,14 +120,12 @@ CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const int nt = cuda::FindNumThreads(coo.num_rows); const int nt = cuda::FindNumThreads(coo.num_rows);
const int nb = (coo.num_rows + nt - 1) / nt; const int nb = (coo.num_rows + nt - 1) / nt;
IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx); IdArray indptr = Full(0, coo.num_rows + 1, nbits, ctx);
CUDA_KERNEL_CALL(_SortedSearchKernelUpperBound, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SortedSearchKernelUpperBound, nb, nt, 0, stream, coo.row.Ptr<int64_t>(),
coo.row.Ptr<int64_t>(), nnz, nnz, rowids.Ptr<int64_t>(), coo.num_rows, indptr.Ptr<int64_t>() + 1);
rowids.Ptr<int64_t>(), coo.num_rows,
indptr.Ptr<int64_t>() + 1); return CSRMatrix(
coo.num_rows, coo.num_cols, indptr, coo.col, coo.data, col_sorted);
return CSRMatrix(coo.num_rows, coo.num_cols,
indptr, coo.col, coo.data, col_sorted);
} }
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Sort COO index * \brief Sort COO index
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
...@@ -18,21 +19,20 @@ namespace impl { ...@@ -18,21 +19,20 @@ namespace impl {
///////////////////////////// COOSort_ ///////////////////////////// ///////////////////////////// COOSort_ /////////////////////////////
/** /**
* @brief Encode row and column IDs into a single scalar per edge. * @brief Encode row and column IDs into a single scalar per edge.
* *
* @tparam IdType The type to encode as. * @tparam IdType The type to encode as.
* @param row The row (src) IDs per edge. * @param row The row (src) IDs per edge.
* @param col The column (dst) IDs per edge. * @param col The column (dst) IDs per edge.
* @param nnz The number of edges. * @param nnz The number of edges.
* @param col_bits The number of bits used to encode the destination. The row * @param col_bits The number of bits used to encode the destination. The row
* information is packed into the remaining bits. * information is packed into the remaining bits.
* @param key The encoded edges (output). * @param key The encoded edges (output).
*/ */
template <typename IdType> template <typename IdType>
__global__ void _COOEncodeEdgesKernel( __global__ void _COOEncodeEdgesKernel(
const IdType* const row, const IdType* const col, const IdType* const row, const IdType* const col, const int64_t nnz,
const int64_t nnz, const int col_bits, IdType * const key) { const int col_bits, IdType* const key) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) { if (tx < nnz) {
...@@ -41,20 +41,19 @@ __global__ void _COOEncodeEdgesKernel( ...@@ -41,20 +41,19 @@ __global__ void _COOEncodeEdgesKernel(
} }
/** /**
* @brief Decode row and column IDs from the encoded edges. * @brief Decode row and column IDs from the encoded edges.
* *
* @tparam IdType The type the edges are encoded as. * @tparam IdType The type the edges are encoded as.
* @param key The encoded edges. * @param key The encoded edges.
* @param nnz The number of edges. * @param nnz The number of edges.
* @param col_bits The number of bits used to store the column/dst ID. * @param col_bits The number of bits used to store the column/dst ID.
* @param row The row (src) IDs per edge (output). * @param row The row (src) IDs per edge (output).
* @param col The col (dst) IDs per edge (output). * @param col The col (dst) IDs per edge (output).
*/ */
template <typename IdType> template <typename IdType>
__global__ void _COODecodeEdgesKernel( __global__ void _COODecodeEdgesKernel(
const IdType* const key, const int64_t nnz, const int col_bits, const IdType* const key, const int64_t nnz, const int col_bits,
IdType * const row, IdType * const col) { IdType* const row, IdType* const col) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) { if (tx < nnz) {
...@@ -64,9 +63,7 @@ __global__ void _COODecodeEdgesKernel( ...@@ -64,9 +63,7 @@ __global__ void _COODecodeEdgesKernel(
} }
} }
template <typename T>
template<typename T>
int _NumberOfBits(const T& range) { int _NumberOfBits(const T& range) {
if (range <= 1) { if (range <= 1) {
// ranges of 0 or 1 require no bits to store // ranges of 0 or 1 require no bits to store
...@@ -74,12 +71,12 @@ int _NumberOfBits(const T& range) { ...@@ -74,12 +71,12 @@ int _NumberOfBits(const T& range) {
} }
int bits = 1; int bits = 1;
while (bits < static_cast<int>(sizeof(T)*8) && (1 << bits) < range) { while (bits < static_cast<int>(sizeof(T) * 8) && (1 << bits) < range) {
++bits; ++bits;
} }
CHECK_EQ((range-1) >> bits, 0); CHECK_EQ((range - 1) >> bits, 0);
CHECK_NE((range-1) >> (bits-1), 0); CHECK_NE((range - 1) >> (bits - 1), 0);
return bits; return bits;
} }
...@@ -95,20 +92,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -95,20 +92,20 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
const int num_bits = row_bits + col_bits; const int num_bits = row_bits + col_bits;
const int nt = 256; const int nt = 256;
const int nb = (nnz+nt-1)/nt; const int nb = (nnz + nt - 1) / nt;
CHECK(static_cast<int64_t>(nb)*nt >= nnz); CHECK(static_cast<int64_t>(nb) * nt >= nnz);
IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits); IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits);
CUDA_KERNEL_CALL(_COOEncodeEdgesKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>(), _COOEncodeEdgesKernel, nb, nt, 0, stream, coo->row.Ptr<IdType>(),
nnz, col_bits, pos.Ptr<IdType>()); coo->col.Ptr<IdType>(), nnz, col_bits, pos.Ptr<IdType>());
auto sorted = Sort(pos, num_bits); auto sorted = Sort(pos, num_bits);
CUDA_KERNEL_CALL(_COODecodeEdgesKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
sorted.first.Ptr<IdType>(), nnz, col_bits, _COODecodeEdgesKernel, nb, nt, 0, stream, sorted.first.Ptr<IdType>(),
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>()); nnz, col_bits, coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>());
if (aten::COOHasData(*coo)) if (aten::COOHasData(*coo))
coo->data = IndexSelect(coo->data, sorted.second); coo->data = IndexSelect(coo->data, sorted.second);
...@@ -138,8 +135,8 @@ template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column); ...@@ -138,8 +135,8 @@ template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
template <typename IdType> template <typename IdType>
__global__ void _COOIsSortedKernel( __global__ void _COOIsSortedKernel(
const IdType* row, const IdType* col, const IdType* row, const IdType* col, int64_t nnz, int8_t* row_sorted,
int64_t nnz, int8_t* row_sorted, int8_t* col_sorted) { int8_t* col_sorted) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < nnz) { while (tx < nnz) {
...@@ -148,8 +145,8 @@ __global__ void _COOIsSortedKernel( ...@@ -148,8 +145,8 @@ __global__ void _COOIsSortedKernel(
col_sorted[0] = 1; col_sorted[0] = 1;
} else { } else {
row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]); row_sorted[tx] = static_cast<int8_t>(row[tx - 1] <= row[tx]);
col_sorted[tx] = static_cast<int8_t>( col_sorted[tx] =
row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]); static_cast<int8_t>(row[tx - 1] < row[tx] || col[tx - 1] <= col[tx]);
} }
tx += stride_x; tx += stride_x;
} }
...@@ -161,18 +158,19 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -161,18 +158,19 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const auto& ctx = coo.row->ctx; const auto& ctx = coo.row->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but should // We allocate a workspace of 2*nnz bytes. It wastes a little bit memory but
// be fine. // should be fine.
int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz)); int8_t* row_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz)); int8_t* col_flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, nnz));
const int nt = cuda::FindNumThreads(nnz); const int nt = cuda::FindNumThreads(nnz);
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_COOIsSortedKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
coo.row.Ptr<IdType>(), coo.col.Ptr<IdType>(), _COOIsSortedKernel, nb, nt, 0, stream, coo.row.Ptr<IdType>(),
nnz, row_flags, col_flags); coo.col.Ptr<IdType>(), nnz, row_flags, col_flags);
const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx); const bool row_sorted = cuda::AllTrue(row_flags, nnz, ctx);
const bool col_sorted = row_sorted? cuda::AllTrue(col_flags, nnz, ctx) : false; const bool col_sorted =
row_sorted ? cuda::AllTrue(col_flags, nnz, ctx) : false;
device->FreeWorkspace(ctx, row_flags); device->FreeWorkspace(ctx, row_flags);
device->FreeWorkspace(ctx, col_flags); device->FreeWorkspace(ctx, col_flags);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief CSR2COO * \brief CSR2COO
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -32,20 +33,16 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -32,20 +33,16 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data; NDArray indptr = csr.indptr, indices = csr.indices, data = csr.data;
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
NDArray row = aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits); NDArray row =
aten::NewIdArray(indices->shape[0], indptr->ctx, indptr->dtype.bits);
int32_t* row_ptr = static_cast<int32_t*>(row->data); int32_t* row_ptr = static_cast<int32_t*>(row->data);
CUSPARSE_CALL(cusparseXcsr2coo( CUSPARSE_CALL(cusparseXcsr2coo(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, indptr_ptr, indices->shape[0], csr.num_rows,
indptr_ptr, row_ptr, CUSPARSE_INDEX_BASE_ZERO));
indices->shape[0],
csr.num_rows, return COOMatrix(
row_ptr, csr.num_rows, csr.num_cols, row, indices, data, true, csr.sorted);
CUSPARSE_INDEX_BASE_ZERO));
return COOMatrix(csr.num_rows, csr.num_cols,
row, indices, data,
true, csr.sorted);
} }
/*! /*!
...@@ -65,8 +62,8 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -65,8 +62,8 @@ COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
*/ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _RepeatKernel( __global__ void _RepeatKernel(
const DType* val, const IdType* pos, const DType* val, const IdType* pos, DType* out, int64_t n_row,
DType* out, int64_t n_row, int64_t length) { int64_t length) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -88,15 +85,13 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) { ...@@ -88,15 +85,13 @@ COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const int nt = 256; const int nt = 256;
const int nb = (nnz + nt - 1) / nt; const int nb = (nnz + nt - 1) / nt;
CUDA_KERNEL_CALL(_RepeatKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _RepeatKernel, nb, nt, 0, stream, rowids.Ptr<int64_t>(),
rowids.Ptr<int64_t>(), csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(), csr.num_rows, nnz);
csr.indptr.Ptr<int64_t>(), ret_row.Ptr<int64_t>(),
csr.num_rows, nnz); return COOMatrix(
csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data, true,
return COOMatrix(csr.num_rows, csr.num_cols, csr.sorted);
ret_row, csr.indices, csr.data,
true, csr.sorted);
} }
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr); template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
...@@ -111,8 +106,7 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -111,8 +106,7 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data)) return coo;
return coo;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(coo.row->ctx); auto device = runtime::DeviceAPI::Get(coo.row->ctx);
...@@ -130,21 +124,12 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -130,21 +124,12 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt( CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
coo.num_rows, coo.num_cols, data_ptr, row_ptr, &workspace_size));
row->shape[0],
data_ptr,
row_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size); void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow( CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, coo.num_rows, coo.num_cols, row->shape[0],
coo.num_rows, coo.num_cols, data_ptr, row_ptr, col_ptr, workspace));
row->shape[0],
data_ptr,
row_ptr,
col_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace); device->FreeWorkspace(row->ctx, workspace);
// The row and column field have already been reordered according // The row and column field have already been reordered according
...@@ -158,8 +143,7 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -158,8 +143,7 @@ COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
template <> template <>
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr); COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
if (aten::IsNullArray(coo.data)) if (aten::IsNullArray(coo.data)) return coo;
return coo;
const auto& sorted = Sort(coo.data); const auto& sorted = Sort(coo.data);
coo.row = IndexSelect(coo.row, sorted.second); coo.row = IndexSelect(coo.row, sorted.second);
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Sort CSR index * \brief Sort CSR index
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
...@@ -20,8 +21,8 @@ namespace impl { ...@@ -20,8 +21,8 @@ namespace impl {
*/ */
template <typename IdType> template <typename IdType>
__global__ void _SegmentIsSorted( __global__ void _SegmentIsSorted(
const IdType* indptr, const IdType* indices, const IdType* indptr, const IdType* indices, int64_t num_rows,
int64_t num_rows, int8_t* flags) { int8_t* flags) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < num_rows) { while (tx < num_rows) {
...@@ -39,15 +40,15 @@ bool CSRIsSorted(CSRMatrix csr) { ...@@ -39,15 +40,15 @@ bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx; const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
// We allocate a workspace of num_rows bytes. It wastes a little bit memory but should // We allocate a workspace of num_rows bytes. It wastes a little bit memory
// be fine. // but should be fine.
int8_t* flags = static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows)); int8_t* flags =
static_cast<int8_t*>(device->AllocWorkspace(ctx, csr.num_rows));
const int nt = cuda::FindNumThreads(csr.num_rows); const int nt = cuda::FindNumThreads(csr.num_rows);
const int nb = (csr.num_rows + nt - 1) / nt; const int nb = (csr.num_rows + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentIsSorted, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _SegmentIsSorted, nb, nt, 0, stream, csr.indptr.Ptr<IdType>(),
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, flags);
csr.num_rows, flags);
bool ret = cuda::AllTrue(flags, csr.num_rows, ctx); bool ret = cuda::AllTrue(flags, csr.num_rows, ctx);
device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, flags);
return ret; return ret;
...@@ -82,10 +83,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) { ...@@ -82,10 +83,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt( CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz,
csr->num_rows, csr->num_cols, nnz, indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), &workspace_size));
indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
cusparseMatDescr_t descr; cusparseMatDescr_t descr;
...@@ -93,11 +92,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) { ...@@ -93,11 +92,8 @@ void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(cusparseXcsrsort( CUSPARSE_CALL(cusparseXcsrsort(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr->num_rows, csr->num_cols, nnz, descr,
csr->num_rows, csr->num_cols, nnz, indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(), data.Ptr<int32_t>(),
descr,
indptr.Ptr<int32_t>(), indices.Ptr<int32_t>(),
data.Ptr<int32_t>(),
workspace)); workspace));
csr->sorted = true; csr->sorted = true;
...@@ -115,8 +111,7 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) { ...@@ -115,8 +111,7 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
const auto& ctx = csr->indptr->ctx; const auto& ctx = csr->indptr->ctx;
const int64_t nnz = csr->indices->shape[0]; const int64_t nnz = csr->indices->shape[0];
const auto nbits = csr->indptr->dtype.bits; const auto nbits = csr->indptr->dtype.bits;
if (!aten::CSRHasData(*csr)) if (!aten::CSRHasData(*csr)) csr->data = aten::Range(0, nnz, nbits, ctx);
csr->data = aten::Range(0, nnz, nbits, ctx);
IdArray new_indices = csr->indices.Clone(); IdArray new_indices = csr->indices.Clone();
IdArray new_data = csr->data.Clone(); IdArray new_data = csr->data.Clone();
...@@ -129,15 +124,15 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) { ...@@ -129,15 +124,15 @@ void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
key_in, key_out, value_in, value_out, nullptr, workspace_size, key_in, key_out, value_in, value_out, nnz,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream)); csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(workspace, workspace_size, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairs(
key_in, key_out, value_in, value_out, workspace, workspace_size, key_in, key_out, value_in, value_out, nnz,
nnz, csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t)*8, stream)); csr->num_rows, offsets, offsets + 1, 0, sizeof(int64_t) * 8, stream));
csr->sorted = true; csr->sorted = true;
csr->indices = new_indices; csr->indices = new_indices;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief CSR transpose (convert to CSC) * \brief CSR transpose (convert to CSC)
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
namespace dgl { namespace dgl {
...@@ -33,14 +34,13 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -33,14 +34,13 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
const int64_t nnz = indices->shape[0]; const int64_t nnz = indices->shape[0];
const auto& ctx = indptr->ctx; const auto& ctx = indptr->ctx;
const auto bits = indptr->dtype.bits; const auto bits = indptr->dtype.bits;
if (aten::IsNullArray(data)) if (aten::IsNullArray(data)) data = aten::Range(0, nnz, bits, ctx);
data = aten::Range(0, nnz, bits, ctx);
const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data); const int32_t* indptr_ptr = static_cast<int32_t*>(indptr->data);
const int32_t* indices_ptr = static_cast<int32_t*>(indices->data); const int32_t* indices_ptr = static_cast<int32_t*>(indices->data);
const void* data_ptr = data->data; const void* data_ptr = data->data;
// (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz == 0. // (BarclayII) csr2csc doesn't seem to clear the content of cscColPtr if nnz
// We need to do it ourselves. // == 0. We need to do it ourselves.
NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx); NDArray t_indptr = aten::Full(0, csr.num_cols + 1, bits, ctx);
NDArray t_indices = aten::NewIdArray(nnz, ctx, bits); NDArray t_indices = aten::NewIdArray(nnz, ctx, bits);
NDArray t_data = aten::NewIdArray(nnz, ctx, bits); NDArray t_data = aten::NewIdArray(nnz, ctx, bits);
...@@ -53,40 +53,29 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -53,40 +53,29 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
// workspace // workspace
size_t workspace_size; size_t workspace_size;
CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize( CUSPARSE_CALL(cusparseCsr2cscEx2_bufferSize(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
csr.num_rows, csr.num_cols, nnz, indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
data_ptr, indptr_ptr, indices_ptr, CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUDA_R_32F,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
&workspace_size)); &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseCsr2cscEx2( CUSPARSE_CALL(cusparseCsr2cscEx2(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz, data_ptr,
csr.num_rows, csr.num_cols, nnz, indptr_ptr, indices_ptr, t_data_ptr, t_indptr_ptr, t_indices_ptr,
data_ptr, indptr_ptr, indices_ptr, CUDA_R_32F, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO,
t_data_ptr, t_indptr_ptr, t_indices_ptr,
CUDA_R_32F,
CUSPARSE_ACTION_NUMERIC,
CUSPARSE_INDEX_BASE_ZERO,
CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference CUSPARSE_CSR2CSC_ALG1, // see cusparse doc for reference
workspace)); workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
#else #else
CUSPARSE_CALL(cusparseScsr2csc( CUSPARSE_CALL(cusparseScsr2csc(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, csr.num_rows, csr.num_cols, nnz,
csr.num_rows, csr.num_cols, nnz,
static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr, static_cast<const float*>(data_ptr), indptr_ptr, indices_ptr,
static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr, static_cast<float*>(t_data_ptr), t_indices_ptr, t_indptr_ptr,
CUSPARSE_ACTION_NUMERIC, CUSPARSE_ACTION_NUMERIC, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_INDEX_BASE_ZERO));
#endif #endif
return CSRMatrix(csr.num_cols, csr.num_rows, return CSRMatrix(
t_indptr, t_indices, t_data, csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false);
false);
} }
template <> template <>
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../filter.h"
#include "../../runtime/cuda/cuda_hashtable.cuh" #include "../../runtime/cuda/cuda_hashtable.cuh"
#include "../filter.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
using namespace dgl::runtime::cuda; using namespace dgl::runtime::cuda;
...@@ -20,35 +20,29 @@ namespace { ...@@ -20,35 +20,29 @@ namespace {
cudaStream_t cudaStream = runtime::getCurrentCUDAStream(); cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
template<typename IdType, bool include> template <typename IdType, bool include>
__global__ void _IsInKernel( __global__ void _IsInKernel(
DeviceOrderedHashTable<IdType> table, DeviceOrderedHashTable<IdType> table, const IdType* const array,
const IdType * const array, const int64_t size, IdType* const mark) {
const int64_t size, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const mark) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) { if (idx < size) {
mark[idx] = table.Contains(array[idx]) ^ (!include); mark[idx] = table.Contains(array[idx]) ^ (!include);
} }
} }
template<typename IdType> template <typename IdType>
__global__ void _InsertKernel( __global__ void _InsertKernel(
const IdType * const prefix, const IdType* const prefix, const int64_t size, IdType* const result) {
const int64_t size, const int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
IdType * const result) {
const int64_t idx = threadIdx.x + blockDim.x*blockIdx.x;
if (idx < size) { if (idx < size) {
if (prefix[idx] != prefix[idx+1]) { if (prefix[idx] != prefix[idx + 1]) {
result[prefix[idx]] = idx; result[prefix[idx]] = idx;
} }
} }
} }
template<typename IdType, bool include> template <typename IdType, bool include>
IdArray _PerformFilter( IdArray _PerformFilter(const OrderedHashTable<IdType>& table, IdArray test) {
const OrderedHashTable<IdType>& table,
IdArray test) {
const auto& ctx = test->ctx; const auto& ctx = test->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t size = test->shape[0]; const int64_t size = test->shape[0];
...@@ -60,22 +54,20 @@ IdArray _PerformFilter( ...@@ -60,22 +54,20 @@ IdArray _PerformFilter(
// we need two arrays: 1) to act as a prefixsum // we need two arrays: 1) to act as a prefixsum
// for the number of entries that will be inserted, and // for the number of entries that will be inserted, and
// 2) to collect the included items. // 2) to collect the included items.
IdType * prefix = static_cast<IdType*>( IdType* prefix = static_cast<IdType*>(
device->AllocWorkspace(ctx, sizeof(IdType)*(size+1))); device->AllocWorkspace(ctx, sizeof(IdType) * (size + 1)));
// will resize down later // will resize down later
IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType)*8); IdArray result = aten::NewIdArray(size, ctx, sizeof(IdType) * 8);
// mark each index based on it's existence in the hashtable // mark each index based on it's existence in the hashtable
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x); const dim3 grid((size + block.x - 1) / block.x);
CUDA_KERNEL_CALL((_IsInKernel<IdType, include>), CUDA_KERNEL_CALL(
grid, block, 0, cudaStream, (_IsInKernel<IdType, include>), grid, block, 0, cudaStream,
table.DeviceHandle(), table.DeviceHandle(), static_cast<const IdType*>(test->data), size,
static_cast<const IdType*>(test->data),
size,
prefix); prefix);
} }
...@@ -83,40 +75,28 @@ IdArray _PerformFilter( ...@@ -83,40 +75,28 @@ IdArray _PerformFilter(
{ {
size_t workspace_bytes; size_t workspace_bytes;
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, nullptr, workspace_bytes, static_cast<IdType*>(nullptr),
workspace_bytes, static_cast<IdType*>(nullptr), size + 1, cudaStream));
static_cast<IdType*>(nullptr), void* workspace = device->AllocWorkspace(ctx, workspace_bytes);
static_cast<IdType*>(nullptr),
size+1, cudaStream));
void * workspace = device->AllocWorkspace(ctx, workspace_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum( CUDA_CALL(cub::DeviceScan::ExclusiveSum(
workspace, workspace, workspace_bytes, prefix, prefix, size + 1, cudaStream));
workspace_bytes,
prefix,
prefix,
size+1, cudaStream));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
// copy number using the internal current stream; // copy number using the internal current stream;
IdType num_unique; IdType num_unique;
device->CopyDataFromTo(prefix+size, 0, device->CopyDataFromTo(
&num_unique, 0, prefix + size, 0, &num_unique, 0, sizeof(num_unique), ctx,
sizeof(num_unique), DGLContext{kDGLCPU, 0}, test->dtype);
ctx,
DGLContext{kDGLCPU, 0},
test->dtype);
// insert items into set // insert items into set
{ {
const dim3 block(256); const dim3 block(256);
const dim3 grid((size+block.x-1)/block.x); const dim3 grid((size + block.x - 1) / block.x);
CUDA_KERNEL_CALL(_InsertKernel, CUDA_KERNEL_CALL(
grid, block, 0, cudaStream, _InsertKernel, grid, block, 0, cudaStream, prefix, size,
prefix,
size,
static_cast<IdType*>(result->data)); static_cast<IdType*>(result->data));
} }
device->FreeWorkspace(ctx, prefix); device->FreeWorkspace(ctx, prefix);
...@@ -124,16 +104,13 @@ IdArray _PerformFilter( ...@@ -124,16 +104,13 @@ IdArray _PerformFilter(
return result.CreateView({num_unique}, result->dtype); return result.CreateView({num_unique}, result->dtype);
} }
template <typename IdType>
template<typename IdType>
class CudaFilterSet : public Filter { class CudaFilterSet : public Filter {
public: public:
explicit CudaFilterSet(IdArray array) : explicit CudaFilterSet(IdArray array)
table_(array->shape[0], array->ctx, cudaStream) { : table_(array->shape[0], array->ctx, cudaStream) {
table_.FillWithUnique( table_.FillWithUnique(
static_cast<const IdType*>(array->data), static_cast<const IdType*>(array->data), array->shape[0], cudaStream);
array->shape[0],
cudaStream);
} }
IdArray find_included_indices(IdArray test) override { IdArray find_included_indices(IdArray test) override {
...@@ -150,7 +127,7 @@ class CudaFilterSet : public Filter { ...@@ -150,7 +127,7 @@ class CudaFilterSet : public Filter {
} // namespace } // namespace
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) { FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set)); return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
} }
......
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright (c) 2021 by Contributors
* \file cuda_common.h * \file cuda_common.h
* \brief Wrapper to place cub in dgl namespace. * \brief Wrapper to place cub in dgl namespace.
*/ */
#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ #ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
......
/** /**
* Copyright (c) 2022, NVIDIA CORPORATION. * Copyright (c) 2022, NVIDIA CORPORATION.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* \file array/gpu/disjoint_union.cu * \file array/gpu/disjoint_union.cu
* \brief Disjoint union GPU implementation. * \brief Disjoint union GPU implementation.
*/ */
#include <dgl/runtime/parallel_for.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/runtime/parallel_for.h>
#include <tuple> #include <tuple>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -31,8 +33,8 @@ namespace impl { ...@@ -31,8 +33,8 @@ namespace impl {
template <typename IdType> template <typename IdType>
__global__ void _DisjointUnionKernel( __global__ void _DisjointUnionKernel(
IdType** arrs, IdType* prefix, IdType* offset, IdType* out, IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,
int64_t n_arrs, int n_elms) { int n_elms) {
IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x; IdType tx = static_cast<IdType>(blockIdx.x) * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < n_elms) { while (tx < n_elms) {
...@@ -48,7 +50,8 @@ __global__ void _DisjointUnionKernel( ...@@ -48,7 +50,8 @@ __global__ void _DisjointUnionKernel(
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) { std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(
const std::vector<COOMatrix>& coos) {
IdType n = coos.size(), nbits = coos[0].row->dtype.bits; IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
IdArray n_rows = NewIdArray(n, CPU, nbits); IdArray n_rows = NewIdArray(n, CPU, nbits);
IdArray n_cols = NewIdArray(n, CPU, nbits); IdArray n_cols = NewIdArray(n, CPU, nbits);
...@@ -58,7 +61,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -58,7 +61,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
IdType* n_cols_data = n_cols.Ptr<IdType>(); IdType* n_cols_data = n_cols.Ptr<IdType>();
IdType* n_elms_data = n_elms.Ptr<IdType>(); IdType* n_elms_data = n_elms.Ptr<IdType>();
dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e){ dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {
for (IdType i = b; i < e; ++i) { for (IdType i = b; i < e; ++i) {
n_rows_data[i] = coos[i].num_rows; n_rows_data[i] = coos[i].num_rows;
n_cols_data[i] = coos[i].num_cols; n_cols_data[i] = coos[i].num_cols;
...@@ -66,30 +69,30 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -66,30 +69,30 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
} }
}); });
return std::make_tuple(CumSum(n_rows.CopyTo(coos[0].row->ctx), true), return std::make_tuple(
CumSum(n_cols.CopyTo(coos[0].row->ctx), true), CumSum(n_rows.CopyTo(coos[0].row->ctx), true),
CumSum(n_elms.CopyTo(coos[0].row->ctx), true)); CumSum(n_cols.CopyTo(coos[0].row->ctx), true),
CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
} }
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out, void _Merge(
int64_t n_arrs, int n_elms, IdType** arrs, IdType* prefix, IdType* offset, IdType* out, int64_t n_arrs,
DGLContext ctx, DGLDataType dtype, cudaStream_t stream) { int n_elms, DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
int nt = 256; int nt = 256;
int nb = (n_elms + nt - 1) / nt; int nb = (n_elms + nt - 1) / nt;
IdType** arrs_dev = static_cast<IdType**>( IdType** arrs_dev = static_cast<IdType**>(
device->AllocWorkspace(ctx, n_arrs*sizeof(IdType*))); device->AllocWorkspace(ctx, n_arrs * sizeof(IdType*)));
device->CopyDataFromTo( device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs, arrs, 0, arrs_dev, 0, sizeof(IdType*) * n_arrs, DGLContext{kDGLCPU, 0},
DGLContext{kDGLCPU, 0}, ctx, dtype); ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, _DisjointUnionKernel, nb, nt, 0, stream, arrs_dev, prefix, offset, out,
arrs_dev, prefix, offset, n_arrs, n_elms);
out, n_arrs, n_elms);
device->FreeWorkspace(ctx, arrs_dev); device->FreeWorkspace(ctx, arrs_dev);
} }
...@@ -132,52 +135,50 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -132,52 +135,50 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
IdType n_elements = 0; IdType n_elements = 0;
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0, &prefix_elm[coos.size()], 0, &n_elements, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0, &prefix_src[coos.size()], 0, &src_offset, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
device->CopyDataFromTo( device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0, &prefix_dst[coos.size()], 0, &dst_offset, 0, sizeof(IdType),
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->ctx, DGLContext{kDGLCPU, 0}, coos[0].row->dtype);
coos[0].row->dtype);
// Union src array // Union src array
IdArray result_src = NewIdArray( IdArray result_src =
n_elements, coos[0].row->ctx, coos[0].row->dtype.bits); NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);
_Merge<XPU, IdType>(rows.get(), prefix_src, prefix_elm, result_src.Ptr<IdType>(), _Merge<XPU, IdType>(
coos.size(), n_elements, ctx, dtype, stream); rows.get(), prefix_src, prefix_elm, result_src.Ptr<IdType>(), coos.size(),
n_elements, ctx, dtype, stream);
// Union dst array // Union dst array
IdArray result_dst = NewIdArray( IdArray result_dst =
n_elements, coos[0].col->ctx, coos[0].col->dtype.bits); NewIdArray(n_elements, coos[0].col->ctx, coos[0].col->dtype.bits);
_Merge<XPU, IdType>(cols.get(), prefix_dst, prefix_elm, result_dst.Ptr<IdType>(), _Merge<XPU, IdType>(
coos.size(), n_elements, ctx, dtype, stream); cols.get(), prefix_dst, prefix_elm, result_dst.Ptr<IdType>(), coos.size(),
n_elements, ctx, dtype, stream);
// Union data array if exists and fetch number of elements // Union data array if exists and fetch number of elements
IdArray result_dat = NullArray(); IdArray result_dat = NullArray();
if (has_data) { if (has_data) {
result_dat = NewIdArray( result_dat =
n_elements, coos[0].row->ctx, coos[0].row->dtype.bits); NewIdArray(n_elements, coos[0].row->ctx, coos[0].row->dtype.bits);
_Merge<XPU, IdType>(data.get(), prefix_elm, prefix_elm, result_dat.Ptr<IdType>(), _Merge<XPU, IdType>(
coos.size(), n_elements, ctx, dtype, stream); data.get(), prefix_elm, prefix_elm, result_dat.Ptr<IdType>(),
coos.size(), n_elements, ctx, dtype, stream);
} }
return COOMatrix( return COOMatrix(
src_offset, dst_offset, src_offset, dst_offset, result_src, result_dst, result_dat, row_sorted,
result_src, col_sorted);
result_dst,
result_dat,
row_sorted,
col_sorted);
} }
template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(
const std::vector<COOMatrix>& coos);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment