"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f5c0bfa55c7d8900a8259dffc5c404d71c826641"
Unverified Commit db772831 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Only support scalar or vector values in SparseMatrix creators (#5204)



* [Sparse] Only support scalar or vector values in SparseMatrix creators

* Update

* Update sparse_matrix.py
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent c7e3754d
...@@ -362,6 +362,9 @@ def diag( ...@@ -362,6 +362,9 @@ def diag(
>>> D.nnz >>> D.nnz
5 5
""" """
assert (
val.dim() <= 2
), "The values of a DiagMatrix can only be scalars or vectors."
# NOTE(Mufei): this may not be needed if DiagMatrix is simple enough # NOTE(Mufei): this may not be needed if DiagMatrix is simple enough
return DiagMatrix(val, shape) return DiagMatrix(val, shape)
......
...@@ -481,6 +481,10 @@ def from_coo( ...@@ -481,6 +481,10 @@ def from_coo(
if val is None: if val is None:
val = torch.ones(row.shape[0]).to(row.device) val = torch.ones(row.shape[0]).to(row.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape)) return SparseMatrix(torch.ops.dgl_sparse.from_coo(row, col, val, shape))
...@@ -562,6 +566,10 @@ def from_csr( ...@@ -562,6 +566,10 @@ def from_csr(
if val is None: if val is None:
val = torch.ones(indices.shape[0]).to(indptr.device) val = torch.ones(indices.shape[0]).to(indptr.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.from_csr(indptr, indices, val, shape) torch.ops.dgl_sparse.from_csr(indptr, indices, val, shape)
) )
...@@ -645,6 +653,10 @@ def from_csc( ...@@ -645,6 +653,10 @@ def from_csc(
if val is None: if val is None:
val = torch.ones(indices.shape[0]).to(indptr.device) val = torch.ones(indices.shape[0]).to(indptr.device)
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix( return SparseMatrix(
torch.ops.dgl_sparse.from_csc(indptr, indices, val, shape) torch.ops.dgl_sparse.from_csc(indptr, indices, val, shape)
) )
...@@ -681,6 +693,10 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix: ...@@ -681,6 +693,10 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
values=tensor([2, 2, 2]), values=tensor([2, 2, 2]),
shape=(3, 5), nnz=3) shape=(3, 5), nnz=3)
""" """
assert (
val.dim() <= 2
), "The values of a SparseMatrix can only be scalars or vectors."
return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val)) return SparseMatrix(torch.ops.dgl_sparse.val_like(mat.c_sparse_matrix, val))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment