Unverified Commit ac74233c authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix csrtocoo (#1324)

parent 9caff617
...@@ -392,14 +392,13 @@ template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr); ...@@ -392,14 +392,13 @@ template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);
// complexity: time O(NNZ), space O(1) // complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
CHECK(CSRHasData(csr)) << "missing data array.";
const int64_t N = csr.num_rows; const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols; const int64_t M = csr.num_cols;
const int64_t nnz = csr.indices->shape[0]; const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
// data array should have the same type as the indices arrays // data array should have the same type as the indices arrays
const IdType* data = static_cast<IdType*>(csr.data->data); const IdType* data = CSRHasData(csr) ? static_cast<IdType*>(csr.data->data) : nullptr;
NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx); NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
IdType* ret_row_data = static_cast<IdType*>(ret_row->data); IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
...@@ -408,8 +407,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) { ...@@ -408,8 +407,8 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
for (IdType row = 0; row < N; ++row) { for (IdType row = 0; row < N; ++row) {
for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) { for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {
const IdType col = indices_data[j]; const IdType col = indices_data[j];
ret_row_data[data[j]] = row; ret_row_data[data ? data[j] : j] = row;
ret_col_data[data[j]] = col; ret_col_data[data ? data[j] : j] = col;
} }
} }
return COOMatrix(N, M, ret_row, ret_col); return COOMatrix(N, M, ret_row, ret_col);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment