Commit d49dcbbd authored by rusty1s's avatar rusty1s
Browse files

diag and matmul fixes

parent 1a5fb80c
...@@ -4,23 +4,24 @@ ...@@ -4,23 +4,24 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) { at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
CHECK_CPU(index); int64_t k) {
CHECK_CPU(row);
int64_t E = index.size(1); CHECK_CPU(col);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
int64_t E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool)); auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.DATA_PTR<bool>();
int64_t r, c; int64_t r, c;
if (k < 0) { if (k < 0) {
for (int64_t i = 0; i < E; i++) { for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E]; r = row_data[i], c = col_data[i];
if (r + k < 0) { if (r + k < 0) {
mask_data[i] = true; mask_data[i] = true;
} else if (r + k >= N) { } else if (r + k >= N) {
...@@ -33,7 +34,7 @@ at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) { ...@@ -33,7 +34,7 @@ at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) {
} }
} else { } else {
for (int64_t i = 0; i < E; i++) { for (int64_t i = 0; i < E; i++) {
r = index_data[i], c = index_data[i + E]; r = row_data[i], c = col_data[i];
if (r + k >= N) { if (r + k >= N) {
mask_data[i + num_diag] = true; mask_data[i + num_diag] = true;
} else if (r + k > c) { } else if (r + k > c) {
......
...@@ -174,36 +174,28 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -174,36 +174,28 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat, at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor grad, std::string reduce) { at::Tensor mat, at::Tensor grad, std::string reduce) {
CHECK_CPU(index); CHECK_CPU(row);
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col);
CHECK_CPU(mat); CHECK_CPU(mat);
CHECK_CPU(grad); CHECK_CPU(grad);
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous(); mat = mat.contiguous();
grad = grad.contiguous(); grad = grad.contiguous();
auto M = grad.size(-2); auto M = grad.size(-2);
auto N = mat.size(-2); auto N = mat.size(-2);
auto E = index.size(1); auto E = row.numel();
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto out = at::zeros(index.size(1), grad.options()); auto out = at::zeros(row.numel(), grad.options());
auto index_data = index.DATA_PTR<int64_t>(); auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>(); auto grad_data = grad.DATA_PTR<scalar_t>();
...@@ -214,7 +206,7 @@ at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat, ...@@ -214,7 +206,7 @@ at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat,
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) { for (int e = 0; e < E; e++) {
row = index_data[e], col = index_data[E + e], val = (scalar_t)0; row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] * val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k]; grad_data[b * M * K + row * K + k];
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N, at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M,
int64_t k); int64_t N, int64_t k);
at::Tensor non_diag_mask(at::Tensor index, int64_t M, int64_t N, int64_t k) { at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
CHECK_CUDA(index); int64_t k) {
return non_diag_mask_cuda(index, M, N, k); CHECK_CUDA(row);
CHECK_CUDA(col);
return non_diag_mask_cuda(row, col, M, N, k);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -5,14 +5,15 @@ ...@@ -5,14 +5,15 @@
#define THREADS 1024 #define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data, __global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag, int64_t N, int64_t k, int64_t num_diag,
int64_t numel) { int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) { if (thread_idx < numel) {
int64_t r = index_data[thread_idx], c = index_data[thread_idx + numel]; int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) { if (k < 0) {
if (r + k < 0) { if (r + k < 0) {
...@@ -37,21 +38,20 @@ __global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data, ...@@ -37,21 +38,20 @@ __global__ void non_diag_mask_kernel(const int64_t *index_data, bool *out_data,
} }
} }
at::Tensor non_diag_mask_cuda(at::Tensor index, int64_t M, int64_t N, at::Tensor non_diag_mask_cuda(at::Tensor row, at::Tensor col, int64_t M,
int64_t k) { int64_t N, int64_t k) {
int64_t E = index.size(1); int64_t E = row.size(0);
index = index.contiguous();
auto index_data = index.DATA_PTR<int64_t>();
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k); int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto mask = at::zeros(E + num_diag, index.options().dtype(at::kBool)); auto row_data = row.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mask = at::zeros(E + num_diag, row.options().dtype(at::kBool));
auto mask_data = mask.DATA_PTR<bool>(); auto mask_data = mask.DATA_PTR<bool>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>( non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
index_data, mask_data, N, k, num_diag, E); row_data, col_data, mask_data, N, k, num_diag, E);
return mask; return mask;
} }
...@@ -20,13 +20,14 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -20,13 +20,14 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return spmm_cuda(rowptr, col, value_opt, mat, reduce); return spmm_cuda(rowptr, col, value_opt, mat, reduce);
} }
at::Tensor spmm_val_bw(at::Tensor index, at::Tensor rowptr, at::Tensor mat, at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor grad, std::string reduce) { at::Tensor mat, at::Tensor grad, std::string reduce) {
CHECK_CUDA(index); CHECK_CUDA(row);
CHECK_CUDA(rowptr); CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat); CHECK_CUDA(mat);
CHECK_CUDA(grad); CHECK_CUDA(grad);
return spmm_val_bw_cuda(index, rowptr, mat, grad, reduce); return spmm_val_bw_cuda(row, rowptr, col, mat, grad, reduce);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -210,17 +210,18 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -210,17 +210,18 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
template <typename scalar_t, ReductionType REDUCE> template <typename scalar_t, ReductionType REDUCE>
__global__ void __global__ void
spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data, spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const scalar_t *mat_data, const scalar_t *grad_data, const int64_t *col_data, const scalar_t *mat_data,
scalar_t *out_data, int B, int M, int N, int E, int K) { const scalar_t *grad_data, scalar_t *out_data, int B, int M,
int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x; int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32 int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32 int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) { if (index_idx < E) {
int row = __ldg(index_data + index_idx); int row = __ldg(row_data + index_idx);
int col = __ldg(index_data + E + index_idx); int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0; scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
...@@ -246,43 +247,35 @@ spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data, ...@@ -246,43 +247,35 @@ spmm_val_bw_kernel(const int64_t *index_data, const int64_t *rowptr_data,
} }
} }
at::Tensor spmm_val_bw_cuda(at::Tensor index, at::Tensor rowptr, at::Tensor mat, at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
at::Tensor grad, std::string reduce) { at::Tensor mat, at::Tensor grad,
std::string reduce) {
AT_ASSERTM(index.dim() == 2, "Input mismatch");
AT_ASSERTM(index.size(0) == 2, "Input mismatch");
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
AT_ASSERTM(mat.dim() == grad.dim(), "Input mismatch");
AT_ASSERTM(reduce2REDUCE.at(reduce) == SUM ||
reduce2REDUCE.at(reduce) == MEAN,
"Reduce operation not supported");
index = index.contiguous();
mat = mat.contiguous(); mat = mat.contiguous();
grad = grad.contiguous(); grad = grad.contiguous();
auto M = grad.size(-2); auto M = grad.size(-2);
auto N = mat.size(-2); auto N = mat.size(-2);
auto E = index.size(1); auto E = row.numel();
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS); auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = at::empty(index.size(1), grad.options()); auto out = at::zeros(row.numel(), grad.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
auto index_data = index.DATA_PTR<int64_t>(); auto row_data = row.DATA_PTR<int64_t>();
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>();
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.DATA_PTR<scalar_t>();
auto grad_data = grad.DATA_PTR<scalar_t>(); auto grad_data = grad.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_val_bw_kernel<scalar_t, REDUCE> spmm_val_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
<<<BLOCKS, THREADS, 0, stream>>>(index_data, rowptr_data, mat_data, row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
grad_data, out_data, B, M, N, E, K); N, E, K);
}); });
}); });
......
...@@ -121,14 +121,10 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA, ...@@ -121,14 +121,10 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
descr, valueC_data, rowptrC_data, colC_data, info, buffer); descr, valueC_data, rowptrC_data, colC_data, info, buffer);
}); });
auto rowC = at::empty_like(colC); rowptrC = rowptrC.toType(at::kLong);
auto rowC_data = rowC.DATA_PTR<int>(); colC = col.toType(at::kLong);
cusparseXcsr2coo(handle, rowptrC_data, nnzC, M, rowC_data,
CUSPARSE_INDEX_BASE_ZERO); return std::make_tuple(rowptrC, colC, valueC);
cusparseDestroyCsrgemm2Info(info);
auto indexC = at::stack({rowC.toType(at::kLong), colC.toType(at::kLong)}, 0);
return std::make_tuple(indexC, rowptrC.toType(at::kLong), valueC);
} }
// #define THREADS 1024 // #define THREADS 1024
......
...@@ -9,12 +9,9 @@ except ImportError: ...@@ -9,12 +9,9 @@ except ImportError:
def remove_diag(src, k=0): def remove_diag(src, k=0):
index, value = src.coo() row, col, value = src.coo()
row, col = index
inv_mask = row != col if k == 0 else row != (col - k) inv_mask = row != col if k == 0 else row != (col - k)
row, col = row[inv_mask], col[inv_mask]
index = index[:, inv_mask]
if src.has_value(): if src.has_value():
value = value[inv_mask] value = value[inv_mask]
...@@ -32,7 +29,7 @@ def remove_diag(src, k=0): ...@@ -32,7 +29,7 @@ def remove_diag(src, k=0):
colcount = src.storage.colcount.clone() colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1 colcount[col[mask]] -= 1
storage = src.storage.__class__(index, value, storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(), sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount, rowcount=rowcount, colcount=colcount,
is_sorted=True) is_sorted=True)
...@@ -45,26 +42,26 @@ def set_diag(src, values=None, k=0): ...@@ -45,26 +42,26 @@ def set_diag(src, values=None, k=0):
src = src.remove_diag(k=0) src = src.remove_diag(k=0)
index, value = src.coo() row, col, value = src.coo()
func = diag_cuda if index.is_cuda else diag_cpu func = diag_cuda if row.is_cuda else diag_cpu
mask = func.non_diag_mask(index, src.size(0), src.size(1), k) mask = func.non_diag_mask(row, col, src.size(0), src.size(1), k)
inv_mask = ~mask inv_mask = ~mask
new_index = index.new_empty((2, mask.size(0))) start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
new_index[:, mask] = index diag = torch.arange(start, start + num_diag, device=src.device)
num_diag = mask.numel() - index.size(1) new_row = row.new_empty(mask.size(0))
start = -k if k < 0 else 0 new_row[mask] = row
new_row[inv_mask] = diag
diag_row = torch.arange(start, start + num_diag, device=src.device) new_col = col.new_empty(mask.size(0))
new_index[0, inv_mask] = diag_row new_col[mask] = row
diag_col = diag_row.add_(k) new_col[inv_mask] = diag.add_(k)
new_index[1, inv_mask] = diag_col
new_value = None new_value = None
if src.has_value(): if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + mask.size()[1:]) new_value = torch.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1 new_value[inv_mask] = values if values is not None else 1
...@@ -78,8 +75,9 @@ def set_diag(src, values=None, k=0): ...@@ -78,8 +75,9 @@ def set_diag(src, values=None, k=0):
colcount = src.storage.colcount.clone() colcount = src.storage.colcount.clone()
colcount[start + k:start + num_diag + k] += 1 colcount[start + k:start + num_diag + k] += 1
storage = src.storage.__class__(new_index, new_value, storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
sparse_size=src.sparse_size(), sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount, rowcount=rowcount, colcount=colcount,
is_sorted=True) is_sorted=True)
return src.__class__.from_storage(storage) return src.__class__.from_storage(storage)
...@@ -20,14 +20,13 @@ def spmm(is_cuda): ...@@ -20,14 +20,13 @@ def spmm(is_cuda):
class SPMM(torch.autograd.Function): class SPMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, index, rowcount, rowptr, colptr, csr2csc, value, mat, def forward(ctx, row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
reduce): reduce):
out, arg_out = spmm(mat.is_cuda).spmm(rowptr, index[1], value, mat, out, arg_out = spmm(mat.is_cuda).spmm(rowptr, col, value, mat, reduce)
reduce)
ctx.reduce = reduce ctx.reduce = reduce
ctx.save_for_backward(index, rowcount, rowptr, colptr, csr2csc, value, ctx.save_for_backward(row, rowptr, col, value, mat, rowcount, colptr,
mat, arg_out) csr2csc, arg_out)
if reduce == 'min' or reduce == 'max': if reduce == 'min' or reduce == 'max':
ctx.mark_non_differentiable(arg_out) ctx.mark_non_differentiable(arg_out)
...@@ -37,27 +36,27 @@ class SPMM(torch.autograd.Function): ...@@ -37,27 +36,27 @@ class SPMM(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_out, *args): def backward(ctx, grad_out, *args):
data = ctx.saved_tensors (row, rowptr, col, value, mat, rowcount, colptr, csr2csc,
index, rowcount, rowptr, colptr, csr2csc, value, mat, arg_out = data arg_out) = ctx.saved_tensors
invalid_arg_mask = arg_out_ind = None invalid_arg_mask = arg_out_ind = None
if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5] if ctx.reduce in ['min', 'max'] and (ctx.needs_input_grad[5]
or ctx.needs_input_grad[6]): or ctx.needs_input_grad[6]):
invalid_arg_mask = arg_out == index.size(1) invalid_arg_mask = arg_out == row.size(0)
arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1) arg_out_ind = arg_out.masked_fill(invalid_arg_mask, -1)
grad_value = None grad_value = None
if ctx.needs_input_grad[5]: if ctx.needs_input_grad[3]:
if ctx.reduce in ['sum', 'add']: if ctx.reduce in ['sum', 'add']:
grad_value = spmm(grad_out.is_cuda).spmm_val_bw( grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
index, rowptr, mat, grad_out, ctx.reduce) row, rowptr, col, mat, grad_out, ctx.reduce)
if ctx.reduce == 'mean': if ctx.reduce == 'mean':
grad_value = spmm(grad_out.is_cuda).spmm_val_bw( grad_value = spmm(grad_out.is_cuda).spmm_val_bw(
index, rowptr, mat, grad_out, ctx.reduce) row, rowptr, col, mat, grad_out, ctx.reduce)
elif ctx.reduce in ['min', 'max']: elif ctx.reduce in ['min', 'max']:
col = index[1][arg_out_ind.flatten()].view_as(arg_out) col = col[arg_out_ind.flatten()].view_as(arg_out)
out = mat.gather(-2, col).mul_(grad_out) out = mat.gather(-2, col).mul_(grad_out)
out.masked_fill_(invalid_arg_mask, 0) out.masked_fill_(invalid_arg_mask, 0)
grad_value = scatter_add(out.flatten(), arg_out.flatten(), grad_value = scatter_add(out.flatten(), arg_out.flatten(),
...@@ -65,16 +64,16 @@ class SPMM(torch.autograd.Function): ...@@ -65,16 +64,16 @@ class SPMM(torch.autograd.Function):
grad_value = grad_value[:-1] grad_value = grad_value[:-1]
grad_mat = None grad_mat = None
if ctx.needs_input_grad[6]: if ctx.needs_input_grad[4]:
if ctx.reduce in ['sum', 'add']: if ctx.reduce in ['sum', 'add']:
value = value[csr2csc] if value is not None else value value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm( grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, index[0][csr2csc], value, grad_out, 'sum') colptr, row[csr2csc], value, grad_out, 'sum')
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
count = rowcount[index[0]].to(mat.dtype).clamp_(min=1) count = rowcount[row].to(mat.dtype).clamp_(min=1)
value = count.pow_(-1) if value is None else value / count value = count.pow_(-1) if value is None else value / count
row = index[0][csr2csc] row = row[csr2csc]
value = value[csr2csc] if value is not None else value value = value[csr2csc] if value is not None else value
grad_mat, _ = spmm(grad_out.is_cuda).spmm( grad_mat, _ = spmm(grad_out.is_cuda).spmm(
colptr, row, value, grad_out, 'sum') colptr, row, value, grad_out, 'sum')
...@@ -86,19 +85,20 @@ class SPMM(torch.autograd.Function): ...@@ -86,19 +85,20 @@ class SPMM(torch.autograd.Function):
else: else:
value = grad_out value = grad_out
value.masked_fill_(invalid_arg_mask, 0) value.masked_fill_(invalid_arg_mask, 0)
col = index[1][arg_out_ind.flatten()].view_as(arg_out) col = col[arg_out_ind.flatten()].view_as(arg_out)
grad_mat = scatter_add(value, col, dim=-2, grad_mat = scatter_add(value, col, dim=-2,
dim_size=mat.size(-2)) dim_size=mat.size(-2))
return None, None, None, None, None, grad_value, grad_mat, None return None, None, None, grad_value, grad_mat, None, None, None, None
class SPSPMM(torch.autograd.Function): class SPSPMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K): def forward(ctx, rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K):
if rowptrA.is_cuda: if rowptrA.is_cuda:
indexC, rowptrC, valueC = spspmm_cuda.spspmm( rowptrC, colC, valueC = spspmm_cuda.spspmm(rowptrA, colA, valueA,
rowptrA, colA, valueA, rowptrB, colB, valueB, M, N, K) rowptrB, colB, valueB,
M, N, K)
else: else:
dtype = None dtype = None
if valueA is not None: if valueA is not None:
...@@ -116,21 +116,18 @@ class SPSPMM(torch.autograd.Function): ...@@ -116,21 +116,18 @@ class SPSPMM(torch.autograd.Function):
C = A @ B C = A @ B
valueC = torch.from_numpy(
C.data).to(dtype) if dtype is not None else None
rowptrC = torch.from_numpy(C.indptr).to(torch.int64) rowptrC = torch.from_numpy(C.indptr).to(torch.int64)
C = C.tocoo() colC = torch.from_numpy(C.indices).to(torch.int64)
rowC = torch.from_numpy(C.row).to(torch.int64) valueC = torch.from_numpy(C.data)
colC = torch.from_numpy(C.col).to(torch.int64) valueC = valueC.to(dtype) if dtype is not None else valueC
indexC = torch.stack([rowC, colC], dim=0)
ctx.mark_non_differentiable(indexC, rowptrC) ctx.mark_non_differentiable(rowptrC, colC)
# We cannot return `NoneType` in torch.autograd :( # We cannot return `NoneType` in torch.autograd :(
if valueC is None: if valueC is None:
return indexC, rowptrC return rowptrC, colC
else: else:
return indexC, rowptrC, valueC return rowptrC, colC, valueC
@staticmethod @staticmethod
def backward(ctx, grad_indexC, grad_rowptrC, *args): def backward(ctx, grad_indexC, grad_rowptrC, *args):
...@@ -152,7 +149,12 @@ def matmul(src, other, reduce='sum'): ...@@ -152,7 +149,12 @@ def matmul(src, other, reduce='sum'):
# Sparse-Dense Matrix Multiplication. # Sparse-Dense Matrix Multiplication.
if torch.is_tensor(other): if torch.is_tensor(other):
assert reduce in ['sum', 'add', 'mean', 'min', 'max'] assert reduce in ['sum', 'add', 'mean', 'min', 'max']
(index, value), rowptr = src.coo(), src.storage.rowptr rowptr, col, value = src.csr()
row = None
if reduce in ['sum', 'add'] and (src.requires_grad
or other.reuqires_grad):
row = src.storage.row
rowcount = None rowcount = None
if other.requires_grad and reduce in ['mean']: if other.requires_grad and reduce in ['mean']:
...@@ -162,8 +164,8 @@ def matmul(src, other, reduce='sum'): ...@@ -162,8 +164,8 @@ def matmul(src, other, reduce='sum'):
if other.requires_grad and reduce in ['sum', 'add', 'mean']: if other.requires_grad and reduce in ['sum', 'add', 'mean']:
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
return SPMM.apply(index, rowcount, rowptr, colptr, csr2csc, value, return SPMM.apply(row, rowptr, col, value, other, rowcount, colptr,
other, reduce) csr2csc, reduce)
# Sparse-Sparse Matrix Multiplication. # Sparse-Sparse Matrix Multiplication.
elif isinstance(other, src.__class__): elif isinstance(other, src.__class__):
...@@ -171,10 +173,9 @@ def matmul(src, other, reduce='sum'): ...@@ -171,10 +173,9 @@ def matmul(src, other, reduce='sum'):
assert src.dim() == 2 and other.dim() == 2 assert src.dim() == 2 and other.dim() == 2
data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1), data = SPSPMM.apply(*src.csr(), *other.csr(), src.size(0), src.size(1),
other.size(1)) other.size(1))
data = data if len(data) == 3 else data + (None, ) (rowptr, col), value = data[:2], data[2] if len(data) == 3 else None
sparse_size = torch.Size([src.size(0), other.size(1)]) sparse_size = torch.Size([src.size(0), other.size(1)])
out = src.__class__(data[0], data[2], sparse_size, is_sorted=True) return src.__class__(rowptr=rowptr, col=col, value=value,
out.storage._rowptr = data[1] sparse_size=sparse_size, is_sorted=True)
return out
raise ValueError raise ValueError
...@@ -78,6 +78,7 @@ class SparseStorage(object): ...@@ -78,6 +78,7 @@ class SparseStorage(object):
assert col is not None assert col is not None
assert col.dtype == torch.long assert col.dtype == torch.long
assert col.dim() == 1 assert col.dim() == 1
col = col.contiguous()
if sparse_size is None: if sparse_size is None:
M = rowptr.numel() - 1 if row is None else row.max().item() + 1 M = rowptr.numel() - 1 if row is None else row.max().item() + 1
...@@ -89,46 +90,54 @@ class SparseStorage(object): ...@@ -89,46 +90,54 @@ class SparseStorage(object):
assert row.device == col.device assert row.device == col.device
assert row.dim() == 1 assert row.dim() == 1
assert row.numel() == col.numel() assert row.numel() == col.numel()
row = row.contiguous()
if rowptr is not None: if rowptr is not None:
assert rowptr.dtype == torch.long assert rowptr.dtype == torch.long
assert rowptr.device == col.device assert rowptr.device == col.device
assert rowptr.dim() == 1 assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_size[0] assert rowptr.numel() - 1 == sparse_size[0]
rowptr = rowptr.contiguous()
if value is not None: if value is not None:
assert value.device == col.device assert value.device == col.device
assert value.size(0) == col.size(0) assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None: if rowcount is not None:
assert rowcount.dtype == torch.long assert rowcount.dtype == torch.long
assert rowcount.device == col.device assert rowcount.device == col.device
assert rowcount.dim() == 1 assert rowcount.dim() == 1
assert rowcount.numel() == sparse_size[0] assert rowcount.numel() == sparse_size[0]
rowcount = rowcount.contiguous()
if colptr is not None: if colptr is not None:
assert colptr.dtype == torch.long assert colptr.dtype == torch.long
assert colptr.device == col.device assert colptr.device == col.device
assert colptr.dim() == 1 assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_size[1] assert colptr.numel() - 1 == sparse_size[1]
colptr = colptr.contiguous()
if colcount is not None: if colcount is not None:
assert colcount.dtype == torch.long assert colcount.dtype == torch.long
assert colcount.device == col.device assert colcount.device == col.device
assert colcount.dim() == 1 assert colcount.dim() == 1
assert colcount.numel() == sparse_size[1] assert colcount.numel() == sparse_size[1]
colcount = colcount.contiguous()
if csr2csc is not None: if csr2csc is not None:
assert csr2csc.dtype == torch.long assert csr2csc.dtype == torch.long
assert csr2csc.device == col.device assert csr2csc.device == col.device
assert csr2csc.dim() == 1 assert csr2csc.dim() == 1
assert csr2csc.numel() == col.size(0) assert csr2csc.numel() == col.size(0)
csr2csc = csr2csc.contiguous()
if csc2csr is not None: if csc2csr is not None:
assert csc2csr.dtype == torch.long assert csc2csr.dtype == torch.long
assert csc2csr.device == col.device assert csc2csr.device == col.device
assert csc2csr.dim() == 1 assert csc2csr.dim() == 1
assert csc2csr.numel() == col.size(0) assert csc2csr.numel() == col.size(0)
csc2csr = csc2csr.contiguous()
self._row = row self._row = row
self._rowptr = rowptr self._rowptr = rowptr
......
...@@ -11,7 +11,7 @@ from torch_sparse.select import select ...@@ -11,7 +11,7 @@ from torch_sparse.select import select
from torch_sparse.index_select import index_select, index_select_nnz from torch_sparse.index_select import index_select, index_select_nnz
from torch_sparse.masked_select import masked_select, masked_select_nnz from torch_sparse.masked_select import masked_select, masked_select_nnz
import torch_sparse.reduce import torch_sparse.reduce
from torch_sparse.diag import remove_diag from torch_sparse.diag import remove_diag, set_diag
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul
from torch_sparse.add import add, add_, add_nnz, add_nnz_ from torch_sparse.add import add, add_, add_nnz, add_nnz_
...@@ -482,8 +482,9 @@ SparseTensor.sum = torch_sparse.reduce.sum ...@@ -482,8 +482,9 @@ SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor.mean = torch_sparse.reduce.mean SparseTensor.mean = torch_sparse.reduce.mean
SparseTensor.min = torch_sparse.reduce.min SparseTensor.min = torch_sparse.reduce.min
SparseTensor.max = torch_sparse.reduce.max SparseTensor.max = torch_sparse.reduce.max
SparseTensor.remove_diag = remove_diag SparseTensor.remove_diag = remove_diag #TODO
SparseTensor.matmul = matmul SparseTensor.set_diag = set_diag #TODO
SparseTensor.matmul = matmul # TODO
SparseTensor.add = add SparseTensor.add = add
SparseTensor.add_ = add_ SparseTensor.add_ = add_
SparseTensor.add_nnz = add_nnz SparseTensor.add_nnz = add_nnz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment