"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5199ee4f7bc6323f9bffae9dd243eb3d9ff19bbb"
Unverified Commit 2576647c authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance] Improve COO to CSR, and sort columns of CSR only when necessary. (#2391)

* Remove double-checking sorted

* Remove sorting of CSR by default

* Update unit test to use unsorted matix

* delete whitespace

* Expand unit tests

* Replace cusparse sort

* Fix row column sorting

* Explicitly don't sort columns

* Fix linting errors

* Fix bit-width calculation

* Fix sorting assertion and unit test

* Fix linting

* Improve CPU COO2CSR

* Remove references

* Rename and add documentation to edge encoding/decoding funcionts

* Fix sorting keys as 64 bit

* Revert cosmetic changes to unit tests

* Update documentation

* Update complexity documentation for coo to csr conversion

* Remove COOIsSorted check in CPU implementation too
parent 0855d255
...@@ -314,9 +314,19 @@ IdArray NonZero(NDArray array); ...@@ -314,9 +314,19 @@ IdArray NonZero(NDArray array);
* is always in int64. * is always in int64.
* *
* \param array Input array. * \param array Input array.
* \param num_bits The number of bits used by the range of values in the array,
* or 0 to use all bits of the type. This is currently only used when sort
* arrays on the GPU.
* \param num_bits The number of bits used in key comparison. The bits are
* right aligned. For example, setting `num_bits` to 8 means using bits from
* `sizeof(IdType) * 8 - num_bits` (inclusive) to `sizeof(IdType) * 8`
* (exclusive). Setting it to a small value could speed up the sorting if the
* underlying sorting algorithm is radix sort (e.g., on GPU). Setting it to
* value of zero, uses full number of bits of the type (sizeof(IdType)*8).
* On CPU, it currently has no effect.
* \return A pair of arrays: sorted values and sorted index to the original position. * \return A pair of arrays: sorted values and sorted index to the original position.
*/ */
std::pair<IdArray, IdArray> Sort(IdArray array); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits = 0);
/*! /*!
* \brief Return a string that prints out some debug information. * \brief Return a string that prints out some debug information.
......
...@@ -286,7 +286,7 @@ IdArray NonZero(NDArray array) { ...@@ -286,7 +286,7 @@ IdArray NonZero(NDArray array) {
return ret; return ret;
} }
std::pair<IdArray, IdArray> Sort(IdArray array) { std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
if (array.NumElements() == 0) { if (array.NumElements() == 0) {
IdArray idx = NewIdArray(0, array->ctx, 64); IdArray idx = NewIdArray(0, array->ctx, 64);
return std::make_pair(array, idx); return std::make_pair(array, idx);
...@@ -294,7 +294,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array) { ...@@ -294,7 +294,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array) {
std::pair<IdArray, IdArray> ret; std::pair<IdArray, IdArray> ret;
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", { ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "Sort", {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, { ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::Sort<XPU, IdType>(array); ret = impl::Sort<XPU, IdType>(array, num_bits);
}); });
}); });
return ret; return ret;
......
...@@ -47,7 +47,7 @@ template <DLDeviceType XPU, typename DType> ...@@ -47,7 +47,7 @@ template <DLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr); IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType> template <DLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array); std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
template <DLDeviceType XPU, typename DType, typename IdType> template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices); NDArray Scatter(NDArray array, IdArray indices);
......
...@@ -161,7 +161,7 @@ namespace aten { ...@@ -161,7 +161,7 @@ namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array) { std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
const int64_t nitem = array->shape[0]; const int64_t nitem = array->shape[0];
IdArray val = array.Clone(); IdArray val = array.Clone();
IdArray idx = aten::Range(0, nitem, 64, array->ctx); IdArray idx = aten::Range(0, nitem, 64, array->ctx);
...@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array) { ...@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array) {
return std::make_pair(val, idx); return std::make_pair(val, idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray); template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray); template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file array/cpu/spmat_op_impl.cc * \file array/cpu/spmat_op_impl.cc
* \brief CPU implementation of COO sparse matrix operators * \brief CPU implementation of COO sparse matrix operators
*/ */
#include <omp.h>
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
...@@ -296,14 +297,16 @@ template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo); ...@@ -296,14 +297,16 @@ template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR ///////////////////////////// ///////////////////////////// COOToCSR /////////////////////////////
// complexity: time O(NNZ), space O(1) // complexity: time O(NNZ), space O(1) if the coo is row sorted,
// time O(NNZ/p + N), space O(NNZ + N*p) otherwise, where p is the number of
// threads.
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
const int64_t N = coo.num_rows; const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0]; const int64_t NNZ = coo.row->shape[0];
const IdType* row_data = static_cast<IdType*>(coo.row->data); const IdType* const row_data = static_cast<IdType*>(coo.row->data);
const IdType* col_data = static_cast<IdType*>(coo.col->data); const IdType* const col_data = static_cast<IdType*>(coo.col->data);
const IdType* data = COOHasData(coo)? static_cast<IdType*>(coo.data->data) : nullptr; const IdType* const data = COOHasData(coo)? static_cast<IdType*>(coo.data->data) : nullptr;
NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx); NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);
NDArray ret_indices; NDArray ret_indices;
...@@ -311,11 +314,6 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -311,11 +314,6 @@ CSRMatrix COOToCSR(COOMatrix coo) {
bool row_sorted = coo.row_sorted; bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted; bool col_sorted = coo.col_sorted;
if (!row_sorted) {
// It is possible that the flag is simply not set (default value is false),
// so we still perform a linear scan to check the flag.
std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
}
if (row_sorted) { if (row_sorted) {
// compute indptr // compute indptr
...@@ -340,32 +338,84 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -340,32 +338,84 @@ CSRMatrix COOToCSR(COOMatrix coo) {
ret_data = coo.data; ret_data = coo.data;
} else { } else {
// compute indptr // compute indptr
IdType* Bp = static_cast<IdType*>(ret_indptr->data); IdType* const Bp = static_cast<IdType*>(ret_indptr->data);
*(Bp++) = 0; Bp[0] = 0;
std::fill(Bp, Bp + N, 0);
for (int64_t i = 0; i < NNZ; ++i) {
Bp[row_data[i]]++;
}
// cumsum
for (int64_t i = 0, cumsum = 0; i < N; ++i) {
const IdType temp = Bp[i];
Bp[i] = cumsum;
cumsum += temp;
}
// compute indices and data // compute indices and data
ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
IdType* Bi = static_cast<IdType*>(ret_indices->data); IdType* const Bi = static_cast<IdType*>(ret_indices->data);
IdType* Bx = static_cast<IdType*>(ret_data->data); IdType* const Bx = static_cast<IdType*>(ret_data->data);
for (int64_t i = 0; i < NNZ; ++i) { // the offset within each row, that each thread will write to
const IdType r = row_data[i]; std::vector<std::vector<IdType>> local_ptrs;
Bi[Bp[r]] = col_data[i]; std::vector<int64_t> thread_prefixsum;
Bx[Bp[r]] = data? data[i] : i;
Bp[r]++; #pragma omp parallel
{
const int num_threads = omp_get_num_threads();
const int thread_id = omp_get_thread_num();
CHECK_LT(thread_id, num_threads);
const int64_t nz_chunk = (NNZ+num_threads-1)/num_threads;
const int64_t nz_start = thread_id*nz_chunk;
const int64_t nz_end = std::min(NNZ, nz_start+nz_chunk);
const int64_t n_chunk = (N+num_threads-1)/num_threads;
const int64_t n_start = thread_id*n_chunk;
const int64_t n_end = std::min(N, n_start+n_chunk);
#pragma omp master
{
local_ptrs.resize(num_threads);
thread_prefixsum.resize(num_threads+1);
}
#pragma omp barrier
local_ptrs[thread_id].resize(N, 0);
for (int64_t i = nz_start; i < nz_end; ++i) {
++local_ptrs[thread_id][row_data[i]];
}
#pragma omp barrier
// compute prefixsum in parallel
int64_t sum = 0;
for (int64_t i = n_start; i < n_end; ++i) {
IdType tmp = 0;
for (int j = 0; j < num_threads; ++j) {
std::swap(tmp, local_ptrs[j][i]);
tmp += local_ptrs[j][i];
}
sum += tmp;
Bp[i+1] = sum;
}
thread_prefixsum[thread_id+1] = sum;
#pragma omp barrier
#pragma omp master
{
for (int64_t i = 0; i < num_threads; ++i) {
thread_prefixsum[i+1] += thread_prefixsum[i];
}
CHECK_EQ(thread_prefixsum[num_threads], NNZ);
}
#pragma omp barrier
sum = thread_prefixsum[thread_id];
for (int64_t i = n_start; i < n_end; ++i) {
Bp[i+1] += sum;
}
#pragma omp barrier
for (int64_t i = nz_start; i < nz_end; ++i) {
const IdType r = row_data[i];
const int64_t index = Bp[r] + local_ptrs[thread_id][r]++;
Bi[index] = col_data[i];
Bx[index] = data ? data[i] : i;
}
} }
CHECK_EQ(Bp[N], NNZ);
} }
return CSRMatrix(coo.num_rows, coo.num_cols, return CSRMatrix(coo.num_rows, coo.num_cols,
......
...@@ -14,7 +14,7 @@ namespace aten { ...@@ -14,7 +14,7 @@ namespace aten {
namespace impl { namespace impl {
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array) { std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
const auto& ctx = array->ctx; const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const int64_t nitems = array->shape[0]; const int64_t nitems = array->shape[0];
...@@ -27,23 +27,27 @@ std::pair<IdArray, IdArray> Sort(IdArray array) { ...@@ -27,23 +27,27 @@ std::pair<IdArray, IdArray> Sort(IdArray array) {
IdType* keys_out = sorted_array.Ptr<IdType>(); IdType* keys_out = sorted_array.Ptr<IdType>();
int64_t* values_out = sorted_idx.Ptr<int64_t>(); int64_t* values_out = sorted_idx.Ptr<int64_t>();
if (num_bits == 0) {
num_bits = sizeof(IdType)*8;
}
// Allocate workspace // Allocate workspace
size_t workspace_size = 0; size_t workspace_size = 0;
cub::DeviceRadixSort::SortPairs(nullptr, workspace_size, cub::DeviceRadixSort::SortPairs(nullptr, workspace_size,
keys_in, keys_out, values_in, values_out, nitems); keys_in, keys_out, values_in, values_out, nitems, 0, num_bits);
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
// Compute // Compute
cub::DeviceRadixSort::SortPairs(workspace, workspace_size, cub::DeviceRadixSort::SortPairs(workspace, workspace_size,
keys_in, keys_out, values_in, values_out, nitems); keys_in, keys_out, values_in, values_out, nitems, 0, num_bits);
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
return std::make_pair(sorted_array, sorted_idx); return std::make_pair(sorted_array, sorted_idx);
} }
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray); template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray); template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -32,12 +32,8 @@ CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) { ...@@ -32,12 +32,8 @@ CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
bool row_sorted = coo.row_sorted; bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted; bool col_sorted = coo.col_sorted;
if (!row_sorted) { if (!row_sorted) {
// It is possible that the flag is simply not set (default value is false), // we only need to sort the rows to perform conversion
// so we still perform a linear scan to check the flag. coo = COOSort(coo, false);
std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
}
if (!row_sorted) {
coo = COOSort(coo);
col_sorted = coo.col_sorted; col_sorted = coo.col_sorted;
} }
...@@ -110,12 +106,7 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) { ...@@ -110,12 +106,7 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
bool row_sorted = coo.row_sorted; bool row_sorted = coo.row_sorted;
bool col_sorted = coo.col_sorted; bool col_sorted = coo.col_sorted;
if (!row_sorted) { if (!row_sorted) {
// It is possible that the flag is simply not set (default value is false), coo = COOSort(coo, false);
// so we still perform a linear scan to check the flag.
std::tie(row_sorted, col_sorted) = COOIsSorted(coo);
}
if (!row_sorted) {
coo = COOSort(coo);
col_sorted = coo.col_sorted; col_sorted = coo.col_sorted;
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
...@@ -16,107 +17,118 @@ namespace impl { ...@@ -16,107 +17,118 @@ namespace impl {
///////////////////////////// COOSort_ ///////////////////////////// ///////////////////////////// COOSort_ /////////////////////////////
template <DLDeviceType XPU, typename IdType> /**
void COOSort_(COOMatrix* coo, bool sort_column) { * @brief Encode row and column IDs into a single scalar per edge.
LOG(FATAL) << "Unreachable codes"; *
* @tparam IdType The type to encode as.
* @param row The row (src) IDs per edge.
* @param col The column (dst) IDs per edge.
* @param nnz The number of edges.
* @param col_bits The number of bits used to encode the destination. The row
* information is packed into the remaining bits.
* @param key The encoded edges (output).
*/
template <typename IdType>
__global__ void _COOEncodeEdgesKernel(
const IdType* const row, const IdType* const col,
const int64_t nnz, const int col_bits, IdType * const key) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) {
key[tx] = row[tx] << col_bits | col[tx];
}
} }
template <> /**
void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column) { * @brief Decode row and column IDs from the encoded edges.
// TODO(minjie): Current implementation is based on cusparse which only supports *
// int32_t. To support int64_t, we could use the Radix sort algorithm provided * @tparam IdType The type the edges are encoded as.
// by CUB. * @param key The encoded edges.
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); * @param nnz The number of edges.
auto device = runtime::DeviceAPI::Get(coo->row->ctx); * @param col_bits The number of bits used to store the column/dst ID.
// allocate cusparse handle if needed * @param row The row (src) IDs per edge (output).
if (!thr_entry->cusparse_handle) { * @param col The col (dst) IDs per edge (output).
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); */
template <typename IdType>
__global__ void _COODecodeEdgesKernel(
const IdType* const key, const int64_t nnz, const int col_bits,
IdType * const row, IdType * const col) {
int64_t tx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (tx < nnz) {
const IdType k = key[tx];
row[tx] = k >> col_bits;
col[tx] = k & ((1 << col_bits) - 1);
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); }
NDArray row = coo->row;
NDArray col = coo->col;
if (!aten::COOHasData(*coo))
coo->data = aten::Range(0, row->shape[0], row->dtype.bits, row->ctx);
NDArray data = coo->data;
int32_t* row_ptr = static_cast<int32_t*>(row->data);
int32_t* col_ptr = static_cast<int32_t*>(col->data);
int32_t* data_ptr = static_cast<int32_t*>(data->data);
// sort row
size_t workspace_size = 0;
CUSPARSE_CALL(cusparseXcoosort_bufferSizeExt(
thr_entry->cusparse_handle,
coo->num_rows, coo->num_cols,
row->shape[0],
row_ptr,
col_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
CUSPARSE_CALL(cusparseXcoosortByRow(
thr_entry->cusparse_handle,
coo->num_rows, coo->num_cols,
row->shape[0],
row_ptr,
col_ptr,
data_ptr,
workspace));
device->FreeWorkspace(row->ctx, workspace);
if (sort_column) {
// First create a row indptr array and then call csrsort
int32_t* indptr = static_cast<int32_t*>( template<typename T>
device->AllocWorkspace(row->ctx, (coo->num_rows + 1) * sizeof(int32_t))); int _NumberOfBits(const T& range) {
CUSPARSE_CALL(cusparseXcoo2csr( if (range <= 1) {
thr_entry->cusparse_handle, // ranges of 0 or 1 require no bits to store
row_ptr, return 0;
row->shape[0], }
coo->num_rows,
indptr, int bits = 1;
CUSPARSE_INDEX_BASE_ZERO)); while (bits < sizeof(T)*8 && (1 << bits) < range) {
CUSPARSE_CALL(cusparseXcsrsort_bufferSizeExt( ++bits;
thr_entry->cusparse_handle,
coo->num_rows,
coo->num_cols,
row->shape[0],
indptr,
col_ptr,
&workspace_size));
void* workspace = device->AllocWorkspace(row->ctx, workspace_size);
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseXcsrsort(
thr_entry->cusparse_handle,
coo->num_rows,
coo->num_cols,
row->shape[0],
descr,
indptr,
col_ptr,
data_ptr,
workspace));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
device->FreeWorkspace(row->ctx, workspace);
device->FreeWorkspace(row->ctx, indptr);
} }
coo->row_sorted = true; CHECK_EQ((range-1) >> bits, 0);
coo->col_sorted = sort_column; CHECK_NE((range-1) >> (bits-1), 0);
return bits;
} }
template <> template <DLDeviceType XPU, typename IdType>
void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column) { void COOSort_(COOMatrix* coo, bool sort_column) {
// Always sort the COO to be both row and column sorted. auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
IdArray pos = coo->row * coo->num_cols + coo->col; const int row_bits = _NumberOfBits(coo->num_rows);
const auto& sorted = Sort(pos);
coo->row = sorted.first / coo->num_cols; const int64_t nnz = coo->row->shape[0];
coo->col = sorted.first % coo->num_cols; if (sort_column) {
if (aten::COOHasData(*coo)) const int col_bits = _NumberOfBits(coo->num_cols);
coo->data = IndexSelect(coo->data, sorted.second); const int num_bits = row_bits + col_bits;
else
coo->data = AsNumBits(sorted.second, coo->row->dtype.bits); const int nt = 256;
coo->row_sorted = coo->col_sorted = true; const int nb = (nnz+nt-1)/nt;
CHECK(static_cast<int64_t>(nb)*nt >= nnz);
IdArray pos = aten::NewIdArray(nnz, coo->row->ctx, coo->row->dtype.bits);
CUDA_KERNEL_CALL(_COOEncodeEdgesKernel, nb, nt, 0, thr_entry->stream,
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>(),
nnz, col_bits, pos.Ptr<IdType>());
auto sorted = Sort(pos, num_bits);
CUDA_KERNEL_CALL(_COODecodeEdgesKernel, nb, nt, 0, thr_entry->stream,
sorted.first.Ptr<IdType>(), nnz, col_bits,
coo->row.Ptr<IdType>(), coo->col.Ptr<IdType>());
if (aten::COOHasData(*coo))
coo->data = IndexSelect(coo->data, sorted.second);
else
coo->data = AsNumBits(sorted.second, coo->row->dtype.bits);
coo->row_sorted = coo->col_sorted = true;
} else {
const int num_bits = row_bits;
auto sorted = Sort(coo->row, num_bits);
coo->row = sorted.first;
coo->col = IndexSelect(coo->col, sorted.second);
if (aten::COOHasData(*coo))
coo->data = IndexSelect(coo->data, sorted.second);
else
coo->data = AsNumBits(sorted.second, coo->row->dtype.bits);
coo->row_sorted = true;
}
} }
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column); template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
......
...@@ -653,7 +653,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -653,7 +653,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(order.empty() || order == std::string("srcdst")) CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\"," << "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\"."; << " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false); auto coo = aten::CSRToCOO(adj_, false);
if (order == std::string("srcdst")) {
// make sure the coo is sorted if an order is requested
coo = aten::COOSort(coo, true);
}
return EdgeArray{coo.row, coo.col, coo.data}; return EdgeArray{coo.row, coo.col, coo.data};
} }
...@@ -1308,7 +1312,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1308,7 +1312,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
if (!in_csr_->defined()) { if (!in_csr_->defined()) {
if (out_csr_->defined()) { if (out_csr_->defined()) {
const auto& newadj = aten::CSRSort(aten::CSRTranspose(out_csr_->adj())); const auto& newadj = aten::CSRTranspose(out_csr_->adj());
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
...@@ -1316,8 +1320,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1316,8 +1320,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_->defined()) << "None of CSR, COO exist"; CHECK(coo_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::CSRSort(aten::COOToCSR( const auto& newadj = aten::COOToCSR(
aten::COOTranspose(coo_->adj()))); aten::COOTranspose(coo_->adj()));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
...@@ -1337,7 +1341,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1337,7 +1341,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
if (!out_csr_->defined()) { if (!out_csr_->defined()) {
if (in_csr_->defined()) { if (in_csr_->defined()) {
const auto& newadj = aten::CSRSort(aten::CSRTranspose(in_csr_->adj())); const auto& newadj = aten::CSRTranspose(in_csr_->adj());
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
...@@ -1345,7 +1349,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1345,7 +1349,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_->defined()) << "None of CSR, COO exist"; CHECK(coo_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::CSRSort(aten::COOToCSR(coo_->adj())); const auto& newadj = aten::COOToCSR(coo_->adj());
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
......
...@@ -38,6 +38,30 @@ aten::CSRMatrix CSR2(DLContext ctx = CTX) { ...@@ -38,6 +38,30 @@ aten::CSRMatrix CSR2(DLContext ctx = CTX) {
false); false);
} }
template <typename IDX>
aten::CSRMatrix CSR3(DLContext ctx = CTX) {
// has duplicate entries and the columns are not sorted
// [[0, 1, 1, 1, 0],
// [1, 0, 0, 0, 0],
// [0, 0, 1, 1, 0],
// [0, 0, 0, 0, 0],
// [1, 1, 1, 0, 0],
// [0, 0, 0, 1, 0]],
// [0, 0, 0, 0, 0]],
// [1, 2, 1, 1, 0]],
// [0, 1, 0, 0, 1]],
// data: [5, 2, 0, 3, 1, 4, 8, 7, 6, 9, 12, 13, 11, 10, 14, 15, 16]
return aten::CSRMatrix(
9, 6,
aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6, 9, 10, 10, 15, 17}),
sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({3, 2, 1, 0, 2, 3, 1, 2, 0, 3, 1,
2, 1, 3, 0, 5, 1}), sizeof(IDX)*8, ctx),
aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4, 6, 8, 7, 9, 13,
10, 11, 14, 12, 16, 15}), sizeof(IDX)*8, ctx),
false);
}
template <typename IDX> template <typename IDX>
aten::COOMatrix COO1(DLContext ctx = CTX) { aten::COOMatrix COO1(DLContext ctx = CTX) {
// [[0, 1, 1, 0, 0], // [[0, 1, 1, 0, 0],
...@@ -115,7 +139,7 @@ aten::COOMatrix COO3(DLContext ctx) { ...@@ -115,7 +139,7 @@ aten::COOMatrix COO3(DLContext ctx) {
} // namespace } // namespace
template <typename IDX> template <typename IDX>
void _TestCSRIsNonZero(DLContext ctx) { void _TestCSRIsNonZero1(DLContext ctx) {
auto csr = CSR1<IDX>(ctx); auto csr = CSR1<IDX>(ctx);
ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1)); ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));
ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0)); ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));
...@@ -126,12 +150,28 @@ void _TestCSRIsNonZero(DLContext ctx) { ...@@ -126,12 +150,28 @@ void _TestCSRIsNonZero(DLContext ctx) {
ASSERT_TRUE(ArrayEQ<IDX>(x, tx)); ASSERT_TRUE(ArrayEQ<IDX>(x, tx));
} }
template <typename IDX>
void _TestCSRIsNonZero2(DLContext ctx) {
auto csr = CSR3<IDX>(ctx);
ASSERT_TRUE(aten::CSRIsNonZero(csr, 0, 1));
ASSERT_FALSE(aten::CSRIsNonZero(csr, 0, 0));
IdArray r = aten::VecToIdArray(std::vector<IDX>({0, 0, 0, 0, 0, }), sizeof(IDX)*8, ctx);
IdArray c = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 3, 4, }), sizeof(IDX)*8, ctx);
IdArray x = aten::CSRIsNonZero(csr, r, c);
IdArray tx = aten::VecToIdArray(std::vector<IDX>({0, 1, 1, 1, 0}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x, tx)) << " x = " << x << ", tx = " << tx;
}
TEST(SpmatTest, TestCSRIsNonZero) { TEST(SpmatTest, TestCSRIsNonZero) {
_TestCSRIsNonZero<int32_t>(CPU); _TestCSRIsNonZero1<int32_t>(CPU);
_TestCSRIsNonZero<int64_t>(CPU); _TestCSRIsNonZero1<int64_t>(CPU);
_TestCSRIsNonZero2<int32_t>(CPU);
_TestCSRIsNonZero2<int64_t>(CPU);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
_TestCSRIsNonZero<int32_t>(GPU); _TestCSRIsNonZero1<int32_t>(GPU);
_TestCSRIsNonZero<int64_t>(GPU); _TestCSRIsNonZero1<int64_t>(GPU);
_TestCSRIsNonZero2<int32_t>(GPU);
_TestCSRIsNonZero2<int64_t>(GPU);
#endif #endif
} }
...@@ -382,7 +422,7 @@ TEST(SpmatTest, TestCSRSliceRows) { ...@@ -382,7 +422,7 @@ TEST(SpmatTest, TestCSRSliceRows) {
} }
template <typename IDX> template <typename IDX>
void _TestCSRSliceMatrix(DLContext ctx) { void _TestCSRSliceMatrix1(DLContext ctx) {
auto csr = CSR2<IDX>(ctx); auto csr = CSR2<IDX>(ctx);
{ {
// square // square
...@@ -439,12 +479,76 @@ void _TestCSRSliceMatrix(DLContext ctx) { ...@@ -439,12 +479,76 @@ void _TestCSRSliceMatrix(DLContext ctx) {
} }
} }
template <typename IDX>
void _TestCSRSliceMatrix2(DLContext ctx) {
auto csr = CSR3<IDX>(ctx);
{
// square
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 3}), sizeof(IDX)*8, ctx);
auto c = aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX)*8, ctx);
auto x = aten::CSRSliceMatrix(csr, r, c);
// [[1, 1, 1],
// [0, 0, 0],
// [0, 0, 0]]
// data: [5, 2, 0]
ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 3);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 3, 3, 3}), sizeof(IDX)*8, ctx);
// indexes are in reverse order in CSR3
auto ti = aten::VecToIdArray(std::vector<IDX>({2, 1, 0}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({0, 2, 5}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
{
// non-square
auto r = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, ctx);
auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
auto x = aten::CSRSliceMatrix(csr, r, c);
// [[0, 1],
// [1, 0],
// [0, 0]]
// data: [0, 3]
ASSERT_EQ(x.num_rows, 3);
ASSERT_EQ(x.num_cols, 2);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 1, 2, 2}), sizeof(IDX)*8, ctx);
auto ti = aten::VecToIdArray(std::vector<IDX>({1, 0}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({5, 3}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
{
// empty slice
auto r = aten::VecToIdArray(std::vector<IDX>({2, 3}), sizeof(IDX)*8, ctx);
auto c = aten::VecToIdArray(std::vector<IDX>({0, 1}), sizeof(IDX)*8, ctx);
auto x = aten::CSRSliceMatrix(csr, r, c);
// [[0, 0],
// [0, 0]]
// data: []
ASSERT_EQ(x.num_rows, 2);
ASSERT_EQ(x.num_cols, 2);
auto tp = aten::VecToIdArray(std::vector<IDX>({0, 0, 0}), sizeof(IDX)*8, ctx);
auto ti = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
auto td = aten::VecToIdArray(std::vector<IDX>({}), sizeof(IDX)*8, ctx);
ASSERT_TRUE(ArrayEQ<IDX>(x.indptr, tp));
ASSERT_TRUE(ArrayEQ<IDX>(x.indices, ti));
ASSERT_TRUE(ArrayEQ<IDX>(x.data, td));
}
}
TEST(SpmatTest, CSRSliceMatrix) { TEST(SpmatTest, CSRSliceMatrix) {
_TestCSRSliceMatrix<int32_t>(CPU); _TestCSRSliceMatrix1<int32_t>(CPU);
_TestCSRSliceMatrix<int64_t>(CPU); _TestCSRSliceMatrix1<int64_t>(CPU);
_TestCSRSliceMatrix2<int32_t>(CPU);
_TestCSRSliceMatrix2<int64_t>(CPU);
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
_TestCSRSliceMatrix<int32_t>(GPU); _TestCSRSliceMatrix1<int32_t>(GPU);
_TestCSRSliceMatrix<int64_t>(GPU); _TestCSRSliceMatrix1<int64_t>(GPU);
_TestCSRSliceMatrix2<int32_t>(GPU);
_TestCSRSliceMatrix2<int64_t>(GPU);
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment