Unverified Commit db772831 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Only support scalar or vector values in SparseMatrix creators (#5204)



* [Sparse] Only support scalar or vector values in SparseMatrix creators

* Update

* Update sparse_matrix.py
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent c7e3754d
......@@ -362,6 +362,9 @@ def diag(
>>> D.nnz
5
"""
assert (
val.dim() <= 2
), "The values of a DiagMatrix can only be scalars or vectors."
# NOTE(Mufei): this may not be needed if DiagMatrix is simple enough
return DiagMatrix(val, shape)
......
......@@ -481,6 +481,10 @@ def from_coo(
if val is None:
val = torch.ones(row.shape[0]).to(row.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape))
......@@ -562,6 +566,10 @@ def from_csr(
if val is None:
val = torch.ones(indices.shape[0]).to(indptr.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(
torch.ops.dgl_sparse.from_csr(indptr, indices, val, shape)
)
......@@ -645,6 +653,10 @@ def from_csc(
if val is None:
val = torch.ones(indices.shape[0]).to(indptr.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(
torch.ops.dgl_sparse.from_csc(indptr, indices, val, shape)
)
......@@ -681,6 +693,10 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
values=tensor([2, 2, 2]),
shape=(3, 5), nnz=3)
"""
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment