Commit a9aca7be authored by rusty1s's avatar rusty1s
Browse files

add __eq__

parent fdc31cc4
...@@ -37,3 +37,19 @@ def test_to_symmetric(device): ...@@ -37,3 +37,19 @@ def test_to_symmetric(device):
[6, 0, 5], [6, 0, 5],
[3, 5, 0], [3, 5, 0],
] ]
def test_equal():
row = torch.tensor([0, 0, 0, 1, 1])
col = torch.tensor([0, 1, 2, 0, 2])
value = torch.arange(1, 6)
matA = SparseTensor(row=row, col=col, value=value)
matB = SparseTensor(row=row, col=col, value=value)
col = torch.tensor([0, 1, 2, 0, 1])
matC = SparseTensor(row=row, col=col, value=value)
assert id(matA) != id(matB)
assert matA == matB
assert id(matA) != id(matC)
assert matA != matC
...@@ -197,6 +197,28 @@ class SparseTensor(object): ...@@ -197,6 +197,28 @@ class SparseTensor(object):
self.storage.clear_cache_() self.storage.clear_cache_()
return self return self
def __eq__(self, other) -> bool:
if not isinstance(other, self.__class__):
return False
if self.sizes() != other.sizes():
return False
rowptrA, colA, valueA = self.csr()
rowptrB, colB, valueB = other.csr()
if valueA is None and valueB is not None:
return False
if valueA is not None and valueB is None:
return False
if not torch.equal(rowptrA, rowptrB):
return False
if not torch.equal(colA, colB):
return False
if valueA is None and valueB is None:
return True
return torch.equal(valueA, valueB)
# Utility functions ####################################################### # Utility functions #######################################################
def fill_value_(self, fill_value: float, dtype: Optional[int] = None): def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment