Commit 6c5e08e7 authored by rusty1s's avatar rusty1s
Browse files

offset implementation

parent d8a8ebab
...@@ -8,16 +8,25 @@ from .utils import dtypes, devices, tensor ...@@ -8,16 +8,25 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_cat(dtype, device): def test_remove_diag(dtype, device):
index = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device) index = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device) value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value) mat = SparseTensor(index, value)
mat.fill_cache_() mat.fill_cache_()
mat = mat.remove_diag() mat = mat.remove_diag()
index, value = mat.coo() assert mat.storage.index.tolist() == [[0, 1], [1, 2]]
assert index.tolist() == [[0, 1], [1, 2]] assert mat.storage.value.tolist() == [2, 3]
assert value.tolist() == [2, 3]
assert len(mat.cached_keys()) == 2 assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 1, 0] assert mat.storage.rowcount.tolist() == [1, 1, 0]
assert mat.storage.colcount.tolist() == [0, 1, 1] assert mat.storage.colcount.tolist() == [0, 1, 1]
mat = SparseTensor(index, value)
mat.fill_cache_()
mat = mat.remove_diag(k=1)
assert mat.storage.index.tolist() == [[0, 2], [0, 2]]
assert mat.storage.value.tolist() == [1, 4]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 0, 1]
assert mat.storage.colcount.tolist() == [1, 0, 1]
import torch
def add_diag(src, value=None, k=0): def add_diag(src, value=None, k=0):
pass pass
...@@ -9,14 +6,16 @@ def remove_diag(src, k=0): ...@@ -9,14 +6,16 @@ def remove_diag(src, k=0):
index, value = src.coo() index, value = src.coo()
row, col = index row, col = index
mask = row == col if k == 0 else row == (col + k) inv_mask = row != col if k == 0 else row != (col - k)
inv_mask = ~mask
index = index[:, inv_mask] index = index[:, inv_mask]
if src.has_value(): if src.has_value():
value = value[inv_mask] value = value[inv_mask]
if src.storage.has_rowcount() or src.storage.has_colcount():
mask = ~inv_mask
rowcount = None rowcount = None
if src.storage.has_rowcount(): if src.storage.has_rowcount():
rowcount = src.storage.rowcount.clone() rowcount = src.storage.rowcount.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment