Unverified Commit 73a508e1 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Stack SparseMatrix COO row and column coordinates into one tensor. (#5314)

parent 5ea04713
...@@ -25,10 +25,10 @@ enum SparseFormat { kCOO, kCSR, kCSC }; ...@@ -25,10 +25,10 @@ enum SparseFormat { kCOO, kCSR, kCSC };
struct COO { struct COO {
/** @brief The shape of the matrix. */ /** @brief The shape of the matrix. */
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
/** @brief COO format row indices array of the matrix. */ /**
torch::Tensor row; * @brief COO tensor of shape (2, nnz), stacking the row and column indices.
/** @brief COO format column indices array of the matrix. */ */
torch::Tensor col; torch::Tensor indices;
/** @brief Whether the row indices are sorted. */ /** @brief Whether the row indices are sorted. */
bool row_sorted = false; bool row_sorted = false;
/** @brief Whether the column indices per row are sorted. */ /** @brief Whether the column indices per row are sorted. */
......
...@@ -51,10 +51,10 @@ torch::Tensor ReduceAlong( ...@@ -51,10 +51,10 @@ torch::Tensor ReduceAlong(
torch::Tensor idx; torch::Tensor idx;
if (dim == 0) { if (dim == 0) {
output_shape[0] = coo->num_cols; output_shape[0] = coo->num_cols;
idx = coo->col.view(view_dims).expand_as(value); idx = coo->indices.index({1}).view(view_dims).expand_as(value);
} else if (dim == 1) { } else if (dim == 1) {
output_shape[0] = coo->num_rows; output_shape[0] = coo->num_rows;
idx = coo->row.view(view_dims).expand_as(value); idx = coo->indices.index({0}).view(view_dims).expand_as(value);
} }
torch::Tensor out = torch::zeros(output_shape, value.options()); torch::Tensor out = torch::zeros(output_shape, value.options());
......
...@@ -18,14 +18,15 @@ std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) { ...@@ -18,14 +18,15 @@ std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {
auto row = DGLArrayToTorchTensor(dgl_coo.row); auto row = DGLArrayToTorchTensor(dgl_coo.row);
auto col = DGLArrayToTorchTensor(dgl_coo.col); auto col = DGLArrayToTorchTensor(dgl_coo.col);
TORCH_CHECK(aten::IsNullArray(dgl_coo.data)); TORCH_CHECK(aten::IsNullArray(dgl_coo.data));
auto indices = torch::stack({row, col});
return std::make_shared<COO>( return std::make_shared<COO>(
COO{dgl_coo.num_rows, dgl_coo.num_cols, row, col, dgl_coo.row_sorted, COO{dgl_coo.num_rows, dgl_coo.num_cols, indices, dgl_coo.row_sorted,
dgl_coo.col_sorted}); dgl_coo.col_sorted});
} }
aten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo) { aten::COOMatrix COOToOldDGLCOO(const std::shared_ptr<COO>& coo) {
auto row = TorchTensorToDGLArray(coo->row); auto row = TorchTensorToDGLArray(coo->indices.index({0}));
auto col = TorchTensorToDGLArray(coo->col); auto col = TorchTensorToDGLArray(coo->indices.index({1}));
return aten::COOMatrix( return aten::COOMatrix(
coo->num_rows, coo->num_cols, row, col, aten::NullArray(), coo->num_rows, coo->num_cols, row, col, aten::NullArray(),
coo->row_sorted, coo->col_sorted); coo->row_sorted, coo->col_sorted);
...@@ -50,14 +51,13 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) { ...@@ -50,14 +51,13 @@ aten::CSRMatrix CSRToOldDGLCSR(const std::shared_ptr<CSR>& csr) {
torch::Tensor COOToTorchCOO( torch::Tensor COOToTorchCOO(
const std::shared_ptr<COO>& coo, torch::Tensor value) { const std::shared_ptr<COO>& coo, torch::Tensor value) {
std::vector<torch::Tensor> indices = {coo->row, coo->col}; torch::Tensor indices = coo->indices;
if (value.ndimension() == 2) { if (value.ndimension() == 2) {
return torch::sparse_coo_tensor( return torch::sparse_coo_tensor(
torch::stack(indices), value, indices, value, {coo->num_rows, coo->num_cols, value.size(1)});
{coo->num_rows, coo->num_cols, value.size(1)});
} else { } else {
return torch::sparse_coo_tensor( return torch::sparse_coo_tensor(
torch::stack(indices), value, {coo->num_rows, coo->num_cols}); indices, value, {coo->num_rows, coo->num_cols});
} }
} }
......
...@@ -30,12 +30,10 @@ SparseMatrix::SparseMatrix( ...@@ -30,12 +30,10 @@ SparseMatrix::SparseMatrix(
// device. Do we allow the graph structure and values are on different // device. Do we allow the graph structure and values are on different
// devices? // devices?
if (coo != nullptr) { if (coo != nullptr) {
TORCH_CHECK(coo->row.dim() == 1); TORCH_CHECK(coo->indices.dim() == 2);
TORCH_CHECK(coo->col.dim() == 1); TORCH_CHECK(coo->indices.size(0) == 2);
TORCH_CHECK(coo->row.size(0) == coo->col.size(0)); TORCH_CHECK(coo->indices.size(1) == value.size(0));
TORCH_CHECK(coo->row.size(0) == value.size(0)); TORCH_CHECK(coo->indices.device() == value.device());
TORCH_CHECK(coo->row.device() == value.device());
TORCH_CHECK(coo->col.device() == value.device());
} }
if (csr != nullptr) { if (csr != nullptr) {
TORCH_CHECK(csr->indptr.dim() == 1); TORCH_CHECK(csr->indptr.dim() == 1);
...@@ -76,8 +74,8 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer( ...@@ -76,8 +74,8 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO( c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value, torch::Tensor row, torch::Tensor col, torch::Tensor value,
const std::vector<int64_t>& shape) { const std::vector<int64_t>& shape) {
auto coo = auto coo = std::make_shared<COO>(
std::make_shared<COO>(COO{shape[0], shape[1], row, col, false, false}); COO{shape[0], shape[1], torch::stack({row, col}), false, false});
return SparseMatrix::FromCOOPointer(coo, value, shape); return SparseMatrix::FromCOOPointer(coo, value, shape);
} }
...@@ -141,7 +139,7 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() { ...@@ -141,7 +139,7 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() { std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr(); auto coo = COOPtr();
auto val = value(); auto val = value();
return std::make_tuple(coo->row, coo->col); return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
} }
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
......
...@@ -64,8 +64,8 @@ torch::Tensor _CSRMask( ...@@ -64,8 +64,8 @@ torch::Tensor _CSRMask(
const c10::intrusive_ptr<SparseMatrix>& sub_mat) { const c10::intrusive_ptr<SparseMatrix>& sub_mat) {
auto csr = CSRToOldDGLCSR(mat->CSRPtr()); auto csr = CSRToOldDGLCSR(mat->CSRPtr());
auto val = TorchTensorToDGLArray(value); auto val = TorchTensorToDGLArray(value);
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row); auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->indices.index({0}));
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col); auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->indices.index({1}));
runtime::NDArray ret = aten::CSRGetFloatingData(csr, row, col, val, 0.); runtime::NDArray ret = aten::CSRGetFloatingData(csr, row, col, val, 0.);
return DGLArrayToTorchTensor(ret); return DGLArrayToTorchTensor(ret);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment