Commit d49dcbbd authored by rusty1s's avatar rusty1s
Browse files

diag and matmul fixes

parent 1a5fb80c
......@@ -4,23 +4,24 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
CHECK_CPU(index);
int64_t E = index.size(1);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
int64_t k) {
CHECK_CPU(row);
CHECK_CPU(col);
int64_t E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool));
auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>();
int64_t r, c;
if (k < 0) {
for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E];
r = row_data[i], c = col_data[i];
if (r + k < 0) {
mask_data[i] = true;
} else if (r + k >= N) {
......@@ -33,7 +34,7 @@ at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
}
} else {
for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E];
r = row_data[i], c = col_data[i];
if (r + k >= N) {
mask_data[i + num_diag] = true;
} else if (r + k > c) {
......
......@@ -174,36 +174,28 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out);
}
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CPU(index);
at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad, std::string reduce) {
CHECK_CPU(row);
CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat);
CHECK_CPU(grad);
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = index.size(1);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto out = at::zeros(index.size(1), grad.options());
auto out = at::zeros(row.numel(), grad.options());
auto index_data = index.DATA_PTR<int64_t>();
auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
......@@ -214,7 +206,7 @@ at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = index_data[e], col = index_data[E + e], val = (scalar_t)0;
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
......
......@@ -2,12 +2,14 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N,
int64_t k);
at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M,
int64_t N, int64_t k);
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(index);
return non_diag_mask_cuda(index, M, N, k);
at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
return non_diag_mask_cuda(row, col, M, N, k);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -5,14 +5,15 @@
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data,
__global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = index_data[thread_idx], c = index_data[thread_idx + numel];
int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) {
if (r + k < 0) {
......@@ -37,21 +38,20 @@ __global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data,
}
}
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N,
int64_t k) {
int64_t E = index.size(1);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M,
int64_t N, int64_t k) {
int64_t E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool));
auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>();
auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
index_data, mask_data, N, k, num_diag, E);
row_data, col_data, mask_data, N, k, num_diag, E);
return mask;
}
......@@ -20,13 +20,14 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return spmm_cuda(rowptr, col, value_opt, mat, reduce);
}
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
CHECK_CUDA(index);
at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
return spmm_val_bw_cuda(index, rowptr, mat, grad, reduce);
return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
......@@ -210,17 +210,18 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
const scalar_t *mat_data, const scalar_t *grad_data,
scalar_t *out_data, int B, int M, int N, int E, int K) {
spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B, int M,
int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(index_data + index_idx);
int col = __ldg(index_data + E + index_idx);
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
......@@ -246,43 +247,35 @@ spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
}
}
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
at::Tensor grad, std::string reduce) {
at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor mat, at::Tensor grad,
std::string reduce) {
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = index.size(1);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = at::empty(index.size(1), grad.options());
auto out = at::zeros(row.numel(), grad.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto index_data = index.DATA_PTR<int64_t>();
auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_val_bw_kernel<scalar_t, REDUCE>
<<<BLOCKS, THREADS, 0, stream>>>(index_data, rowptr_data, mat_data,
grad_data, out_data, B, M, N, E, K);
spmm_val_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
......
......@@ -121,14 +121,10 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
descr, valueC_data, rowptrC_data, colC_data, info, buffer);
});
auto rowC = at::empty_like(colC);
auto rowC_data = rowC.DATA_PTR<int>();
cusparseXcsr2coo(handle, rowptrC_data, nnzC, M, rowC_data,
CUSPARSE_INDEX_BASE_ZERO);
cusparseDestroyCsrgemm2Info(info);
auto indexC = at::stack({rowC.toType(at::kLong), colC.toType(at::kLong)}, 0);
return std::make_tuple(indexC, rowptrC.toType(at::kLong), valueC);
rowptrC = rowptrC.toType(at::kLong);
colC = col.toType(at::kLong);
return std::make_tuple(rowptrC, colC, valueC);
}
// #define THREADS 1024
......
......@@ -9,12 +9,9 @@ except ImportError:
def remove_diag(src, k=0):
index, value = src.coo()
row, col = index
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
index = index[:, inv_mask]
row, col = row[inv_mask], col[inv_mask]
if src.has_value():
value = value[inv_mask]
......@@ -32,7 +29,7 @@ def remove_diag(src, k=0):
colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1
storage = src.storage.__class__(index, value,
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
......@@ -45,26 +42,26 @@ def set_diag(src, values=None, k=0):
src = src.remove_diag(k=0)
index, value = src.coo()
row, col, value = src.coo()
func = diag_cuda if index.is_cuda else diag_cpu
mask = func.non_diag_mask(index, src.size(0), src.size(1), k)
func = diag_cuda if row.is_cuda else diag_cpu
mask = func.non_diag_mask(row, col, src.size(0), src.size(1), k)
inv_mask = ~mask
new_index = index.new_empty((2, mask.size(0)))
new_index[:, mask] = index
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
diag = torch.arange(start, start + num_diag, device=src.device)
num_diag = mask.numel() - index.size(1)
start = -k if k < 0 else 0
new_row = row.new_empty(mask.size(0))
new_row[mask] = row
new_row[inv_mask] = diag
diag_row = torch.arange(start, start + num_diag, device=src.device)
new_index[0, inv_mask] = diag_row
diag_col = diag_row.add_(k)
new_index[1, inv_mask] = diag_col
new_col = col.new_empty(mask.size(0))
new_col[mask] = row
new_col[inv_mask] = diag.add_(k)
new_value = None
if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + mask.size()[1:])
new_value = torch.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1
......@@ -78,8 +75,9 @@ def set_diag(src, values=None, k=0):
colcount = src.storage.colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = src.storage.__class__(new_index, new_value,
storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
return src.__class__.from_storage(storage)
......@@ -20,14 +20,13 @@ def spmm(is_cuda):
class SPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, index, rowcount, rowptr, colptr, csr2csc, value, mat,
def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce):
out, arg_out = spmm(mat.is_cuda).spmm(rowptr, index[1], value, mat,
reduce)
out, arg_out = spmm(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
ctx.reduce = reduce
ctx.save_for_backward(index, rowcount, rowptr, colptr, csr2csc, value,
mat, arg_out)
ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
csr2csc, arg_out)
if reduce == 'min' or reduce == 'max':
ctx.mark_non_differentiable(arg_out)
......@@ -37,27 +36,27 @@ class SPMM(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_out, *args):
data = ctx.saved_tensors
index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data
(row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
arg_out) = ctx.saved_tensors
invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
or ctx.needs_input_grad[6]):
invalid_arg_mask = arg_out == index.size(1)
invalid_arg_mask = arg_out == row.size(0)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None
if ctx.needs_input_grad[5]:
if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
index, rowptr, mat, grad_out, ctx.reduce)
row, rowptr, col, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
index, rowptr, mat, grad_out, ctx.reduce)
row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out_ind.flatten()].view_as(arg_out)
col = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out)
out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(),
......@@ -65,16 +64,16 @@ class SPMM(torch.autograd.Function):
grad_value = grad_value[:-1]
grad_mat = None
if ctx.needs_input_grad[6]:
if ctx.needs_input_grad[4]:
if ctx.reduce in ['sum', 'add']:
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, index[0][csr2csc], value, grad_out, 'sum')
colptr, row[csr2csc], value, grad_out, 'sum')
elif ctx.reduce == 'mean':
count = rowcount[index[0]].to(mat.dtype).clamp_(min=1)
count = rowcount[row].to(mat.dtype).clamp_(min=1)
value = count.pow_(-1) if value is None else value / count
row = index[0][csr2csc]
row = row[csr2csc]
value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum')
......@@ -86,19 +85,20 @@ class SPMM(torch.autograd.Function):
else:
value = grad_out
value.masked_fill_(invalid_arg_mask, 0)
col = index[1][arg_out_ind.flatten()].view_as(arg_out)
col = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2,
dim_size=mat.size(-2))
return None, None, None, None, None, grad_value, grad_mat, None
return None, None, None, grad_value, grad_mat, None, None, None, None
class SPSPMM(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
if rowptrA.is_cuda:
indexC, rowptrC, valueC = spspmm_cuda.spspmm(
rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K)
rowptrC, colC, valueC = spspmm_cuda.spspmm(rowptrA, colA, valueA,
rowptrB, colB, valueB,
M, N, K)
else:
dtype = None
if valueA is not None:
......@@ -116,21 +116,18 @@ class SPSPMM(torch.autograd.Function):
C = A @ B
valueC = torch.from_numpy(
C.data).to(dtype) if dtype is not None else None
rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
C = C.tocoo()
rowC = torch.from_numpy(C.row).to(torch.int64)
colC = torch.from_numpy(C.col).to(torch.int64)
indexC = torch.stack([rowC, colC], dim=0)
colC = torch.from_numpy(C.indices).to(torch.int64)
valueC = torch.from_numpy(C.data)
valueC = valueC.to(dtype) if dtype is not None else valueC
ctx.mark_non_differentiable(indexC, rowptrC)
ctx.mark_non_differentiable(rowptrC, colC)
# We cannot return `NoneType` in torch.autograd :(
if valueC is None:
return indexC, rowptrC
return rowptrC, colC
else:
return indexC, rowptrC, valueC
return rowptrC, colC, valueC
@staticmethod
def backward(ctx, grad_indexC, grad_rowptrC, *args):
......@@ -152,7 +149,12 @@ def matmul(src, other, reduce='sum'):
# Sparse-Dense Matrix Multiplication.
if torch.is_tensor(other):
assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr
rowptr, col, value = src.csr()
row = None
if reduce in ['sum', 'add'] and (src.requires_grad
or other.reuqires_grad):
row = src.storage.row
rowcount = None
if other.requires_grad and reduce in ['mean']:
......@@ -162,8 +164,8 @@ def matmul(src, other, reduce='sum'):
if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value,
other, reduce)
return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
csr2csc, reduce)
# Sparse-Sparse Matrix Multiplication.
elif isinstance(other, src.__class__):
......@@ -171,10 +173,9 @@ def matmul(src, other, reduce='sum'):
assert src.dim() == 2 and other.dim() == 2
data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
other.size(1))
data = data if len(data) == 3 else data + (None, )
(rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
sparse_size = torch.Size([src.size(0), other.size(1)])
out = src.__class__(data[0], data[2], sparse_size, is_sorted=True)
out.storage._rowptr = data[1]
return out
return src.__class__(rowptr=rowptr, col=col, value=value,
sparse_size=sparse_size, is_sorted=True)
raise ValueError
......@@ -78,6 +78,7 @@ class SparseStorage(object):
assert col is not None
assert col.dtype == torch.long
assert col.dim() == 1
col = col.contiguous()
if sparse_size is None:
M = rowptr.numel() - 1 if row is None else row.max().item() + 1
......@@ -89,46 +90,54 @@ class SparseStorage(object):
assert row.device == col.device
assert row.dim() == 1
assert row.numel() == col.numel()
row = row.contiguous()
if rowptr is not None:
assert rowptr.dtype == torch.long
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0]
rowptr = rowptr.contiguous()
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0]
rowcount = rowcount.contiguous()
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1]
colptr = colptr.contiguous()
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1]
colcount = colcount.contiguous()
if csr2csc is not None:
assert csr2csc.dtype == torch.long
assert csr2csc.device == col.device
assert csr2csc.dim() == 1
assert csr2csc.numel() == col.size(0)
csr2csc = csr2csc.contiguous()
if csc2csr is not None:
assert csc2csr.dtype == torch.long
assert csc2csr.device == col.device
assert csc2csr.dim() == 1
assert csc2csr.numel() == col.size(0)
csc2csr = csc2csr.contiguous()
self._row = row
self._rowptr = rowptr
......
......@@ -11,7 +11,7 @@ from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce
from torch_sparse.diag import remove_diag
from torch_sparse.diag import remove_diag, set_diag
from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_
......@@ -482,8 +482,9 @@ SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag
SparseTensor.matmul = matmul
SparseTensor.remove_diag = remove_diag #TODO
SparseTensor.set_diag = set_diag #TODO
SparseTensor.matmul = matmul # TODO
SparseTensor.add = add
SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment