"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7de2e51b5ec9f21685df56be42e41b5b3e6938a8"
Unverified Commit 866c70da authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Fix bugs in Diag format conversions (#5465)

parent 4cf5f682
...@@ -113,7 +113,8 @@ std::shared_ptr<CSR> DiagToCSR( ...@@ -113,7 +113,8 @@ std::shared_ptr<CSR> DiagToCSR(
const c10::TensorOptions& indices_options) { const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols); int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options); auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1); auto nnz_range = torch::arange(nnz + 1, indices_options);
indptr.index_put_({nnz_range}, nnz_range);
auto indices = torch::arange(nnz, indices_options); auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>( return std::make_shared<CSR>(
CSR{diag->num_rows, diag->num_cols, indptr, indices, CSR{diag->num_rows, diag->num_cols, indptr, indices,
...@@ -125,7 +126,8 @@ std::shared_ptr<CSR> DiagToCSC( ...@@ -125,7 +126,8 @@ std::shared_ptr<CSR> DiagToCSC(
const c10::TensorOptions& indices_options) { const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols); int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options); auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1); auto nnz_range = torch::arange(nnz + 1, indices_options);
indptr.index_put_({nnz_range}, nnz_range);
auto indices = torch::arange(nnz, indices_options); auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>( return std::make_shared<CSR>(
CSR{diag->num_cols, diag->num_rows, indptr, indices, CSR{diag->num_cols, diag->num_rows, indptr, indices,
......
...@@ -345,6 +345,28 @@ def test_csr_to_csc(dense_dim, indptr, indices, shape): ...@@ -345,6 +345,28 @@ def test_csr_to_csc(dense_dim, indptr, indices, shape):
assert torch.allclose(mat_indices, indices) assert torch.allclose(mat_indices, indices)
@pytest.mark.parametrize("shape", [(3, 5), (5, 5), (5, 4)])
def test_diag_conversions(shape):
n_rows, n_cols = shape
nnz = min(shape)
ctx = F.ctx()
val = torch.randn(nnz).to(ctx)
D = diag(val, shape)
row, col = D.coo()
assert torch.allclose(row, torch.arange(nnz).to(ctx))
assert torch.allclose(col, torch.arange(nnz).to(ctx))
indptr, indices, _ = D.csr()
exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_rows - nnz)
assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))
assert torch.allclose(indices, torch.arange(nnz).to(ctx))
indptr, indices, _ = D.csc()
exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_cols - nnz)
assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx))
assert torch.allclose(indices, torch.arange(nnz).to(ctx))
@pytest.mark.parametrize("val_shape", [(3), (3, 2)]) @pytest.mark.parametrize("val_shape", [(3), (3, 2)])
@pytest.mark.parametrize("shape", [(3, 5), (5, 5)]) @pytest.mark.parametrize("shape", [(3, 5), (5, 5)])
def test_val_like(val_shape, shape): def test_val_like(val_shape, shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment