Unverified Commit b2d38ca8 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4803)



* [Misc] clang-format auto fix.

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 07dc8fb6
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
...@@ -14,7 +15,7 @@ using runtime::parallel_for; ...@@ -14,7 +15,7 @@ using runtime::parallel_for;
namespace aten { namespace aten {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) { std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
const int64_t rows = lengths->shape[0]; const int64_t rows = lengths->shape[0];
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]); const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
...@@ -41,16 +42,24 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) { ...@@ -41,16 +42,24 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
return std::make_pair(concat, offsets); return std::make_pair(concat, offsets);
} }
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(NDArray, IdArray); NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(NDArray, IdArray); NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(NDArray, IdArray); NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(NDArray, IdArray); template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(NDArray, IdArray); NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(
NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(
NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(
NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(
NDArray, IdArray);
template<DGLDeviceType XPU, typename DType> template <DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) { std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
CHECK_NDIM(array, 2, "array"); CHECK_NDIM(array, 2, "array");
const DType *array_data = static_cast<DType *>(array->data); const DType *array_data = static_cast<DType *>(array->data);
...@@ -64,8 +73,7 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) { ...@@ -64,8 +73,7 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
int64_t j; int64_t j;
for (j = 0; j < cols; ++j) { for (j = 0; j < cols; ++j) {
const DType val = array_data[i * cols + j]; const DType val = array_data[i * cols + j];
if (val == pad_value) if (val == pad_value) break;
break;
} }
length_data[i] = j; length_data[i] = j;
} }
...@@ -75,10 +83,14 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) { ...@@ -75,10 +83,14 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
return std::make_tuple(ret.first, length, ret.second); return std::make_tuple(ret.first, length, ret.second);
} }
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(NDArray, int32_t); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(NDArray, int64_t); NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(NDArray, float); template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(NDArray, double); NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(
NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(
NDArray, double);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* \brief Array repeat CPU implementation * \brief Array repeat CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <algorithm> #include <algorithm>
namespace dgl { namespace dgl {
...@@ -13,21 +14,23 @@ namespace impl { ...@@ -13,21 +14,23 @@ namespace impl {
template <DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats) { NDArray Repeat(NDArray array, IdArray repeats) {
CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch"; CHECK(array->shape[0] == repeats->shape[0])
<< "shape of array and repeats mismatch";
const int64_t len = array->shape[0]; const int64_t len = array->shape[0];
const DType *array_data = static_cast<DType *>(array->data); const DType *array_data = static_cast<DType *>(array->data);
const IdType *repeats_data = static_cast<IdType *>(repeats->data); const IdType *repeats_data = static_cast<IdType *>(repeats->data);
IdType num_elements = 0; IdType num_elements = 0;
for (int64_t i = 0; i < len; ++i) for (int64_t i = 0; i < len; ++i) num_elements += repeats_data[i];
num_elements += repeats_data[i];
NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx); NDArray result = NDArray::Empty({num_elements}, array->dtype, array->ctx);
DType *result_data = static_cast<DType *>(result->data); DType *result_data = static_cast<DType *>(result->data);
IdType curr = 0; IdType curr = 0;
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
std::fill(result_data + curr, result_data + curr + repeats_data[i], array_data[i]); std::fill(
result_data + curr, result_data + curr + repeats_data[i],
array_data[i]);
curr += repeats_data[i]; curr += repeats_data[i];
} }
......
...@@ -13,7 +13,8 @@ namespace impl { ...@@ -13,7 +13,8 @@ namespace impl {
template <DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices) { NDArray Scatter(NDArray array, IdArray indices) {
NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx); NDArray result =
NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);
const DType *array_data = static_cast<DType *>(array->data); const DType *array_data = static_cast<DType *>(array->data);
const IdType *indices_data = static_cast<IdType *>(indices->data); const IdType *indices_data = static_cast<IdType *>(indices->data);
...@@ -37,9 +38,9 @@ template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray); ...@@ -37,9 +38,9 @@ template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) { void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>(); const IdType *idx = index.Ptr<IdType>();
const DType* val = value.Ptr<DType>(); const DType *val = value.Ptr<DType>();
DType* outd = out.Ptr<DType>(); DType *outd = out.Ptr<DType>();
runtime::parallel_for(0, len, [&](size_t b, size_t e) { runtime::parallel_for(0, len, [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
outd[idx[i]] = val[i]; outd[idx[i]] = val[i];
......
...@@ -17,8 +17,7 @@ struct PairRef { ...@@ -17,8 +17,7 @@ struct PairRef {
PairRef() = delete; PairRef() = delete;
PairRef(const PairRef& other) = default; PairRef(const PairRef& other) = default;
PairRef(PairRef&& other) = default; PairRef(PairRef&& other) = default;
PairRef(V1 *const r, V2 *const c) PairRef(V1* const r, V2* const c) : row(r), col(c) {}
: row(r), col(c) {}
PairRef& operator=(const PairRef& other) { PairRef& operator=(const PairRef& other) {
*row = *other.row; *row = *other.row;
...@@ -31,17 +30,15 @@ struct PairRef { ...@@ -31,17 +30,15 @@ struct PairRef {
return *this; return *this;
} }
operator std::pair<V1, V2>() const { operator std::pair<V1, V2>() const { return std::make_pair(*row, *col); }
return std::make_pair(*row, *col);
}
void Swap(const PairRef& other) const { void Swap(const PairRef& other) const {
std::swap(*row, *other.row); std::swap(*row, *other.row);
std::swap(*col, *other.col); std::swap(*col, *other.col);
} }
V1 *row; V1* row;
V2 *col; V2* col;
}; };
using std::swap; using std::swap;
...@@ -51,43 +48,30 @@ void swap(const PairRef<V1, V2>& r1, const PairRef<V1, V2>& r2) { ...@@ -51,43 +48,30 @@ void swap(const PairRef<V1, V2>& r1, const PairRef<V1, V2>& r2) {
} }
template <typename V1, typename V2> template <typename V1, typename V2>
struct PairIterator : public std::iterator<std::random_access_iterator_tag, struct PairIterator
std::pair<V1, V2>, : public std::iterator<
std::ptrdiff_t, std::random_access_iterator_tag, std::pair<V1, V2>, std::ptrdiff_t,
std::pair<V1*, V2*>, std::pair<V1*, V2*>, PairRef<V1, V2>> {
PairRef<V1, V2>> {
PairIterator() = default; PairIterator() = default;
PairIterator(const PairIterator& other) = default; PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default; PairIterator(PairIterator&& other) = default;
PairIterator(V1 *r, V2 *c): row(r), col(c) {} PairIterator(V1* r, V2* c) : row(r), col(c) {}
PairIterator& operator=(const PairIterator& other) = default; PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default; PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default; ~PairIterator() = default;
bool operator==(const PairIterator& other) const { bool operator==(const PairIterator& other) const { return row == other.row; }
return row == other.row;
}
bool operator!=(const PairIterator& other) const { bool operator!=(const PairIterator& other) const { return row != other.row; }
return row != other.row;
}
bool operator<(const PairIterator& other) const { bool operator<(const PairIterator& other) const { return row < other.row; }
return row < other.row;
}
bool operator>(const PairIterator& other) const { bool operator>(const PairIterator& other) const { return row > other.row; }
return row > other.row;
}
bool operator<=(const PairIterator& other) const { bool operator<=(const PairIterator& other) const { return row <= other.row; }
return row <= other.row;
}
bool operator>=(const PairIterator& other) const { bool operator>=(const PairIterator& other) const { return row >= other.row; }
return row >= other.row;
}
PairIterator& operator+=(const std::ptrdiff_t& movement) { PairIterator& operator+=(const std::ptrdiff_t& movement) {
row += movement; row += movement;
...@@ -101,13 +85,9 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -101,13 +85,9 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
return *this; return *this;
} }
PairIterator& operator++() { PairIterator& operator++() { return operator+=(1); }
return operator+=(1);
}
PairIterator& operator--() { PairIterator& operator--() { return operator-=(1); }
return operator-=(1);
}
PairIterator operator++(int) { PairIterator operator++(int) {
PairIterator ret(*this); PairIterator ret(*this);
...@@ -137,20 +117,16 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -137,20 +117,16 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
return row - other.row; return row - other.row;
} }
PairRef<V1, V2> operator*() const { PairRef<V1, V2> operator*() const { return PairRef<V1, V2>(row, col); }
return PairRef<V1, V2>(row, col); PairRef<V1, V2> operator*() { return PairRef<V1, V2>(row, col); }
}
PairRef<V1, V2> operator*() {
return PairRef<V1, V2>(row, col);
}
// required for random access iterators in VS2019 // required for random access iterators in VS2019
PairRef<V1, V2> operator[](size_t offset) const { PairRef<V1, V2> operator[](size_t offset) const {
return PairRef<V1, V2>(row + offset, col + offset); return PairRef<V1, V2>(row + offset, col + offset);
} }
V1 *row; V1* row;
V2 *col; V2* col;
}; };
} // namespace } // namespace
...@@ -175,14 +151,16 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) { ...@@ -175,14 +151,16 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
#endif #endif
PairIterator<IdType, int64_t>(val_data, idx_data), PairIterator<IdType, int64_t>(val_data, idx_data),
PairIterator<IdType, int64_t>(val_data, idx_data) + nitem, PairIterator<IdType, int64_t>(val_data, idx_data) + nitem,
[] (const Pair& a, const Pair& b) { [](const Pair& a, const Pair& b) {
return std::get<0>(a) < std::get<0>(b); return std::get<0>(a) < std::get<0>(b);
}); });
return std::make_pair(val, idx); return std::make_pair(val, idx);
} }
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(IdArray, int num_bits); template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(IdArray, int num_bits); IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(
IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -8,16 +8,19 @@ ...@@ -8,16 +8,19 @@
#include <dgl/aten/types.h> #include <dgl/aten/types.h>
#include <parallel_hashmap/phmap.h> #include <parallel_hashmap/phmap.h>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
/*! /*!
* \brief A hashmap that maps each ids in the given array to new ids starting from zero. * \brief A hashmap that maps each ids in the given array to new ids starting
* from zero.
* *
* Useful for relabeling integers and finding unique integers. * Useful for relabeling integers and finding unique integers.
* *
...@@ -27,23 +30,21 @@ template <typename IdType> ...@@ -27,23 +30,21 @@ template <typename IdType>
class IdHashMap { class IdHashMap {
public: public:
// default ctor // default ctor
IdHashMap(): filter_(kFilterSize, false) {} IdHashMap() : filter_(kFilterSize, false) {}
// Construct the hashmap using the given id array. // Construct the hashmap using the given id array.
// The id array could contain duplicates. // The id array could contain duplicates.
// If the id array has no duplicates, the array will be relabeled to consecutive // If the id array has no duplicates, the array will be relabeled to
// integers starting from 0. // consecutive integers starting from 0.
explicit IdHashMap(IdArray ids): filter_(kFilterSize, false) { explicit IdHashMap(IdArray ids) : filter_(kFilterSize, false) {
oldv2newv_.reserve(ids->shape[0]); oldv2newv_.reserve(ids->shape[0]);
Update(ids); Update(ids);
} }
// copy ctor // copy ctor
IdHashMap(const IdHashMap &other) = default; IdHashMap(const IdHashMap& other) = default;
void Reserve(const int64_t size) { void Reserve(const int64_t size) { oldv2newv_.reserve(size); }
oldv2newv_.reserve(size);
}
// Update the hashmap with given id array. // Update the hashmap with given id array.
// The id array could contain duplicates. // The id array could contain duplicates.
...@@ -52,8 +53,8 @@ class IdHashMap { ...@@ -52,8 +53,8 @@ class IdHashMap {
const int64_t len = ids->shape[0]; const int64_t len = ids->shape[0];
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
const IdType id = ids_data[i]; const IdType id = ids_data[i];
// phmap::flat_hash_map::insert assures that an insertion will not happen if the // phmap::flat_hash_map::insert assures that an insertion will not happen
// key already exists. // if the key already exists.
oldv2newv_.insert({id, oldv2newv_.size()}); oldv2newv_.insert({id, oldv2newv_.size()});
filter_[id & kFilterMask] = true; filter_[id & kFilterMask] = true;
} }
...@@ -88,22 +89,21 @@ class IdHashMap { ...@@ -88,22 +89,21 @@ class IdHashMap {
// Return all the old ids collected so far, ordered by new id. // Return all the old ids collected so far, ordered by new id.
IdArray Values() const { IdArray Values() const {
IdArray values = NewIdArray(oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); IdArray values = NewIdArray(
oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* values_data = static_cast<IdType*>(values->data); IdType* values_data = static_cast<IdType*>(values->data);
for (auto pair : oldv2newv_) for (auto pair : oldv2newv_) values_data[pair.second] = pair.first;
values_data[pair.second] = pair.first;
return values; return values;
} }
inline size_t Size() const { inline size_t Size() const { return oldv2newv_.size(); }
return oldv2newv_.size();
}
private: private:
static constexpr int32_t kFilterMask = 0xFFFFFF; static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1; static constexpr int32_t kFilterSize = kFilterMask + 1;
// This bitmap is used as a bloom filter to remove some lookups. // This bitmap is used as a bloom filter to remove some lookups.
// Hashtable is very slow. Using bloom filter can significantly speed up lookups. // Hashtable is very slow. Using bloom filter can significantly speed up
// lookups.
std::vector<bool> filter_; std::vector<bool> filter_;
// The hashmap from old vid to new vid // The hashmap from old vid to new vid
phmap::flat_hash_map<IdType, IdType> oldv2newv_; phmap::flat_hash_map<IdType, IdType> oldv2newv_;
...@@ -114,7 +114,7 @@ class IdHashMap { ...@@ -114,7 +114,7 @@ class IdHashMap {
*/ */
struct PairHash { struct PairHash {
template <class T1, class T2> template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const { std::size_t operator()(const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second); return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
} }
}; };
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -19,8 +20,7 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { ...@@ -19,8 +20,7 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
if (!coo.row_sorted || !coo.col_sorted) if (!coo.row_sorted || !coo.col_sorted) coo = COOSort(coo, true);
coo = COOSort(coo, true);
std::vector<IdType> new_row, new_col, count; std::vector<IdType> new_row, new_col, count;
IdType prev_row = -1, prev_col = -1; IdType prev_row = -1, prev_col = -1;
...@@ -39,8 +39,12 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) { ...@@ -39,8 +39,12 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
} }
COOMatrix coo_result = COOMatrix{ COOMatrix coo_result = COOMatrix{
coo.num_rows, coo.num_cols, NDArray::FromVector(new_row), NDArray::FromVector(new_col), coo.num_rows,
NullArray(), true}; coo.num_cols,
NDArray::FromVector(new_row),
NDArray::FromVector(new_col),
NullArray(),
true};
return std::make_pair(coo_result, NDArray::FromVector(count)); return std::make_pair(coo_result, NDArray::FromVector(count));
} }
......
...@@ -5,24 +5,24 @@ ...@@ -5,24 +5,24 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <numeric>
#include <algorithm> #include <algorithm>
#include <vector>
#include <iterator> #include <iterator>
#include <numeric>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
IdType* coo_row = coo.row.Ptr<IdType>(); IdType* coo_row = coo.row.Ptr<IdType>();
IdType* coo_col = coo.col.Ptr<IdType>(); IdType* coo_col = coo.col.Ptr<IdType>();
IdArray data = COOHasData(coo) ? coo.data : Range(0, IdArray data = COOHasData(coo)
nnz, ? coo.data
coo.row->dtype.bits, : Range(0, nnz, coo.row->dtype.bits, coo.row->ctx);
coo.row->ctx);
IdType* data_data = data.Ptr<IdType>(); IdType* data_data = data.Ptr<IdType>();
std::vector<IdType> new_row; std::vector<IdType> new_row;
std::vector<IdType> new_col; std::vector<IdType> new_col;
...@@ -32,8 +32,7 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { ...@@ -32,8 +32,7 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
IdType v = coo_col[i]; IdType v = coo_col[i];
for (int64_t j = 0; j < nnz; ++j) { for (int64_t j = 0; j < nnz; ++j) {
// no self-loop // no self-loop
if (i == j) if (i == j) continue;
continue;
// succ_u == v // succ_u == v
// if not backtracking succ_u != u // if not backtracking succ_u != u
...@@ -44,14 +43,16 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) { ...@@ -44,14 +43,16 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
} }
} }
COOMatrix res = COOMatrix(nnz, nnz, NDArray::FromVector(new_row), NDArray::FromVector(new_col), COOMatrix res = COOMatrix(
nnz, nnz, NDArray::FromVector(new_row), NDArray::FromVector(new_col),
NullArray(), false, false); NullArray(), false, false);
return res; return res;
} }
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(const COOMatrix &coo, bool backtracking); const COOMatrix& coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int64_t>(const COOMatrix &coo, bool backtracking); template COOMatrix COOLineGraph<kDGLCPU, int64_t>(
const COOMatrix& coo, bool backtracking);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
* \brief COO matrix remove entries CPU implementation * \brief COO matrix remove entries CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -15,14 +17,12 @@ namespace impl { ...@@ -15,14 +17,12 @@ namespace impl {
namespace { namespace {
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */ /*! \brief COORemove implementation for COOMatrix with default consecutive edge
* IDs */
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COORemoveConsecutive( void COORemoveConsecutive(
COOMatrix coo, COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,
IdArray entries, std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {
std::vector<IdType> *new_rows,
std::vector<IdType> *new_cols,
std::vector<IdType> *new_eids) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
const IdType *row_data = static_cast<IdType *>(coo.row->data); const IdType *row_data = static_cast<IdType *>(coo.row->data);
...@@ -36,8 +36,7 @@ void COORemoveConsecutive( ...@@ -36,8 +36,7 @@ void COORemoveConsecutive(
for (int64_t i = 0; i < nnz; ++i) { for (int64_t i = 0; i < nnz; ++i) {
if (j < n_entries && entry_data_sorted[j] == i) { if (j < n_entries && entry_data_sorted[j] == i) {
// Move on to the next different entry // Move on to the next different entry
while (j < n_entries && entry_data_sorted[j] == i) while (j < n_entries && entry_data_sorted[j] == i) ++j;
++j;
continue; continue;
} }
new_rows->push_back(row_data[i]); new_rows->push_back(row_data[i]);
...@@ -49,11 +48,8 @@ void COORemoveConsecutive( ...@@ -49,11 +48,8 @@ void COORemoveConsecutive(
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */ /*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void COORemoveShuffled( void COORemoveShuffled(
COOMatrix coo, COOMatrix coo, IdArray entries, std::vector<IdType> *new_rows,
IdArray entries, std::vector<IdType> *new_cols, std::vector<IdType> *new_eids) {
std::vector<IdType> *new_rows,
std::vector<IdType> *new_cols,
std::vector<IdType> *new_eids) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const IdType *row_data = static_cast<IdType *>(coo.row->data); const IdType *row_data = static_cast<IdType *>(coo.row->data);
const IdType *col_data = static_cast<IdType *>(coo.col->data); const IdType *col_data = static_cast<IdType *>(coo.col->data);
...@@ -63,8 +59,7 @@ void COORemoveShuffled( ...@@ -63,8 +59,7 @@ void COORemoveShuffled(
for (int64_t i = 0; i < nnz; ++i) { for (int64_t i = 0; i < nnz; ++i) {
const IdType eid = eid_data[i]; const IdType eid = eid_data[i];
if (eid_map.Contains(eid)) if (eid_map.Contains(eid)) continue;
continue;
new_rows->push_back(row_data[i]); new_rows->push_back(row_data[i]);
new_cols->push_back(col_data[i]); new_cols->push_back(col_data[i]);
new_eids->push_back(eid); new_eids->push_back(eid);
...@@ -77,8 +72,7 @@ template <DGLDeviceType XPU, typename IdType> ...@@ -77,8 +72,7 @@ template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries) { COOMatrix COORemove(COOMatrix coo, IdArray entries) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
if (n_entries == 0) if (n_entries == 0) return coo;
return coo;
std::vector<IdType> new_rows, new_cols, new_eids; std::vector<IdType> new_rows, new_cols, new_eids;
new_rows.reserve(nnz - n_entries); new_rows.reserve(nnz - n_entries);
...@@ -86,16 +80,16 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) { ...@@ -86,16 +80,16 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
new_eids.reserve(nnz - n_entries); new_eids.reserve(nnz - n_entries);
if (COOHasData(coo)) if (COOHasData(coo))
COORemoveShuffled<XPU, IdType>(coo, entries, &new_rows, &new_cols, &new_eids); COORemoveShuffled<XPU, IdType>(
coo, entries, &new_rows, &new_cols, &new_eids);
else else
// Removing from COO ordered by eid has more efficient implementation. // Removing from COO ordered by eid has more efficient implementation.
COORemoveConsecutive<XPU, IdType>(coo, entries, &new_rows, &new_cols, &new_eids); COORemoveConsecutive<XPU, IdType>(
coo, entries, &new_rows, &new_cols, &new_eids);
return COOMatrix( return COOMatrix(
coo.num_rows, coo.num_cols, coo.num_rows, coo.num_cols, IdArray::FromVector(new_rows),
IdArray::FromVector(new_rows), IdArray::FromVector(new_cols), IdArray::FromVector(new_eids));
IdArray::FromVector(new_cols),
IdArray::FromVector(new_eids));
} }
template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries); template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
......
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#ifdef PARALLEL_ALGORITHMS #ifdef PARALLEL_ALGORITHMS
#include <parallel/algorithm> #include <parallel/algorithm>
#endif #endif
#include <numeric>
#include <algorithm> #include <algorithm>
#include <vector>
#include <iterator> #include <iterator>
#include <numeric>
#include <tuple> #include <tuple>
#include <vector>
namespace { namespace {
...@@ -20,7 +20,7 @@ struct TupleRef { ...@@ -20,7 +20,7 @@ struct TupleRef {
TupleRef() = delete; TupleRef() = delete;
TupleRef(const TupleRef& other) = default; TupleRef(const TupleRef& other) = default;
TupleRef(TupleRef&& other) = default; TupleRef(TupleRef&& other) = default;
TupleRef(IdType *const r, IdType *const c, IdType *const d) TupleRef(IdType* const r, IdType* const c, IdType* const d)
: row(r), col(c), data(d) {} : row(r), col(c), data(d) {}
TupleRef& operator=(const TupleRef& other) { TupleRef& operator=(const TupleRef& other) {
...@@ -56,43 +56,31 @@ void swap(const TupleRef<IdType>& r1, const TupleRef<IdType>& r2) { ...@@ -56,43 +56,31 @@ void swap(const TupleRef<IdType>& r1, const TupleRef<IdType>& r2) {
} }
template <typename IdType> template <typename IdType>
struct CooIterator : public std::iterator<std::random_access_iterator_tag, struct CooIterator
std::tuple<IdType, IdType, IdType>, : public std::iterator<
std::ptrdiff_t, std::random_access_iterator_tag, std::tuple<IdType, IdType, IdType>,
std::tuple<IdType*, IdType*, IdType*>, std::ptrdiff_t, std::tuple<IdType*, IdType*, IdType*>,
TupleRef<IdType>> { TupleRef<IdType>> {
CooIterator() = default; CooIterator() = default;
CooIterator(const CooIterator& other) = default; CooIterator(const CooIterator& other) = default;
CooIterator(CooIterator&& other) = default; CooIterator(CooIterator&& other) = default;
CooIterator(IdType *r, IdType *c, IdType *d): row(r), col(c), data(d) {} CooIterator(IdType* r, IdType* c, IdType* d) : row(r), col(c), data(d) {}
CooIterator& operator=(const CooIterator& other) = default; CooIterator& operator=(const CooIterator& other) = default;
CooIterator& operator=(CooIterator&& other) = default; CooIterator& operator=(CooIterator&& other) = default;
~CooIterator() = default; ~CooIterator() = default;
bool operator==(const CooIterator& other) const { bool operator==(const CooIterator& other) const { return row == other.row; }
return row == other.row;
}
bool operator!=(const CooIterator& other) const { bool operator!=(const CooIterator& other) const { return row != other.row; }
return row != other.row;
}
bool operator<(const CooIterator& other) const { bool operator<(const CooIterator& other) const { return row < other.row; }
return row < other.row;
}
bool operator>(const CooIterator& other) const { bool operator>(const CooIterator& other) const { return row > other.row; }
return row > other.row;
}
bool operator<=(const CooIterator& other) const { bool operator<=(const CooIterator& other) const { return row <= other.row; }
return row <= other.row;
}
bool operator>=(const CooIterator& other) const { bool operator>=(const CooIterator& other) const { return row >= other.row; }
return row >= other.row;
}
CooIterator& operator+=(const std::ptrdiff_t& movement) { CooIterator& operator+=(const std::ptrdiff_t& movement) {
row += movement; row += movement;
...@@ -108,13 +96,9 @@ struct CooIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -108,13 +96,9 @@ struct CooIterator : public std::iterator<std::random_access_iterator_tag,
return *this; return *this;
} }
CooIterator& operator++() { CooIterator& operator++() { return operator+=(1); }
return operator+=(1);
}
CooIterator& operator--() { CooIterator& operator--() { return operator-=(1); }
return operator-=(1);
}
CooIterator operator++(int) { CooIterator operator++(int) {
CooIterator ret(*this); CooIterator ret(*this);
...@@ -147,9 +131,7 @@ struct CooIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -147,9 +131,7 @@ struct CooIterator : public std::iterator<std::random_access_iterator_tag,
TupleRef<IdType> operator*() const { TupleRef<IdType> operator*() const {
return TupleRef<IdType>(row, col, data); return TupleRef<IdType>(row, col, data);
} }
TupleRef<IdType> operator*() { TupleRef<IdType> operator*() { return TupleRef<IdType>(row, col, data); }
return TupleRef<IdType>(row, col, data);
}
// required for random access iterators in VS2019 // required for random access iterators in VS2019
TupleRef<IdType> operator[](size_t offset) const { TupleRef<IdType> operator[](size_t offset) const {
...@@ -188,8 +170,9 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -188,8 +170,9 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
CooIterator<IdType>(coo_row, coo_col, coo_data), CooIterator<IdType>(coo_row, coo_col, coo_data),
CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz, CooIterator<IdType>(coo_row, coo_col, coo_data) + nnz,
[](const Tuple& a, const Tuple& b) { [](const Tuple& a, const Tuple& b) {
return (std::get<0>(a) != std::get<0>(b)) ? return (std::get<0>(a) != std::get<0>(b))
(std::get<0>(a) < std::get<0>(b)) : (std::get<1>(a) < std::get<1>(b)); ? (std::get<0>(a) < std::get<0>(b))
: (std::get<1>(a) < std::get<1>(b));
}); });
} else { } else {
#ifdef PARALLEL_ALGORITHMS #ifdef PARALLEL_ALGORITHMS
...@@ -211,7 +194,6 @@ void COOSort_(COOMatrix* coo, bool sort_column) { ...@@ -211,7 +194,6 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool); template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool); template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
///////////////////////////// COOIsSorted ///////////////////////////// ///////////////////////////// COOIsSorted /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
...@@ -225,8 +207,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) { ...@@ -225,8 +207,7 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
row_sorted = (row[i - 1] <= row[i]); row_sorted = (row[i - 1] <= row[i]);
col_sorted = col_sorted && (row[i - 1] < row[i] || col[i - 1] <= col[i]); col_sorted = col_sorted && (row[i - 1] < row[i] || col[i - 1] <= col[i]);
} }
if (!row_sorted) if (!row_sorted) col_sorted = false;
col_sorted = false;
return {row_sorted, col_sorted}; return {row_sorted, col_sorted};
} }
......
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <vector>
#include <unordered_set>
#include <numeric> #include <numeric>
#include <unordered_set>
#include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -18,11 +20,11 @@ namespace aten { ...@@ -18,11 +20,11 @@ namespace aten {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data, void CollectDataFromSorted(
const IdType start, const IdType end, const IdType col, const IdType* indices_data, const IdType* data, const IdType start,
std::vector<IdType> *ret_vec) { const IdType end, const IdType col, std::vector<IdType>* ret_vec) {
const IdType *start_ptr = indices_data + start; const IdType* start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end; const IdType* end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col); auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched // This might be a multi-graph. We need to collect all of the matched
// columns. // columns.
...@@ -30,7 +32,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data, ...@@ -30,7 +32,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
// If the col exist // If the col exist
if (*it == col) { if (*it == col) {
IdType idx = it - indices_data; IdType idx = it - indices_data;
ret_vec->push_back(data? data[idx] : idx); ret_vec->push_back(data ? data[idx] : idx);
} else { } else {
// If we find a column that is different, we can stop searching now. // If we find a column that is different, we can stop searching now.
break; break;
...@@ -40,7 +42,8 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data, ...@@ -40,7 +42,8 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
...@@ -54,30 +57,35 @@ NDArray CSRGetData( ...@@ -54,30 +57,35 @@ NDArray CSRGetData(
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const IdType* data = CSRHasData(csr)? static_cast<IdType*>(csr.data->data) : nullptr; const IdType* data =
CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
const int64_t retlen = std::max(rowlen, collen); const int64_t retlen = std::max(rowlen, collen);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>(); const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids) if (return_eids)
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
"DType does not match row's dtype."; << "DType does not match row's dtype.";
NDArray ret = Full(filler, retlen, rows->ctx); NDArray ret = Full(filler, retlen, rows->ctx);
DType* ret_data = ret.Ptr<DType>(); DType* ret_data = ret.Ptr<DType>();
// NOTE: In most cases, the input csr is already sorted. If not, we might need to // NOTE: In most cases, the input csr is already sorted. If not, we might need
// consider sorting it especially when the number of (row, col) pairs is large. // to
// Need more benchmarks to justify the choice. // consider sorting it especially when the number of (row, col) pairs is
// large. Need more benchmarks to justify the choice.
if (csr.sorted) { if (csr.sorted) {
// use binary search on each row // use binary search on each row
parallel_for(0, retlen, [&](size_t b, size_t e) { parallel_for(0, retlen, [&](size_t b, size_t e) {
for (auto p = b; p < e; ++p) { for (auto p = b; p < e; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; const IdType row_id = row_data[p * row_stride],
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id; col_id = col_data[p * col_stride];
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id; CHECK(row_id >= 0 && row_id < csr.num_rows)
const IdType *start_ptr = indices_data + indptr_data[row_id]; << "Invalid row index: " << row_id;
const IdType *end_ptr = indices_data + indptr_data[row_id + 1]; CHECK(col_id >= 0 && col_id < csr.num_cols)
<< "Invalid col index: " << col_id;
const IdType* start_ptr = indices_data + indptr_data[row_id];
const IdType* end_ptr = indices_data + indptr_data[row_id + 1];
auto it = std::lower_bound(start_ptr, end_ptr, col_id); auto it = std::lower_bound(start_ptr, end_ptr, col_id);
if (it != end_ptr && *it == col_id) { if (it != end_ptr && *it == col_id) {
const IdType idx = it - indices_data; const IdType idx = it - indices_data;
...@@ -90,10 +98,14 @@ NDArray CSRGetData( ...@@ -90,10 +98,14 @@ NDArray CSRGetData(
// linear search on each row // linear search on each row
parallel_for(0, retlen, [&](size_t b, size_t e) { parallel_for(0, retlen, [&](size_t b, size_t e) {
for (auto p = b; p < e; ++p) { for (auto p = b; p < e; ++p) {
const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; const IdType row_id = row_data[p * row_stride],
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id; col_id = col_data[p * col_stride];
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id; CHECK(row_id >= 0 && row_id < csr.num_rows)
for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1]; ++idx) { << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols)
<< "Invalid col index: " << col_id;
for (IdType idx = indptr_data[row_id]; idx < indptr_data[row_id + 1];
++idx) {
if (indices_data[idx] == col_id) { if (indices_data[idx] == col_id) {
IdType eid = data ? data[idx] : idx; IdType eid = data ? data[idx] : idx;
ret_data[p] = return_eids ? eid : weight_data[eid]; ret_data[p] = return_eids ? eid : weight_data[eid];
...@@ -107,19 +119,25 @@ NDArray CSRGetData( ...@@ -107,19 +119,25 @@ NDArray CSRGetData(
} }
template NDArray CSRGetData<kDGLCPU, int32_t, float>( template NDArray CSRGetData<kDGLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCPU, int64_t, float>( template NDArray CSRGetData<kDGLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCPU, int32_t, double>( template NDArray CSRGetData<kDGLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
template NDArray CSRGetData<kDGLCPU, int64_t, double>( template NDArray CSRGetData<kDGLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>( template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int32_t filler);
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>( template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <parallel_hashmap/phmap.h> #include <parallel_hashmap/phmap.h>
#include <vector> #include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -22,12 +24,8 @@ namespace { ...@@ -22,12 +24,8 @@ namespace {
// TODO(BarclayII): avoid using map for sorted CSRs // TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType> template <typename IdType>
void CountNNZPerRow( void CountNNZPerRow(
const IdType* A_indptr, const IdType* A_indptr, const IdType* A_indices, const IdType* B_indptr,
const IdType* A_indices, const IdType* B_indices, IdType* C_indptr_data, int64_t M) {
const IdType* B_indptr,
const IdType* B_indices,
IdType* C_indptr_data,
int64_t M) {
parallel_for(0, M, [=](size_t b, size_t e) { parallel_for(0, M, [=](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
phmap::flat_hash_set<IdType> set; phmap::flat_hash_set<IdType> set;
...@@ -56,18 +54,10 @@ int64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) { ...@@ -56,18 +54,10 @@ int64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {
template <typename IdType, typename DType> template <typename IdType, typename DType>
void ComputeIndicesAndData( void ComputeIndicesAndData(
const IdType* A_indptr, const IdType* A_indptr, const IdType* A_indices, const IdType* A_eids,
const IdType* A_indices, const DType* A_data, const IdType* B_indptr, const IdType* B_indices,
const IdType* A_eids, const IdType* B_eids, const DType* B_data, const IdType* C_indptr_data,
const DType* A_data, IdType* C_indices_data, DType* C_weights_data, int64_t M) {
const IdType* B_indptr,
const IdType* B_indices,
const IdType* B_eids,
const DType* B_data,
const IdType* C_indptr_data,
IdType* C_indices_data,
DType* C_weights_data,
int64_t M) {
parallel_for(0, M, [=](size_t b, size_t e) { parallel_for(0, M, [=](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
phmap::flat_hash_map<IdType, DType> map; phmap::flat_hash_map<IdType, DType> map;
...@@ -95,11 +85,10 @@ void ComputeIndicesAndData( ...@@ -95,11 +85,10 @@ void ComputeIndicesAndData(
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A, const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights) { NDArray B_weights) {
CHECK_EQ(A.num_cols, B.num_rows) << "A's number of columns must equal to B's number of rows"; CHECK_EQ(A.num_cols, B.num_rows)
<< "A's number of columns must equal to B's number of rows";
const bool A_has_eid = !IsNullArray(A.data); const bool A_has_eid = !IsNullArray(A.data);
const bool B_has_eid = !IsNullArray(B.data); const bool B_has_eid = !IsNullArray(B.data);
const IdType* A_indptr = A.indptr.Ptr<IdType>(); const IdType* A_indptr = A.indptr.Ptr<IdType>();
...@@ -116,7 +105,8 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -116,7 +105,8 @@ std::pair<CSRMatrix, NDArray> CSRMM(
IdArray C_indptr = IdArray::Empty({M + 1}, A.indptr->dtype, A.indptr->ctx); IdArray C_indptr = IdArray::Empty({M + 1}, A.indptr->dtype, A.indptr->ctx);
IdType* C_indptr_data = C_indptr.Ptr<IdType>(); IdType* C_indptr_data = C_indptr.Ptr<IdType>();
CountNNZPerRow<IdType>(A_indptr, A_indices, B_indptr, B_indices, C_indptr_data, M); CountNNZPerRow<IdType>(
A_indptr, A_indices, B_indptr, B_indices, C_indptr_data, M);
int64_t nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M); int64_t nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);
// Allocate indices and weights array // Allocate indices and weights array
IdArray C_indices = IdArray::Empty({nnz}, A.indices->dtype, A.indices->ctx); IdArray C_indices = IdArray::Empty({nnz}, A.indices->dtype, A.indices->ctx);
...@@ -125,12 +115,12 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -125,12 +115,12 @@ std::pair<CSRMatrix, NDArray> CSRMM(
DType* C_weights_data = C_weights.Ptr<DType>(); DType* C_weights_data = C_weights.Ptr<DType>();
ComputeIndicesAndData<IdType, DType>( ComputeIndicesAndData<IdType, DType>(
A_indptr, A_indices, A_eids, A_data, A_indptr, A_indices, A_eids, A_data, B_indptr, B_indices, B_eids, B_data,
B_indptr, B_indices, B_eids, B_data,
C_indptr_data, C_indices_data, C_weights_data, M); C_indptr_data, C_indices_data, C_weights_data, M);
return { return {
CSRMatrix(M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), CSRMatrix(
M, P, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights}; C_weights};
} }
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
* \brief CSR matrix remove entries CPU implementation * \brief CSR matrix remove entries CPU implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -17,11 +19,8 @@ namespace { ...@@ -17,11 +19,8 @@ namespace {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRRemoveConsecutive( void CSRRemoveConsecutive(
CSRMatrix csr, CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,
IdArray entries, std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries); CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
...@@ -37,8 +36,7 @@ void CSRRemoveConsecutive( ...@@ -37,8 +36,7 @@ void CSRRemoveConsecutive(
for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) { for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
if (k < n_entries && entry_data_sorted[k] == j) { if (k < n_entries && entry_data_sorted[k] == j) {
// Move on to the next different entry // Move on to the next different entry
while (k < n_entries && entry_data_sorted[k] == j) while (k < n_entries && entry_data_sorted[k] == j) ++k;
++k;
continue; continue;
} }
new_indices->push_back(indices_data[j]); new_indices->push_back(indices_data[j]);
...@@ -50,11 +48,8 @@ void CSRRemoveConsecutive( ...@@ -50,11 +48,8 @@ void CSRRemoveConsecutive(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRRemoveShuffled( void CSRRemoveShuffled(
CSRMatrix csr, CSRMatrix csr, IdArray entries, std::vector<IdType> *new_indptr,
IdArray entries, std::vector<IdType> *new_indices, std::vector<IdType> *new_eids) {
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
CHECK_SAME_DTYPE(csr.indices, entries); CHECK_SAME_DTYPE(csr.indices, entries);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
...@@ -66,8 +61,7 @@ void CSRRemoveShuffled( ...@@ -66,8 +61,7 @@ void CSRRemoveShuffled(
for (int64_t i = 0; i < csr.num_rows; ++i) { for (int64_t i = 0; i < csr.num_rows; ++i) {
for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) { for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
const IdType eid = eid_data ? eid_data[j] : j; const IdType eid = eid_data ? eid_data[j] : j;
if (eid_map.Contains(eid)) if (eid_map.Contains(eid)) continue;
continue;
new_indices->push_back(indices_data[j]); new_indices->push_back(indices_data[j]);
new_eids->push_back(eid); new_eids->push_back(eid);
} }
...@@ -82,8 +76,7 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { ...@@ -82,8 +76,7 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries); CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const int64_t n_entries = entries->shape[0]; const int64_t n_entries = entries->shape[0];
if (n_entries == 0) if (n_entries == 0) return csr;
return csr;
std::vector<IdType> new_indptr, new_indices, new_eids; std::vector<IdType> new_indptr, new_indices, new_eids;
new_indptr.reserve(nnz - n_entries); new_indptr.reserve(nnz - n_entries);
...@@ -91,16 +84,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) { ...@@ -91,16 +84,16 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
new_eids.reserve(nnz - n_entries); new_eids.reserve(nnz - n_entries);
if (CSRHasData(csr)) if (CSRHasData(csr))
CSRRemoveShuffled<XPU, IdType>(csr, entries, &new_indptr, &new_indices, &new_eids); CSRRemoveShuffled<XPU, IdType>(
csr, entries, &new_indptr, &new_indices, &new_eids);
else else
// Removing from CSR ordered by eid has more efficient implementation // Removing from CSR ordered by eid has more efficient implementation
CSRRemoveConsecutive<XPU, IdType>(csr, entries, &new_indptr, &new_indices, &new_eids); CSRRemoveConsecutive<XPU, IdType>(
csr, entries, &new_indptr, &new_indices, &new_eids);
return CSRMatrix( return CSRMatrix(
csr.num_rows, csr.num_cols, csr.num_rows, csr.num_cols, IdArray::FromVector(new_indptr),
IdArray::FromVector(new_indptr), IdArray::FromVector(new_indices), IdArray::FromVector(new_eids));
IdArray::FromVector(new_indices),
IdArray::FromVector(new_eids));
} }
template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries); template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <numeric>
#include <algorithm> #include <algorithm>
#include <numeric>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -16,14 +17,14 @@ namespace impl { ...@@ -16,14 +17,14 @@ namespace impl {
///////////////////////////// CSRIsSorted ///////////////////////////// ///////////////////////////// CSRIsSorted /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) { bool CSRIsSorted(CSRMatrix csr) {
const IdType* indptr = csr.indptr.Ptr<IdType>(); const IdType *indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>(); const IdType *indices = csr.indices.Ptr<IdType>();
return runtime::parallel_reduce(0, csr.num_rows, 1, 1, return runtime::parallel_reduce(
0, csr.num_rows, 1, 1,
[indptr, indices](size_t b, size_t e, bool ident) { [indptr, indices](size_t b, size_t e, bool ident) {
for (size_t row = b; row < e; ++row) { for (size_t row = b; row < e; ++row) {
for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) { for (IdType i = indptr[row] + 1; i < indptr[row + 1]; ++i) {
if (indices[i - 1] > indices[i]) if (indices[i - 1] > indices[i]) return false;
return false;
} }
} }
return ident; return ident;
...@@ -37,12 +38,12 @@ template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr); ...@@ -37,12 +38,12 @@ template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
///////////////////////////// CSRSort ///////////////////////////// ///////////////////////////// CSRSort /////////////////////////////
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) { void CSRSort_(CSRMatrix *csr) {
typedef std::pair<IdType, IdType> ShufflePair; typedef std::pair<IdType, IdType> ShufflePair;
const int64_t num_rows = csr->num_rows; const int64_t num_rows = csr->num_rows;
const int64_t nnz = csr->indices->shape[0]; const int64_t nnz = csr->indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr->indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr->indptr->data);
IdType* indices_data = static_cast<IdType*>(csr->indices->data); IdType *indices_data = static_cast<IdType *>(csr->indices->data);
if (CSRIsSorted(*csr)) { if (CSRIsSorted(*csr)) {
csr->sorted = true; csr->sorted = true;
...@@ -52,7 +53,7 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -52,7 +53,7 @@ void CSRSort_(CSRMatrix* csr) {
if (!CSRHasData(*csr)) { if (!CSRHasData(*csr)) {
csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx); csr->data = aten::Range(0, nnz, csr->indptr->dtype.bits, csr->indptr->ctx);
} }
IdType* eid_data = static_cast<IdType*>(csr->data->data); IdType *eid_data = static_cast<IdType *>(csr->data->data);
runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) { runtime::parallel_for(0, num_rows, [=](size_t b, size_t e) {
for (auto row = b; row < e; ++row) { for (auto row = b; row < e; ++row) {
...@@ -65,7 +66,8 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -65,7 +66,8 @@ void CSRSort_(CSRMatrix* csr) {
reorder_vec[i].first = col[i]; reorder_vec[i].first = col[i];
reorder_vec[i].second = eid[i]; reorder_vec[i].second = eid[i];
} }
std::sort(reorder_vec.begin(), reorder_vec.end(), std::sort(
reorder_vec.begin(), reorder_vec.end(),
[](const ShufflePair &e1, const ShufflePair &e2) { [](const ShufflePair &e1, const ShufflePair &e2) {
return e1.first < e2.first; return e1.first < e2.first;
}); });
...@@ -79,8 +81,8 @@ void CSRSort_(CSRMatrix* csr) { ...@@ -79,8 +81,8 @@ void CSRSort_(CSRMatrix* csr) {
csr->sorted = true; csr->sorted = true;
} }
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr); template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix *csr);
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr); template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix *csr);
template <DGLDeviceType XPU, typename IdType, typename TagType> template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag( std::pair<CSRMatrix, NDArray> CSRSortByTag(
...@@ -93,15 +95,15 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -93,15 +95,15 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
const auto tag_data = static_cast<const TagType *>(tag_array->data); const auto tag_data = static_cast<const TagType *>(tag_array->data);
const int64_t num_rows = csr.num_rows; const int64_t num_rows = csr.num_rows;
NDArray tag_pos = NDArray::Empty({csr.num_rows, num_tags + 1}, NDArray tag_pos = NDArray::Empty(
csr.indptr->dtype, csr.indptr->ctx); {csr.num_rows, num_tags + 1}, csr.indptr->dtype, csr.indptr->ctx);
auto tag_pos_data = static_cast<IdType *>(tag_pos->data); auto tag_pos_data = static_cast<IdType *>(tag_pos->data);
std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0); std::fill(tag_pos_data, tag_pos_data + csr.num_rows * (num_tags + 1), 0);
aten::CSRMatrix output(csr.num_rows, csr.num_cols, csr.indptr.Clone(), aten::CSRMatrix output(
csr.indices.Clone(), csr.num_rows, csr.num_cols, csr.indptr.Clone(), csr.indices.Clone(),
NDArray::Empty({csr.indices->shape[0]}, NDArray::Empty(
csr.indices->dtype, csr.indices->ctx), {csr.indices->shape[0]}, csr.indices->dtype, csr.indices->ctx),
csr.sorted); csr.sorted);
auto out_indices_data = static_cast<IdType *>(output.indices->data); auto out_indices_data = static_cast<IdType *>(output.indices->data);
...@@ -115,18 +117,18 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag( ...@@ -115,18 +117,18 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
auto tag_pos_row = tag_pos_data + src * (num_tags + 1); auto tag_pos_row = tag_pos_data + src * (num_tags + 1);
std::vector<IdType> pointer(num_tags, 0); std::vector<IdType> pointer(num_tags, 0);
for (IdType ptr = start ; ptr < end ; ++ptr) { for (IdType ptr = start; ptr < end; ++ptr) {
const IdType eid = eid_data ? eid_data[ptr] : ptr; const IdType eid = eid_data ? eid_data[ptr] : ptr;
const TagType tag = tag_data[eid]; const TagType tag = tag_data[eid];
CHECK_LT(tag, num_tags); CHECK_LT(tag, num_tags);
++tag_pos_row[tag + 1]; ++tag_pos_row[tag + 1];
} // count } // count
for (TagType tag = 1 ; tag <= num_tags; ++tag) { for (TagType tag = 1; tag <= num_tags; ++tag) {
tag_pos_row[tag] += tag_pos_row[tag - 1]; tag_pos_row[tag] += tag_pos_row[tag - 1];
} // cumulate } // cumulate
for (IdType ptr = start ; ptr < end ; ++ptr) { for (IdType ptr = start; ptr < end; ++ptr) {
const IdType dst = indices_data[ptr]; const IdType dst = indices_data[ptr];
const IdType eid = eid_data ? eid_data[ptr] : ptr; const IdType eid = eid_data ? eid_data[ptr] : ptr;
const TagType tag = tag_data[eid]; const TagType tag = tag_data[eid];
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <parallel_hashmap/phmap.h> #include <parallel_hashmap/phmap.h>
#include <vector> #include <vector>
#include "array_utils.h" #include "array_utils.h"
namespace dgl { namespace dgl {
...@@ -22,8 +24,7 @@ namespace { ...@@ -22,8 +24,7 @@ namespace {
template <typename IdType> template <typename IdType>
void CountNNZPerRow( void CountNNZPerRow(
const std::vector<const IdType*>& A_indptr, const std::vector<const IdType*>& A_indptr,
const std::vector<const IdType*>& A_indices, const std::vector<const IdType*>& A_indices, IdType* C_indptr_data,
IdType* C_indptr_data,
int64_t M) { int64_t M) {
int64_t n = A_indptr.size(); int64_t n = A_indptr.size();
...@@ -57,11 +58,8 @@ void ComputeIndicesAndData( ...@@ -57,11 +58,8 @@ void ComputeIndicesAndData(
const std::vector<const IdType*>& A_indptr, const std::vector<const IdType*>& A_indptr,
const std::vector<const IdType*>& A_indices, const std::vector<const IdType*>& A_indices,
const std::vector<const IdType*>& A_eids, const std::vector<const IdType*>& A_eids,
const std::vector<const DType*>& A_data, const std::vector<const DType*>& A_data, const IdType* C_indptr_data,
const IdType* C_indptr_data, IdType* C_indices_data, DType* C_weights_data, int64_t M) {
IdType* C_indices_data,
DType* C_weights_data,
int64_t M) {
int64_t n = A_indptr.size(); int64_t n = A_indptr.size();
runtime::parallel_for(0, M, [=](size_t b, size_t e) { runtime::parallel_for(0, M, [=](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
...@@ -87,10 +85,10 @@ void ComputeIndicesAndData( ...@@ -87,10 +85,10 @@ void ComputeIndicesAndData(
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights) {
const std::vector<NDArray>& A_weights) {
CHECK(A.size() > 0) << "List of matrices can't be empty."; CHECK(A.size() > 0) << "List of matrices can't be empty.";
CHECK_EQ(A.size(), A_weights.size()) << "List of matrices and weights must have same length"; CHECK_EQ(A.size(), A_weights.size())
<< "List of matrices and weights must have same length";
const int64_t M = A[0].num_rows; const int64_t M = A[0].num_rows;
const int64_t N = A[0].num_cols; const int64_t N = A[0].num_cols;
const int64_t n = A.size(); const int64_t n = A.size();
...@@ -111,22 +109,26 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -111,22 +109,26 @@ std::pair<CSRMatrix, NDArray> CSRSum(
A_data[i] = data.Ptr<DType>(); A_data[i] = data.Ptr<DType>();
} }
IdArray C_indptr = IdArray::Empty({M + 1}, A[0].indptr->dtype, A[0].indptr->ctx); IdArray C_indptr =
IdArray::Empty({M + 1}, A[0].indptr->dtype, A[0].indptr->ctx);
IdType* C_indptr_data = C_indptr.Ptr<IdType>(); IdType* C_indptr_data = C_indptr.Ptr<IdType>();
CountNNZPerRow<IdType>(A_indptr, A_indices, C_indptr_data, M); CountNNZPerRow<IdType>(A_indptr, A_indices, C_indptr_data, M);
IdType nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M); IdType nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);
// Allocate indices and weights array // Allocate indices and weights array
IdArray C_indices = IdArray::Empty({nnz}, A[0].indices->dtype, A[0].indices->ctx); IdArray C_indices =
NDArray C_weights = NDArray::Empty({nnz}, A_weights[0]->dtype, A_weights[0]->ctx); IdArray::Empty({nnz}, A[0].indices->dtype, A[0].indices->ctx);
NDArray C_weights =
NDArray::Empty({nnz}, A_weights[0]->dtype, A_weights[0]->ctx);
IdType* C_indices_data = C_indices.Ptr<IdType>(); IdType* C_indices_data = C_indices.Ptr<IdType>();
DType* C_weights_data = C_weights.Ptr<DType>(); DType* C_weights_data = C_weights.Ptr<DType>();
ComputeIndicesAndData<IdType, DType>( ComputeIndicesAndData<IdType, DType>(
A_indptr, A_indices, A_eids, A_data, A_indptr, A_indices, A_eids, A_data, C_indptr_data, C_indices_data,
C_indptr_data, C_indices_data, C_weights_data, M); C_weights_data, M);
return { return {
CSRMatrix(M, N, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), CSRMatrix(
M, N, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights}; C_weights};
} }
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief CSR sorting * \brief CSR sorting
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <numeric>
#include <algorithm> #include <algorithm>
#include <numeric>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -14,11 +15,10 @@ namespace impl { ...@@ -14,11 +15,10 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) { std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
if (!csr.sorted) if (!csr.sorted) csr = CSRSort(csr);
csr = CSRSort(csr);
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
std::vector<IdType> indptr; std::vector<IdType> indptr;
std::vector<IdType> indices; std::vector<IdType> indices;
...@@ -27,16 +27,16 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) { ...@@ -27,16 +27,16 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
indptr[0] = 0; indptr[0] = 0;
for (int64_t i = 1; i < csr.indptr->shape[0]; ++i) { for (int64_t i = 1; i < csr.indptr->shape[0]; ++i) {
if (indptr_data[i-1] == indptr_data[i]) { if (indptr_data[i - 1] == indptr_data[i]) {
indptr[i] = indptr[i-1]; indptr[i] = indptr[i - 1];
continue; continue;
} }
int64_t cnt = 1; int64_t cnt = 1;
int64_t dup_cnt = 1; int64_t dup_cnt = 1;
indices.push_back(indices_data[indptr_data[i-1]]); indices.push_back(indices_data[indptr_data[i - 1]]);
for (int64_t j = indptr_data[i-1]+1; j < indptr_data[i]; ++j) { for (int64_t j = indptr_data[i - 1] + 1; j < indptr_data[i]; ++j) {
if (indices_data[j-1] == indices_data[j]) { if (indices_data[j - 1] == indices_data[j]) {
++dup_cnt; ++dup_cnt;
continue; continue;
} }
...@@ -46,29 +46,27 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) { ...@@ -46,29 +46,27 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
++cnt; ++cnt;
} }
count.push_back(dup_cnt); count.push_back(dup_cnt);
indptr[i] = indptr[i-1] + cnt; indptr[i] = indptr[i - 1] + cnt;
} }
CSRMatrix res_csr = CSRMatrix( CSRMatrix res_csr = CSRMatrix(
csr.num_rows, csr.num_rows, csr.num_cols, IdArray::FromVector(indptr),
csr.num_cols, IdArray::FromVector(indices), NullArray(), true);
IdArray::FromVector(indptr),
IdArray::FromVector(indices),
NullArray(),
true);
const IdArray &edge_count = IdArray::FromVector(count); const IdArray &edge_count = IdArray::FromVector(count);
const IdArray new_eids = Range( const IdArray new_eids =
0, res_csr.indices->shape[0], sizeof(IdType) * 8, csr.indptr->ctx); Range(0, res_csr.indices->shape[0], sizeof(IdType) * 8, csr.indptr->ctx);
const IdArray eids_remapped = CSRHasData(csr) ? const IdArray eids_remapped =
Scatter(Repeat(new_eids, edge_count), csr.data) : CSRHasData(csr) ? Scatter(Repeat(new_eids, edge_count), csr.data)
Repeat(new_eids, edge_count); : Repeat(new_eids, edge_count);
return std::make_tuple(res_csr, edge_count, eids_remapped); return std::make_tuple(res_csr, edge_count, eids_remapped);
} }
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(CSRMatrix); template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(CSRMatrix); CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(
CSRMatrix);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -5,17 +5,18 @@ ...@@ -5,17 +5,18 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <numeric>
#include <algorithm> #include <algorithm>
#include <vector>
#include <iterator> #include <iterator>
#include <numeric>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix UnionCsr(const std::vector<CSRMatrix> &csrs) {
std::vector<IdType> res_indptr; std::vector<IdType> res_indptr;
std::vector<IdType> res_indices; std::vector<IdType> res_indices;
std::vector<IdType> res_data; std::vector<IdType> res_data;
...@@ -31,12 +32,12 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -31,12 +32,12 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
for (size_t i = 0; i < csrs.size(); ++i) { for (size_t i = 0; i < csrs.size(); ++i) {
// eids of csrs[0] remains unchanged // eids of csrs[0] remains unchanged
// eids of csrs[1] will be increased by number of edges of csrs[0], etc. // eids of csrs[1] will be increased by number of edges of csrs[0], etc.
data.push_back(CSRHasData(csrs[i]) ? data.push_back(
csrs[i].data + num_edges: CSRHasData(csrs[i])
Range(num_edges, ? csrs[i].data + num_edges
num_edges + csrs[i].indices->shape[0], : Range(
csrs[i].indptr->dtype.bits, num_edges, num_edges + csrs[i].indices->shape[0],
csrs[i].indptr->ctx)); csrs[i].indptr->dtype.bits, csrs[i].indptr->ctx));
data_data.push_back(data[i].Ptr<IdType>()); data_data.push_back(data[i].Ptr<IdType>());
indptr_data.push_back(csrs[i].indptr.Ptr<IdType>()); indptr_data.push_back(csrs[i].indptr.Ptr<IdType>());
indices_data.push_back(csrs[i].indices.Ptr<IdType>()); indices_data.push_back(csrs[i].indices.Ptr<IdType>());
...@@ -55,13 +56,13 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -55,13 +56,13 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<int64_t> indices_off; std::vector<int64_t> indices_off;
res_indptr[i] = indptr_data[0][i]; res_indptr[i] = indptr_data[0][i];
indices_off.push_back(indptr_data[0][i-1]); indices_off.push_back(indptr_data[0][i - 1]);
for (size_t j = 1; j < csrs.size(); ++j) { for (size_t j = 1; j < csrs.size(); ++j) {
res_indptr[i] += indptr_data[j][i]; res_indptr[i] += indptr_data[j][i];
indices_off.push_back(indptr_data[j][i-1]); indices_off.push_back(indptr_data[j][i - 1]);
} }
IdType off = res_indptr[i-1]; IdType off = res_indptr[i - 1];
while (off < res_indptr[i]) { while (off < res_indptr[i]) {
IdType min = csrs[0].num_cols + 1; IdType min = csrs[0].num_cols + 1;
int64_t min_idx = -1; int64_t min_idx = -1;
...@@ -84,33 +85,29 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) { ...@@ -84,33 +85,29 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
} else { // some csrs are not sorted } else { // some csrs are not sorted
#pragma omp for #pragma omp for
for (int64_t i = 1; i <= csrs[0].num_rows; ++i) { for (int64_t i = 1; i <= csrs[0].num_rows; ++i) {
IdType off = res_indptr[i-1]; IdType off = res_indptr[i - 1];
res_indptr[i] = 0; res_indptr[i] = 0;
for (size_t j = 0; j < csrs.size(); ++j) { for (size_t j = 0; j < csrs.size(); ++j) {
std::memcpy(&res_indices[off], std::memcpy(
&indices_data[j][indptr_data[j][i-1]], &res_indices[off], &indices_data[j][indptr_data[j][i - 1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1])); sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i - 1]));
std::memcpy(&res_data[off], std::memcpy(
&data_data[j][indptr_data[j][i-1]], &res_data[off], &data_data[j][indptr_data[j][i - 1]],
sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i-1])); sizeof(IdType) * (indptr_data[j][i] - indptr_data[j][i - 1]));
off += indptr_data[j][i] - indptr_data[j][i-1]; off += indptr_data[j][i] - indptr_data[j][i - 1];
} }
res_indptr[i] = off; res_indptr[i] = off;
} // omp for } // omp for
} }
return CSRMatrix( return CSRMatrix(
csrs[0].num_rows, csrs[0].num_rows, csrs[0].num_cols, IdArray::FromVector(res_indptr),
csrs[0].num_cols, IdArray::FromVector(res_indices), IdArray::FromVector(res_data), sorted);
IdArray::FromVector(res_indptr),
IdArray::FromVector(res_indices),
IdArray::FromVector(res_data),
sorted);
} }
template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix>&); template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix> &);
template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix>&); template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix> &);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
/** /**
* Copyright (c) 2022, NVIDIA CORPORATION. * Copyright (c) 2022, NVIDIA CORPORATION.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* \file array/cpu/disjoint_union.cc * \file array/cpu/disjoint_union.cc
* \brief Disjoint union CPU implementation. * \brief Disjoint union CPU implementation.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <tuple> #include <tuple>
namespace dgl { namespace dgl {
...@@ -27,19 +28,20 @@ namespace aten { ...@@ -27,19 +28,20 @@ namespace aten {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) { std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(
IdArray prefix_src_arr = NewIdArray( const std::vector<COOMatrix>& coos) {
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits); IdArray prefix_src_arr =
IdArray prefix_dst_arr = NewIdArray( NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits); IdArray prefix_dst_arr =
IdArray prefix_elm_arr = NewIdArray( NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits); IdArray prefix_elm_arr =
NewIdArray(coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
auto prefix_src = prefix_src_arr.Ptr<IdType>(); auto prefix_src = prefix_src_arr.Ptr<IdType>();
auto prefix_dst = prefix_dst_arr.Ptr<IdType>(); auto prefix_dst = prefix_dst_arr.Ptr<IdType>();
auto prefix_elm = prefix_elm_arr.Ptr<IdType>(); auto prefix_elm = prefix_elm_arr.Ptr<IdType>();
dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e){ dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {
for (IdType i = b; i < e; ++i) { for (IdType i = b; i < e; ++i) {
prefix_src[i] = coos[i].num_rows; prefix_src[i] = coos[i].num_rows;
prefix_dst[i] = coos[i].num_cols; prefix_dst[i] = coos[i].num_cols;
...@@ -47,8 +49,8 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa ...@@ -47,8 +49,8 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
} }
}); });
return std::make_tuple(CumSum(prefix_src_arr, true), return std::make_tuple(
CumSum(prefix_dst_arr, true), CumSum(prefix_src_arr, true), CumSum(prefix_dst_arr, true),
CumSum(prefix_elm_arr, true)); CumSum(prefix_elm_arr, true));
} }
...@@ -83,9 +85,9 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -83,9 +85,9 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
auto res_dst_data = result_dst.Ptr<IdType>(); auto res_dst_data = result_dst.Ptr<IdType>();
auto res_dat_data = result_dat.Ptr<IdType>(); auto res_dat_data = result_dat.Ptr<IdType>();
dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e){ dgl::runtime::parallel_for(0, coos.size(), [&](IdType b, IdType e) {
for (IdType i = b; i < e; ++i) { for (IdType i = b; i < e; ++i) {
const aten::COOMatrix &coo = coos[i]; const aten::COOMatrix& coo = coos[i];
if (!coo.row_sorted) row_sorted = false; if (!coo.row_sorted) row_sorted = false;
if (!coo.col_sorted) col_sorted = false; if (!coo.col_sorted) col_sorted = false;
...@@ -104,22 +106,20 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { ...@@ -104,22 +106,20 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
if (has_data) { if (has_data) {
for (IdType j = 0; j < coo.row->shape[0]; j++) { for (IdType j = 0; j < coo.row->shape[0]; j++) {
const auto d = (!COOHasData(coo)) ? j : edges_dat[j]; const auto d = (!COOHasData(coo)) ? j : edges_dat[j];
res_dat_data[prefix_elm[i]+j] = d + prefix_elm[i]; res_dat_data[prefix_elm[i] + j] = d + prefix_elm[i];
} }
} }
} }
}); });
return COOMatrix( return COOMatrix(
prefix_src[coos.size()], prefix_dst[coos.size()], prefix_src[coos.size()], prefix_dst[coos.size()], result_src, result_dst,
result_src, result_dat, row_sorted, col_sorted);
result_dst,
result_dat,
row_sorted,
col_sorted);
} }
template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(const std::vector<COOMatrix>& coos); template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(
template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(const std::vector<COOMatrix>& coos); const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(
const std::vector<COOMatrix>& coos);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -25,8 +26,9 @@ void transpose(const DType *in, DType *out, const int N, const int M) { ...@@ -25,8 +26,9 @@ void transpose(const DType *in, DType *out, const int N, const int M) {
} }
template <typename DType> template <typename DType>
void matmul(const DType *A, const DType *B, void matmul(
DType *C, const int M, const int N, const int K) { const DType *A, const DType *B, DType *C, const int M, const int N,
const int K) {
#pragma omp parallel #pragma omp parallel
{ {
int i, j, k; int i, j, k;
...@@ -55,25 +57,24 @@ void matmul(const DType *A, const DType *B, ...@@ -55,25 +57,24 @@ void matmul(const DType *A, const DType *B,
* \param b_trans Matrix B to be transposed * \param b_trans Matrix B to be transposed
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void gatherMM_SortedEtype(const NDArray A, void gatherMM_SortedEtype(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray A_dim1_per_rel,
NDArray C, const NDArray B_dim1_per_rel, bool a_trans, bool b_trans) {
const NDArray A_dim1_per_rel,
const NDArray B_dim1_per_rel,
bool a_trans, bool b_trans) {
assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements()); assert(A_dim1_per_rel.NumElements() == B_dim1_per_rel.NumElements());
int64_t num_rel = A_dim1_per_rel.NumElements(); int64_t num_rel = A_dim1_per_rel.NumElements();
const DType *A_data = A.Ptr<DType>(); const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>(); const DType *B_data = B.Ptr<DType>();
const IdType* A_rel_data = A_dim1_per_rel.Ptr<IdType>(); const IdType *A_rel_data = A_dim1_per_rel.Ptr<IdType>();
const IdType* B_rel_data = B_dim1_per_rel.Ptr<IdType>(); const IdType *B_rel_data = B_dim1_per_rel.Ptr<IdType>();
DType *C_data = C.Ptr<DType>(); DType *C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0; int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k, h_col, w_row; int64_t m, n, k, h_col, w_row;
for (int etype = 0; etype < num_rel; ++etype) { for (int etype = 0; etype < num_rel; ++etype) {
assert((a_trans) ? A_rel_data[etype] : A->shape[1] == \ assert(
(b_trans) ? B->shape[1] : B_rel_data[etype]); (a_trans) ? A_rel_data[etype]
: A->shape[1] == (b_trans) ? B->shape[1]
: B_rel_data[etype]);
m = A_rel_data[etype]; // rows of A m = A_rel_data[etype]; // rows of A
n = B->shape[1]; // cols of B n = B->shape[1]; // cols of B
k = B_rel_data[etype]; // rows of B == cols of A k = B_rel_data[etype]; // rows of B == cols of A
...@@ -81,16 +82,17 @@ void gatherMM_SortedEtype(const NDArray A, ...@@ -81,16 +82,17 @@ void gatherMM_SortedEtype(const NDArray A,
NDArray A_trans, B_trans; NDArray A_trans, B_trans;
if (a_trans) { if (a_trans) {
A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx); A_trans = NDArray::Empty({m * k}, A->dtype, A->ctx);
transpose<DType>(A_data + A_offset, static_cast<DType *>(A_trans->data), m, k); transpose<DType>(
A_data + A_offset, static_cast<DType *>(A_trans->data), m, k);
} }
if (b_trans) { if (b_trans) {
B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx); B_trans = NDArray::Empty({k * n}, B->dtype, B->ctx);
transpose<DType>(B_data + B_offset, static_cast<DType *>(B_trans->data), k, n); transpose<DType>(
B_data + B_offset, static_cast<DType *>(B_trans->data), k, n);
} }
if (a_trans || b_trans) { if (a_trans || b_trans) {
int64_t tmp = k; int64_t tmp = k;
if (a_trans) if (a_trans) std::swap(m, k);
std::swap(m, k);
if (b_trans) { if (b_trans) {
k = tmp; k = tmp;
std::swap(n, k); std::swap(n, k);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment