/** * Copyright (c) 2019 by Contributors * @file array/cpu/spmat_op_impl.cc * @brief CPU implementation of COO sparse matrix operators */ #include #include #include #include #include #include #include #include "array_utils.h" namespace dgl { using runtime::NDArray; using runtime::parallel_for; namespace aten { namespace impl { /** * TODO(BarclayII): * For row-major sorted COOs, we have faster implementation with binary search, * sorted search, etc. Later we should benchmark how much we can gain with * sorted COOs on hypersparse graphs. */ ///////////////////////////// COOIsNonZero ///////////////////////////// template bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col; const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); for (int64_t i = 0; i < coo.row->shape[0]; ++i) { if (coo_row_data[i] == row && coo_col_data[i] == col) return true; } return false; } template bool COOIsNonZero(COOMatrix, int64_t, int64_t); template bool COOIsNonZero(COOMatrix, int64_t, int64_t); template NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { const auto rowlen = row->shape[0]; const auto collen = col->shape[0]; const auto rstlen = std::max(rowlen, collen); NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx); IdType *rst_data = static_cast(rst->data); const IdType *row_data = static_cast(row->data); const IdType *col_data = static_cast(col->data); const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t kmax = std::max(rowlen, collen); parallel_for(0, kmax, [=](size_t b, size_t e) { for (auto k = b; k < e; ++k) { int64_t i = row_stride * k; int64_t j = col_stride * k; rst_data[k] = COOIsNonZero(coo, row_data[i], col_data[j]) ? 1 : 0; } }); return rst; } template NDArray COOIsNonZero(COOMatrix, NDArray, NDArray); template NDArray COOIsNonZero(COOMatrix, NDArray, NDArray); ///////////////////////////// COOHasDuplicate ///////////////////////////// template bool COOHasDuplicate(COOMatrix coo) { std::unordered_set, PairHash> hashmap; const IdType *src_data = static_cast(coo.row->data); const IdType *dst_data = static_cast(coo.col->data); const auto nnz = coo.row->shape[0]; for (IdType eid = 0; eid < nnz; ++eid) { const auto &p = std::make_pair(src_data[eid], dst_data[eid]); if (hashmap.count(p)) { return true; } else { hashmap.insert(p); } } return false; } template bool COOHasDuplicate(COOMatrix coo); template bool COOHasDuplicate(COOMatrix coo); ///////////////////////////// COOGetRowNNZ ///////////////////////////// template int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; const IdType *coo_row_data = static_cast(coo.row->data); int64_t result = 0; for (int64_t i = 0; i < coo.row->shape[0]; ++i) { if (coo_row_data[i] == row) ++result; } return result; } template int64_t COOGetRowNNZ(COOMatrix, int64_t); template int64_t COOGetRowNNZ(COOMatrix, int64_t); template NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) { CHECK_SAME_DTYPE(coo.col, rows); const auto len = rows->shape[0]; const IdType *vid_data = static_cast(rows->data); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); IdType *rst_data = static_cast(rst->data); #pragma omp parallel for for (int64_t i = 0; i < len; ++i) { rst_data[i] = COOGetRowNNZ(coo, vid_data[i]); } return rst; } template NDArray COOGetRowNNZ(COOMatrix, NDArray); template NDArray COOGetRowNNZ(COOMatrix, NDArray); ////////////////////////// COOGetRowDataAndIndices ///////////////////////////// template std::pair COOGetRowDataAndIndices( COOMatrix coo, int64_t row) { CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); const IdType *coo_data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; std::vector indices; std::vector data; for (int64_t i = 0; i < coo.row->shape[0]; ++i) { if (coo_row_data[i] == row) { indices.push_back(coo_col_data[i]); data.push_back(coo_data ? coo_data[i] : i); } } return std::make_pair( NDArray::FromVector(data), NDArray::FromVector(indices)); } template std::pair COOGetRowDataAndIndices( COOMatrix, int64_t); template std::pair COOGetRowDataAndIndices( COOMatrix, int64_t); ///////////////////////////// COOGetData ///////////////////////////// template IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) { const int64_t rowlen = rows->shape[0]; const int64_t collen = cols->shape[0]; CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1)) << "Invalid row and col Id array:" << rows << " " << cols; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const IdType *row_data = rows.Ptr(); const IdType *col_data = cols.Ptr(); const IdType *coo_row = coo.row.Ptr(); const IdType *coo_col = coo.col.Ptr(); const IdType *data = COOHasData(coo) ? coo.data.Ptr() : nullptr; const int64_t nnz = coo.row->shape[0]; const int64_t retlen = std::max(rowlen, collen); IdArray ret = Full(-1, retlen, rows->dtype.bits, rows->ctx); IdType *ret_data = ret.Ptr(); // TODO(minjie): We might need to consider sorting the COO beforehand // especially when the number of (row, col) pairs is large. Need more // benchmarks to justify the choice. if (coo.row_sorted) { parallel_for(0, retlen, [&](size_t b, size_t e) { for (auto p = b; p < e; ++p) { const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; auto it = std::lower_bound(coo_row, coo_row + nnz, row_id); for (; it < coo_row + nnz && *it == row_id; ++it) { const auto idx = it - coo_row; if (coo_col[idx] == col_id) { ret_data[p] = data ? data[idx] : idx; break; } } } }); } else { #pragma omp parallel for for (int64_t p = 0; p < retlen; ++p) { const IdType row_id = row_data[p * row_stride], col_id = col_data[p * col_stride]; for (int64_t idx = 0; idx < nnz; ++idx) { if (coo_row[idx] == row_id && coo_col[idx] == col_id) { ret_data[p] = data ? data[idx] : idx; break; } } } } return ret; } template IdArray COOGetData(COOMatrix, IdArray, IdArray); template IdArray COOGetData(COOMatrix, IdArray, IdArray); ///////////////////////////// COOGetDataAndIndices ///////////////////////////// template std::vector COOGetDataAndIndices( COOMatrix coo, NDArray rows, NDArray cols) { CHECK_SAME_DTYPE(coo.col, rows); CHECK_SAME_DTYPE(coo.col, cols); const int64_t rowlen = rows->shape[0]; const int64_t collen = cols->shape[0]; const int64_t len = std::max(rowlen, collen); CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1)) << "Invalid row and col id array."; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const IdType *row_data = static_cast(rows->data); const IdType *col_data = static_cast(cols->data); const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); const IdType *data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; std::vector ret_rows, ret_cols; std::vector ret_data; ret_rows.reserve(len); ret_cols.reserve(len); ret_data.reserve(len); // NOTE(BarclayII): With a small number of lookups, linear scan is faster. // The threshold 200 comes from benchmarking both algorithms on a P3.8x // instance. I also tried sorting plus binary search. The speed gain is only // significant for medium-sized graphs and lookups, so I didn't include it. if (len >= 200) { // TODO(BarclayII) Ideally we would want to cache this object. However I'm // not sure what is the best way to do so since this object is valid for CPU // only. std::unordered_multimap, IdType, PairHash> pair_map; pair_map.reserve(coo.row->shape[0]); for (int64_t k = 0; k < coo.row->shape[0]; ++k) pair_map.emplace( std::make_pair(coo_row_data[k], coo_col_data[k]), data ? data[k] : k); for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { const IdType row_id = row_data[i], col_id = col_data[j]; CHECK(row_id >= 0 && row_id < coo.num_rows) << "Invalid row index: " << row_id; CHECK(col_id >= 0 && col_id < coo.num_cols) << "Invalid col index: " << col_id; auto range = pair_map.equal_range({row_id, col_id}); for (auto it = range.first; it != range.second; ++it) { ret_rows.push_back(row_id); ret_cols.push_back(col_id); ret_data.push_back(it->second); } } } else { for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { const IdType row_id = row_data[i], col_id = col_data[j]; CHECK(row_id >= 0 && row_id < coo.num_rows) << "Invalid row index: " << row_id; CHECK(col_id >= 0 && col_id < coo.num_cols) << "Invalid col index: " << col_id; for (int64_t k = 0; k < coo.row->shape[0]; ++k) { if (coo_row_data[k] == row_id && coo_col_data[k] == col_id) { ret_rows.push_back(row_id); ret_cols.push_back(col_id); ret_data.push_back(data ? data[k] : k); } } } } return { NDArray::FromVector(ret_rows), NDArray::FromVector(ret_cols), NDArray::FromVector(ret_data)}; } template std::vector COOGetDataAndIndices( COOMatrix coo, NDArray rows, NDArray cols); template std::vector COOGetDataAndIndices( COOMatrix coo, NDArray rows, NDArray cols); ///////////////////////////// COOTranspose ///////////////////////////// template COOMatrix COOTranspose(COOMatrix coo) { return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data}; } template COOMatrix COOTranspose(COOMatrix coo); template COOMatrix COOTranspose(COOMatrix coo); ///////////////////////////// COOToCSR ///////////////////////////// namespace { template CSRMatrix SortedCOOToCSR(const COOMatrix &coo) { const int64_t N = coo.num_rows; const int64_t NNZ = coo.row->shape[0]; const IdType *const row_data = static_cast(coo.row->data); const IdType *const data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx); NDArray ret_indices = coo.col; NDArray ret_data = data == nullptr ? NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx) : coo.data; // compute indptr IdType *const Bp = static_cast(ret_indptr->data); Bp[0] = 0; IdType *const fill_data = data ? nullptr : static_cast(ret_data->data); if (NNZ > 0) { auto num_threads = omp_get_max_threads(); parallel_for(0, num_threads, [&](int b, int e) { for (auto thread_id = b; thread_id < e; ++thread_id) { // We partition the set the of non-zeros among the threads const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads; const int64_t nz_start = thread_id * nz_chunk; const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk); // Each thread searchs the row array for a change, and marks it's // location in Bp. Threads, other than the first, start at the last // index covered by the previous, in order to detect changes in the row // array between thread partitions. This means that each thread after // the first, searches the range [nz_start-1, nz_end). That is, // if we had 10 non-zeros, and 4 threads, the indexes searched by each // thread would be: // 0: [0, 1, 2] // 1: [2, 3, 4, 5] // 2: [5, 6, 7, 8] // 3: [8, 9] // // That way, if the row array were [0, 0, 1, 2, 2, 2, 4, 5, 5, 6], each // change in row would be captured by one thread: // // 0: [0, 0, 1] - row 0 // 1: [1, 2, 2, 2] - row 1 // 2: [2, 4, 5, 5] - rows 2, 3, and 4 // 3: [5, 6] - rows 5 and 6 // int64_t row = 0; if (nz_start < nz_end) { row = nz_start == 0 ? 0 : row_data[nz_start - 1]; for (int64_t i = nz_start; i < nz_end; ++i) { while (row != row_data[i]) { ++row; Bp[row] = i; } } // We will not detect the row change for the last row, nor any empty // rows at the end of the matrix, so the last active thread needs // mark all remaining rows in Bp with NNZ. if (nz_end == NNZ) { while (row < N) { ++row; Bp[row] = NNZ; } } if (fill_data) { // TODO(minjie): Many of our current implementation assumes that CSR // must have // a data array. This is a temporary workaround. Remove this // after: // - The old immutable graph implementation is deprecated. // - The old binary reduce kernel is deprecated. std::iota(fill_data + nz_start, fill_data + nz_end, nz_start); } } } }); } else { std::fill(Bp, Bp + N + 1, 0); } return CSRMatrix( coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data, coo.col_sorted); } template CSRMatrix UnSortedSparseCOOToCSR(const COOMatrix &coo) { const int64_t N = coo.num_rows; const int64_t NNZ = coo.row->shape[0]; const IdType *const row_data = static_cast(coo.row->data); const IdType *const col_data = static_cast(coo.col->data); const IdType *const data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx); NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); IdType *const Bp = static_cast(ret_indptr->data); Bp[0] = 0; IdType *const Bi = static_cast(ret_indices->data); IdType *const Bx = static_cast(ret_data->data); // store sorted data and original index. NDArray sorted_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); NDArray sorted_data_pos = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); IdType *const Sx = static_cast(sorted_data->data); IdType *const Si = static_cast(sorted_data_pos->data); // record row_idx in each thread. std::vector> p_sum; #pragma omp parallel { const int num_threads = omp_get_num_threads(); const int thread_id = omp_get_thread_num(); CHECK_LT(thread_id, num_threads); const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads; const int64_t nz_start = thread_id * nz_chunk; const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk); const int64_t n_chunk = (N + num_threads - 1) / num_threads; const int64_t n_start = thread_id * n_chunk; const int64_t n_end = std::min(N, n_start + n_chunk); // init Bp as zero and one shift is always applied when accessing Bp as // its length is N+1. for (auto i = n_start; i < n_end; ++i) { Bp[i + 1] = 0; } #pragma omp master { p_sum.resize(num_threads); } #pragma omp barrier // iterate on NNZ data and count row_idx. p_sum[thread_id].resize(num_threads, 0); for (auto i = nz_start; i < nz_end; ++i) { const int64_t row_idx = row_data[i]; const int64_t row_thread_id = row_idx / n_chunk; ++p_sum[thread_id][row_thread_id]; } #pragma omp barrier #pragma omp master // accumulate row_idx. { int64_t cum = 0; for (size_t j = 0; j < p_sum.size(); ++j) { for (size_t i = 0; i < p_sum.size(); ++i) { auto tmp = p_sum[i][j]; p_sum[i][j] = cum; cum += tmp; } } CHECK_EQ(cum, NNZ); } #pragma omp barrier // sort data by row_idx and place into Sx/Si. std::vector data_pos(p_sum[thread_id]); for (auto i = nz_start; i < nz_end; ++i) { const int64_t row_idx = row_data[i]; const int64_t row_thread_id = row_idx / n_chunk; const int64_t pos = data_pos[row_thread_id]++; Sx[pos] = data == nullptr ? i : data[i]; Si[pos] = i; } #pragma omp barrier // Now we're able to do coo2csr on sorted data in each thread in parallel. // compute data number on each row_idx. const int64_t i_start = p_sum[0][thread_id]; const int64_t i_end = thread_id + 1 == num_threads ? NNZ : p_sum[0][thread_id + 1]; for (auto i = i_start; i < i_end; ++i) { const int64_t row_idx = row_data[Si[i]]; ++Bp[row_idx + 1]; } // accumulate on each row IdType cumsum = 0; for (auto i = n_start; i < n_end; ++i) { const auto tmp = Bp[i + 1]; Bp[i + 1] = cumsum; cumsum += tmp; } // update Bi/Bp/Bx for (auto i = i_start; i < i_end; ++i) { const int64_t row_idx = row_data[Si[i]]; const int64_t dest = (Bp[row_idx + 1]++) + i_start; Bi[dest] = col_data[Si[i]]; Bx[dest] = Sx[i]; } for (auto i = n_start; i < n_end; ++i) { Bp[i + 1] += i_start; } } return CSRMatrix( coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data, coo.col_sorted); } template CSRMatrix UnSortedDenseCOOToCSR(const COOMatrix &coo) { const int64_t N = coo.num_rows; const int64_t NNZ = coo.row->shape[0]; const IdType *const row_data = static_cast(coo.row->data); const IdType *const col_data = static_cast(coo.col->data); const IdType *const data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx); NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); IdType *const Bp = static_cast(ret_indptr->data); Bp[0] = 0; IdType *const Bi = static_cast(ret_indices->data); IdType *const Bx = static_cast(ret_data->data); // the offset within each row, that each thread will write to std::vector> local_ptrs; std::vector thread_prefixsum; #pragma omp parallel { const int num_threads = omp_get_num_threads(); const int thread_id = omp_get_thread_num(); CHECK_LT(thread_id, num_threads); const int64_t nz_chunk = (NNZ + num_threads - 1) / num_threads; const int64_t nz_start = thread_id * nz_chunk; const int64_t nz_end = std::min(NNZ, nz_start + nz_chunk); const int64_t n_chunk = (N + num_threads - 1) / num_threads; const int64_t n_start = thread_id * n_chunk; const int64_t n_end = std::min(N, n_start + n_chunk); #pragma omp master { local_ptrs.resize(num_threads); thread_prefixsum.resize(num_threads + 1); } #pragma omp barrier local_ptrs[thread_id].resize(N, 0); for (int64_t i = nz_start; i < nz_end; ++i) { ++local_ptrs[thread_id][row_data[i]]; } #pragma omp barrier // compute prefixsum in parallel int64_t sum = 0; for (int64_t i = n_start; i < n_end; ++i) { IdType tmp = 0; for (int j = 0; j < num_threads; ++j) { std::swap(tmp, local_ptrs[j][i]); tmp += local_ptrs[j][i]; } sum += tmp; Bp[i + 1] = sum; } thread_prefixsum[thread_id + 1] = sum; #pragma omp barrier #pragma omp master { for (int64_t i = 0; i < num_threads; ++i) { thread_prefixsum[i + 1] += thread_prefixsum[i]; } CHECK_EQ(thread_prefixsum[num_threads], NNZ); } #pragma omp barrier sum = thread_prefixsum[thread_id]; for (int64_t i = n_start; i < n_end; ++i) { Bp[i + 1] += sum; } #pragma omp barrier for (int64_t i = nz_start; i < nz_end; ++i) { const IdType r = row_data[i]; const int64_t index = Bp[r] + local_ptrs[thread_id][r]++; Bi[index] = col_data[i]; Bx[index] = data ? data[i] : i; } } CHECK_EQ(Bp[N], NNZ); return CSRMatrix( coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data, coo.col_sorted); } } // namespace /** Implementation and Complexity details. N: num_nodes, NNZ: num_edges, P: num_threads. 1. If row is sorted in COO, SortedCOOToCSR<> is applied. Time: O(NNZ/P). Space: O(1). 2. If row is NOT sorted in COO and graph is sparse (low average degree), UnSortedSparseCOOToCSR<> is applied. Time: O(NNZ/P + N/P + P^2), space O(NNZ + P^2). 3. If row is NOT sorted in COO and graph is dense (medium/high average degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ + N*P). */ template CSRMatrix COOToCSR(COOMatrix coo) { if (!coo.row_sorted) { const int64_t num_threads = omp_get_num_threads(); const int64_t num_nodes = coo.num_rows; const int64_t num_edges = coo.row->shape[0]; // Besides graph density, num_threads is also taken into account. Below // criteria is set-up according to the time/space complexity difference // between these 2 algorithms. if (num_threads * num_nodes > 4 * num_edges) { return UnSortedSparseCOOToCSR(coo); } return UnSortedDenseCOOToCSR(coo); } return SortedCOOToCSR(coo); } template CSRMatrix COOToCSR(COOMatrix coo); template CSRMatrix COOToCSR(COOMatrix coo); ///////////////////////////// COOSliceRows ///////////////////////////// template COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { // TODO(minjie): use binary search when coo.row_sorted is true CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start; CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end; const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); const IdType *coo_data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; std::vector ret_row, ret_col; std::vector ret_data; for (int64_t i = 0; i < coo.row->shape[0]; ++i) { const IdType row_id = coo_row_data[i]; const IdType col_id = coo_col_data[i]; if (row_id < end && row_id >= start) { ret_row.push_back(row_id - start); ret_col.push_back(col_id); ret_data.push_back(coo_data ? coo_data[i] : i); } } return COOMatrix( end - start, coo.num_cols, NDArray::FromVector(ret_row), NDArray::FromVector(ret_col), NDArray::FromVector(ret_data), coo.row_sorted, coo.col_sorted); } template COOMatrix COOSliceRows(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); const IdType *coo_data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; std::vector ret_row, ret_col; std::vector ret_data; IdHashMap hashmap(rows); for (int64_t i = 0; i < coo.row->shape[0]; ++i) { const IdType row_id = coo_row_data[i]; const IdType col_id = coo_col_data[i]; const IdType mapped_row_id = hashmap.Map(row_id, -1); if (mapped_row_id != -1) { ret_row.push_back(mapped_row_id); ret_col.push_back(col_id); ret_data.push_back(coo_data ? coo_data[i] : i); } } return COOMatrix{ rows->shape[0], coo.num_cols, NDArray::FromVector(ret_row), NDArray::FromVector(ret_col), NDArray::FromVector(ret_data), coo.row_sorted, coo.col_sorted}; } template COOMatrix COOSliceRows(COOMatrix, NDArray); template COOMatrix COOSliceRows(COOMatrix, NDArray); ///////////////////////////// COOSliceMatrix ///////////////////////////// template COOMatrix COOSliceMatrix( COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) { const IdType *coo_row_data = static_cast(coo.row->data); const IdType *coo_col_data = static_cast(coo.col->data); const IdType *coo_data = COOHasData(coo) ? static_cast(coo.data->data) : nullptr; IdHashMap row_map(rows), col_map(cols); std::vector ret_row, ret_col; std::vector ret_data; for (int64_t i = 0; i < coo.row->shape[0]; ++i) { const IdType row_id = coo_row_data[i]; const IdType col_id = coo_col_data[i]; const IdType mapped_row_id = row_map.Map(row_id, -1); if (mapped_row_id != -1) { const IdType mapped_col_id = col_map.Map(col_id, -1); if (mapped_col_id != -1) { ret_row.push_back(mapped_row_id); ret_col.push_back(mapped_col_id); ret_data.push_back(coo_data ? coo_data[i] : i); } } } return COOMatrix( rows->shape[0], cols->shape[0], NDArray::FromVector(ret_row), NDArray::FromVector(ret_col), NDArray::FromVector(ret_data), coo.row_sorted, coo.col_sorted); } template COOMatrix COOSliceMatrix( COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); template COOMatrix COOSliceMatrix( COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); ///////////////////////////// COOReorder ///////////////////////////// template COOMatrix COOReorder( COOMatrix coo, runtime::NDArray new_row_id_arr, runtime::NDArray new_col_id_arr) { CHECK_SAME_DTYPE(coo.row, new_row_id_arr); CHECK_SAME_DTYPE(coo.col, new_col_id_arr); // Input COO const IdType *in_rows = static_cast(coo.row->data); const IdType *in_cols = static_cast(coo.col->data); int64_t num_rows = coo.num_rows; int64_t num_cols = coo.num_cols; int64_t nnz = coo.row->shape[0]; CHECK_EQ(num_rows, new_row_id_arr->shape[0]) << "The new row Id array needs to be the same as the number of rows of " "COO"; CHECK_EQ(num_cols, new_col_id_arr->shape[0]) << "The new col Id array needs to be the same as the number of cols of " "COO"; // New row/col Ids. const IdType *new_row_ids = static_cast(new_row_id_arr->data); const IdType *new_col_ids = static_cast(new_col_id_arr->data); // Output COO NDArray out_row_arr = NDArray::Empty({nnz}, coo.row->dtype, coo.row->ctx); NDArray out_col_arr = NDArray::Empty({nnz}, coo.col->dtype, coo.col->ctx); NDArray out_data_arr = COOHasData(coo) ? coo.data : NullArray(); IdType *out_row = static_cast(out_row_arr->data); IdType *out_col = static_cast(out_col_arr->data); parallel_for(0, nnz, [=](size_t b, size_t e) { for (auto i = b; i < e; ++i) { out_row[i] = new_row_ids[in_rows[i]]; out_col[i] = new_col_ids[in_cols[i]]; } }); return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr); } template COOMatrix COOReorder( COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); template COOMatrix COOReorder( COOMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids); } // namespace impl } // namespace aten } // namespace dgl