Commit 9c0419db authored by rusty1s's avatar rusty1s
Browse files

eye func

parent 8daf945a
from torch_sparse import eye
def test_eye():
index, value = eye(3)
assert index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert value.tolist() == [1, 1, 1]
from .coalesce import coalesce from .coalesce import coalesce
from .transpose import transpose from .transpose import transpose
from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
...@@ -9,6 +10,7 @@ __all__ = [ ...@@ -9,6 +10,7 @@ __all__ = [
'__version__', '__version__',
'coalesce', 'coalesce',
'transpose', 'transpose',
'eye',
'spmm', 'spmm',
'spspmm', 'spspmm',
] ]
import torch
def eye(m, dtype=None, device=None):
"""Returns a sparse matrix with ones on the diagonal and zeros elsewhere.
Args:
m (int): The first dimension of sparse matrix.
dtype (`torch.dtype`, optional): The desired data type of returned
value vector. (default is set by `torch.set_default_tensor_type()`)
device (`torch.device`, optional): The desired device of returned
tensors. (default is set by `torch.set_default_tensor_type()`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row = torch.arange(m, dtype=torch.long, device=device)
index = torch.stack([row, row], dim=0)
value = torch.ones(m, dtype=dtype, device=device)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment