Unverified Commit 4463b3d6 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support strided tensor when calling old DGL APIs (#5506)

parent d45eafd4
......@@ -83,9 +83,9 @@ torch::Tensor SDDMMNoAutoGrad(
if (mat1.dim() >= 3) {
shape.push_back(mat1.size(2));
// (N, K, B) -> (N, B, K)
mat1 = mat1.transpose(1, 2).contiguous();
mat1 = mat1.transpose(1, 2);
// (M, K, B) -> (M, B, K)
mat2_tr = mat2_tr.transpose(1, 2).contiguous();
mat2_tr = mat2_tr.transpose(1, 2);
}
auto ret = torch::zeros(shape, mat1.options());
const std::string op = "dot";
......
......@@ -68,7 +68,7 @@ void _SDDMMSanityCheck(
torch::Tensor SDDMMAutoGrad::forward(
AutogradContext* ctx, const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor mat1, torch::Tensor mat2) {
auto mat2_tr = mat2.transpose(0, 1).contiguous();
auto mat2_tr = mat2.transpose(0, 1);
auto ret = SDDMMNoAutoGrad(sparse_mat, mat1, mat2_tr);
torch::Tensor cache_mat1, cache_mat2;
if (mat1.requires_grad()) {
......@@ -94,13 +94,12 @@ tensor_list SDDMMAutoGrad::backward(
torch::Tensor mat1_grad, mat2_grad;
if (ctx->saved_data["mat1_requires_grad"].toBool()) {
// SDDMM(M, A, B) = C. dA = SpMM(dC, B^T)
mat1_grad = SpMMNoAutoGrad(
sparse_mat, grad, mat2.transpose(0, 1).contiguous(), false);
mat1_grad = SpMMNoAutoGrad(sparse_mat, grad, mat2.transpose(0, 1), false);
}
if (ctx->saved_data["mat2_requires_grad"].toBool()) {
// SDDMM(M, A, B) = C. dB = SpMM(dC^T, A)^T
auto mat2_tr_grad = SpMMNoAutoGrad(sparse_mat, grad, mat1, true);
mat2_grad = mat2_tr_grad.transpose(0, 1).contiguous();
mat2_grad = mat2_tr_grad.transpose(0, 1);
}
return {torch::Tensor(), mat1_grad, mat2_grad};
}
......
......@@ -52,7 +52,7 @@ inline static void ElementwiseOpSanityCheck(
/** @brief Convert a Torch tensor to a DGL array. */
inline static runtime::NDArray TorchTensorToDGLArray(torch::Tensor tensor) {
return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor));
return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor.contiguous()));
}
/** @brief Convert a DGL array to a Torch tensor. */
......
......@@ -13,6 +13,7 @@ from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_stride,
sparse_matrix_to_dense,
sparse_matrix_to_torch_sparse,
)
......@@ -30,6 +31,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
else:
X = torch.randn(shape[1], requires_grad=True, device=dev)
X = rand_stride(X)
sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result)
sparse_result.backward(grad)
......@@ -56,6 +58,7 @@ def test_bspmm(create_func, shape, nnz):
dev = F.ctx()
A = create_func(shape, nnz, dev, 2)
X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)
X = rand_stride(X)
sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result)
......
......@@ -6,7 +6,13 @@ import torch
from dgl.sparse import bsddmm, sddmm
from .utils import clone_detach_and_grad, rand_coo, rand_csc, rand_csr
from .utils import (
clone_detach_and_grad,
rand_coo,
rand_csc,
rand_csr,
rand_stride,
)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
......@@ -23,6 +29,9 @@ def test_sddmm(create_func, shape, nnz, hidden):
B = torch.rand(shape[0], requires_grad=True, device=dev)
C = torch.rand(shape[1], requires_grad=True, device=dev)
B = rand_stride(B)
C = rand_stride(C)
A_val_clone = clone_detach_and_grad(A.val)
dense_B = clone_detach_and_grad(B)
dense_C = clone_detach_and_grad(C)
......@@ -58,6 +67,9 @@ def test_bsddmm(create_func, shape, nnz, nz_dim):
B = torch.rand(shape[0], hidden, nz_dim, requires_grad=True, device=dev)
C = torch.rand(hidden, shape[1], nz_dim, requires_grad=True, device=dev)
B = rand_stride(B)
C = rand_stride(C)
A_val_clone = clone_detach_and_grad(A.val)
dense_B = clone_detach_and_grad(B)
dense_C = clone_detach_and_grad(C)
......
import numpy as np
import torch
from dgl.sparse import from_coo, from_csc, from_csr, SparseMatrix
from dgl.sparse import from_csc, from_csr, SparseMatrix, spmatrix
np.random.seed(42)
torch.random.manual_seed(42)
......@@ -13,6 +13,16 @@ def clone_detach_and_grad(t):
return t
def rand_stride(t):
"""Add stride to the last dimension of a tensor."""
stride = np.random.randint(2, 4)
ret = torch.stack([t] * stride, dim=-1)[..., 0]
ret = ret.detach()
if torch.is_floating_point(t):
ret.requires_grad_()
return ret
def rand_coo(shape, nnz, dev, nz_dim=None):
# Create a sparse matrix without duplicate entries.
nnzid = np.random.choice(shape[0] * shape[1], nnz, replace=False)
......@@ -23,7 +33,10 @@ def rand_coo(shape, nnz, dev, nz_dim=None):
val = torch.randn(nnz, device=dev, requires_grad=True)
else:
val = torch.randn(nnz, nz_dim, device=dev, requires_grad=True)
return from_coo(row, col, val, shape)
indices = torch.stack([row, col])
indices = rand_stride(indices)
val = rand_stride(val)
return spmatrix(indices, val, shape)
def rand_csr(shape, nnz, dev, nz_dim=None):
......@@ -42,6 +55,9 @@ def rand_csr(shape, nnz, dev, nz_dim=None):
indptr = torch.cumsum(indptr, 0)
row_sorted, row_sorted_idx = torch.sort(row)
indices = col[row_sorted_idx]
indptr = rand_stride(indptr)
indices = rand_stride(indices)
val = rand_stride(val)
return from_csr(indptr, indices, val, shape=shape)
......@@ -61,6 +77,9 @@ def rand_csc(shape, nnz, dev, nz_dim=None):
indptr = torch.cumsum(indptr, 0)
col_sorted, col_sorted_idx = torch.sort(col)
indices = row[col_sorted_idx]
indptr = rand_stride(indptr)
indices = rand_stride(indices)
val = rand_stride(val)
return from_csc(indptr, indices, val, shape=shape)
......@@ -69,7 +88,9 @@ def rand_coo_uncoalesced(shape, nnz, dev):
row = torch.randint(shape[0], (nnz,), device=dev)
col = torch.randint(shape[1], (nnz,), device=dev)
val = torch.randn(nnz, device=dev, requires_grad=True)
return from_coo(row, col, val, shape)
indices = torch.stack([row, col])
indices = rand_stride(indices)
return spmatrix(indices, val, shape)
def rand_csr_uncoalesced(shape, nnz, dev):
......@@ -83,6 +104,9 @@ def rand_csr_uncoalesced(shape, nnz, dev):
indptr = torch.cumsum(indptr, 0)
row_sorted, row_sorted_idx = torch.sort(row)
indices = col[row_sorted_idx]
indptr = rand_stride(indptr)
indices = rand_stride(indices)
val = rand_stride(val)
return from_csr(indptr, indices, val, shape=shape)
......@@ -97,6 +121,9 @@ def rand_csc_uncoalesced(shape, nnz, dev):
indptr = torch.cumsum(indptr, 0)
col_sorted, col_sorted_idx = torch.sort(col)
indices = row[col_sorted_idx]
indptr = rand_stride(indptr)
indices = rand_stride(indices)
val = rand_stride(val)
return from_csc(indptr, indices, val, shape=shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment