Commit 8481d0d6 authored by rusty1s's avatar rusty1s
Browse files

masked select fix

parent e7f4ef9f
...@@ -10,56 +10,54 @@ def masked_select(src, dim, mask): ...@@ -10,56 +10,54 @@ def masked_select(src, dim, mask):
storage = src.storage storage = src.storage
if dim == 0: if dim == 0:
(row, col), value = src.coo() row, col, value = src.coo()
rowcount = src.storage.rowcount rowcount = src.storage.rowcount
row_mask = mask[row]
rowcount = rowcount[mask] rowcount = rowcount[mask]
idx = torch.arange(rowcount.size(0), device=rowcount.device)
row = idx.repeat_interleave(rowcount) mask = mask[row]
col = col[row_mask] row = torch.arange(rowcount.size(0),
index = torch.stack([row, col], dim=0) device=row.device).repeat_interleave(rowcount)
col = col[mask]
if src.has_value(): if src.has_value():
value = value[row_mask] value = value[mask]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__( storage = src.storage.__class__(row=row, col=col, value=value,
index, value, sparse_size, rowcount=rowcount, is_sorted=True) sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
elif dim == 1: elif dim == 1:
row, col, value = src.coo()
csr2csc = src.storage.csr2csc csr2csc = src.storage.csr2csc
row = src.storage.row[csr2csc] row, col = row[csr2csc], col[csr2csc]
col = src.storage.col[csr2csc]
colcount = src.storage.colcount colcount = src.storage.colcount
col_mask = mask[col]
colcount = colcount[mask] colcount = colcount[mask]
tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount) mask = mask[col]
row = row[col_mask] col = torch.arange(colcount.size(0),
device=col.device).repeat_interleave(colcount)
row = row[mask]
csc2csr = (colcount.size(0) * row + col).argsort() csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr] row, col = row[csc2csr], col[csc2csr]
value = src.storage.value
if src.has_value(): if src.has_value():
value = value[csr2csc][col_mask][csc2csr] value = value[csr2csc][mask][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__( storage = src.storage.__class__(row=row, col=col, value=value,
index, sparse_size=sparse_size,
value, colcount=colcount, csc2csr=csc2csr,
sparse_size,
colcount=colcount,
csc2csr=csc2csr,
is_sorted=True) is_sorted=True)
else: else:
idx = mask.nonzero().view(-1) idx = mask.nonzero().view(-1)
storage = src.storage.apply_value(lambda x: x.index_select( storage = src.storage.apply_value(
dim - 1, idx)) lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage) return src.from_storage(storage)
...@@ -70,14 +68,15 @@ def masked_select_nnz(src, mask, layout=None): ...@@ -70,14 +68,15 @@ def masked_select_nnz(src, mask, layout=None):
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr] mask = mask[src.storage.csc2csr]
index, value = src.coo() row, col, value = src.coo()
row, col = row[mask], col[mask]
index = index[:, mask]
if src.has_value(): if src.has_value():
value = value[mask] value = value[mask]
# There is no other information we can maintain... # There is no other information we can maintain...
storage = src.storage.__class__( storage = src.storage.__class__(row=row, col=col, value=value,
index, value, src.sparse_size(), is_sorted=True) sparse_size=src.sparse_size(),
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment