Unverified Commit 5ffd2a02 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support column-wise softmax (#5377)



* [Sparse] Support column-wise softmax

* Update python/dgl/sparse/softmax.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

---------
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent c42fa8a5
...@@ -12,17 +12,20 @@ namespace dgl { ...@@ -12,17 +12,20 @@ namespace dgl {
namespace sparse { namespace sparse {
/** /**
* @brief Apply row-wise softmax to the non-zero entries of the sparse matrix. * @brief Apply softmax to the non-zero entries of the sparse matrix on the
* dimension dim. dim = 0 or 1 indicates column-wise or row-wise softmax
* respectively.
* *
* This function supports autograd for the sparse matrix, but it does not * This function supports autograd for the sparse matrix, but it does not
* support higher order gradient. * support higher order gradient.
* *
* @param sparse_mat The sparse matrix * @param sparse_mat The sparse matrix
* @param dim The dimension to apply softmax
* *
* @return Sparse matrix * @return Sparse matrix
*/ */
c10::intrusive_ptr<SparseMatrix> Softmax( c10::intrusive_ptr<SparseMatrix> Softmax(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat); const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim);
} // namespace sparse } // namespace sparse
} // namespace dgl } // namespace dgl
......
...@@ -112,7 +112,7 @@ torch::Tensor SDDMMNoAutoGrad( ...@@ -112,7 +112,7 @@ torch::Tensor SDDMMNoAutoGrad(
torch::Tensor BroadcastOpNoAutoGrad( torch::Tensor BroadcastOpNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
const std::string& op) { const std::string& op, int64_t dim) {
auto sparse_val = sparse_mat->value(); auto sparse_val = sparse_mat->value();
const int64_t out_row = sparse_mat->nnz(); const int64_t out_row = sparse_mat->nnz();
const std::vector<int64_t> shape({out_row, sparse_val.size(1)}); const std::vector<int64_t> shape({out_row, sparse_val.size(1)});
...@@ -121,6 +121,9 @@ torch::Tensor BroadcastOpNoAutoGrad( ...@@ -121,6 +121,9 @@ torch::Tensor BroadcastOpNoAutoGrad(
auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val); auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat); auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
auto dgl_ret = TorchTensorToDGLArray(ret); auto dgl_ret = TorchTensorToDGLArray(ret);
// Setting dgl_rhs_target to 0 or 2 means using row or column coordinators
// to access dgl_dense_mat for each edge, respectively.
auto dgl_rhs_target = dim == 0 ? 2 : 0;
// The format for calculation will be chosen in the following order: COO, CSR // The format for calculation will be chosen in the following order: COO, CSR
// . COO is created if the sparse matrix only has CSC format. // . COO is created if the sparse matrix only has CSC format.
...@@ -130,32 +133,32 @@ torch::Tensor BroadcastOpNoAutoGrad( ...@@ -130,32 +133,32 @@ torch::Tensor BroadcastOpNoAutoGrad(
auto coo = COOToOldDGLCOO(sparse_mat->COOPtr()); auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
aten::COOSDDMM( aten::COOSDDMM(
op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret, op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret,
1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */); 1 /* Lhs target: e */, dgl_rhs_target);
} else { } else {
auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr()); auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
aten::CSRSDDMM( aten::CSRSDDMM(
op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret, op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret,
1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */); 1 /* Lhs target: e */, dgl_rhs_target);
} }
return ret; return ret;
} }
torch::Tensor BroadcastSubNoAutoGrad( torch::Tensor BroadcastSubNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat) { int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub"); return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub", dim);
} }
torch::Tensor BroadcastDivNoAutoGrad( torch::Tensor BroadcastDivNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat) { int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div"); return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div", dim);
} }
torch::Tensor BroadcastMulNoAutoGrad( torch::Tensor BroadcastMulNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat) { int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul"); return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul", dim);
} }
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad( c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
......
...@@ -57,68 +57,76 @@ torch::Tensor SDDMMNoAutoGrad( ...@@ -57,68 +57,76 @@ torch::Tensor SDDMMNoAutoGrad(
/** /**
* @brief Broadcast the dense feature to the nonzero entries and then compute * @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = \phi(x_e, x_v), where x_e is the nonzero value, x_v is the dense * x_e = \phi(x_e, x_v) on the dimension dim, where x_e is the nonzero value,
* feature, and \phi is add, sub, mul, or div. * x_v is the dense feature, and \phi is add, sub, mul, or div. dim = 0 or 1
* means column-wise or row-wise broadcast respectively.
* *
* This function does not take care of autograd. * This function does not take care of autograd.
* *
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D) * @param dense_mat Dense feature of shape (N, D)
* @param op Operator, can be add, sub, mul, or div * @param op Operator, can be add, sub, mul, or div
* @param dim The dimension to broadcast.
* *
* @return Dense tensor of shape (nnz, D) * @return Dense tensor of shape (nnz, D)
*/ */
torch::Tensor BroadcastOpNoAutoGrad( torch::Tensor BroadcastOpNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
const std::string& op); const std::string& op, int64_t dim);
/** /**
* @brief Broadcast the dense feature to the nonzero entries and then compute * @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e - x_v, where x_e is the nonzero value, x_v is the dense * x_e = x_e - x_v on the dimension dim, where x_e is the nonzero value, x_v is
* feature. * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
* *
* This function does not take care of autograd. * This function does not take care of autograd.
* *
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D) * @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
* *
* @return Dense tensor of shape (nnz, D) * @return Dense tensor of shape (nnz, D)
*/ */
torch::Tensor BroadcastSubNoAutoGrad( torch::Tensor BroadcastSubNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat); int64_t dim);
/** /**
* @brief Broadcast the dense feature to the nonzero entries and then compute * @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e / x_v, where x_e is the nonzero value, x_v is the dense * x_e = x_e / x_v on the dimension dim, where x_e is the nonzero value, x_v is
* feature. * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
* *
* This function does not take care of autograd. * This function does not take care of autograd.
* *
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D) * @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
* *
* @return Dense tensor of shape (nnz, D) * @return Dense tensor of shape (nnz, D)
*/ */
torch::Tensor BroadcastDivNoAutoGrad( torch::Tensor BroadcastDivNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat); int64_t dim);
/** /**
* @brief Broadcast the dense feature to the nonzero entries and then compute * @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e * x_v, where x_e is the nonzero value, x_v is the dense * x_e = x_e * x_v on the dimension dim, where x_e is the nonzero value, x_v is
* feature. * the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
* *
* This function does not take care of autograd. * This function does not take care of autograd.
* *
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values * @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D) * @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
* *
* @return Dense tensor of shape (nnz, D) * @return Dense tensor of shape (nnz, D)
*/ */
torch::Tensor BroadcastMulNoAutoGrad( torch::Tensor BroadcastMulNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
torch::Tensor dense_mat); int64_t dim);
/** /**
* @brief Perform a sparse-sparse matrix multiplication with possibly different * @brief Perform a sparse-sparse matrix multiplication with possibly different
......
...@@ -20,22 +20,22 @@ class SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> { ...@@ -20,22 +20,22 @@ class SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> {
public: public:
static torch::Tensor forward( static torch::Tensor forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat, AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
torch::Tensor sparse_val); torch::Tensor sparse_val, int64_t dim);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs); static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
}; };
torch::Tensor SoftmaxAutoGrad::forward( torch::Tensor SoftmaxAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat, AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
torch::Tensor sparse_val) { torch::Tensor sparse_val, int64_t dim) {
// Reduce by columns with dim 1. // Reduce by columns with dim 1.
auto sparse_val_max = ReduceMax(sparse_mat, 1); auto sparse_val_max = ReduceMax(sparse_mat, dim);
auto sparse_val_exp = auto sparse_val_exp =
BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max).exp(); BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max, dim).exp();
auto sparse_val_sum = auto sparse_val_sum =
ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), 1); ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), dim);
auto sparse_score = BroadcastDivNoAutoGrad( auto sparse_score = BroadcastDivNoAutoGrad(
SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum); SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum, dim);
const bool sparse_requires_grad = sparse_val.requires_grad(); const bool sparse_requires_grad = sparse_val.requires_grad();
torch::Tensor cache_sparse_score; torch::Tensor cache_sparse_score;
...@@ -44,6 +44,7 @@ torch::Tensor SoftmaxAutoGrad::forward( ...@@ -44,6 +44,7 @@ torch::Tensor SoftmaxAutoGrad::forward(
} }
ctx->saved_data["sparse_matrix"] = sparse_mat; ctx->saved_data["sparse_matrix"] = sparse_mat;
ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad; ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({cache_sparse_score}); ctx->save_for_backward({cache_sparse_score});
return sparse_score; return sparse_score;
} }
...@@ -58,21 +59,22 @@ tensor_list SoftmaxAutoGrad::backward( ...@@ -58,21 +59,22 @@ tensor_list SoftmaxAutoGrad::backward(
ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>(); ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
const bool sparse_requires_grad = const bool sparse_requires_grad =
ctx->saved_data["sparse_requires_grad"].toBool(); ctx->saved_data["sparse_requires_grad"].toBool();
const int64_t dim = ctx->saved_data["dim"].toInt();
torch::Tensor sparse_val_grad; torch::Tensor sparse_val_grad;
if (sparse_requires_grad) { if (sparse_requires_grad) {
auto sds = sparse_score * output_grad; auto sds = sparse_score * output_grad;
auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), 1); auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), dim);
sparse_val_grad = sparse_val_grad =
sds - BroadcastMulNoAutoGrad( sds - BroadcastMulNoAutoGrad(
SparseMatrix::ValLike(sparse_mat, sparse_score), accum); SparseMatrix::ValLike(sparse_mat, sparse_score), accum, dim);
} }
return {torch::Tensor(), sparse_val_grad}; return {torch::Tensor(), sparse_val_grad, torch::Tensor()};
} }
c10::intrusive_ptr<SparseMatrix> Softmax( c10::intrusive_ptr<SparseMatrix> Softmax(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat) { const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim) {
auto sparse_val = sparse_mat->value(); auto sparse_val = sparse_mat->value();
bool expand_dim = false; bool expand_dim = false;
auto new_sparse_mat = sparse_mat; auto new_sparse_mat = sparse_mat;
...@@ -82,7 +84,7 @@ c10::intrusive_ptr<SparseMatrix> Softmax( ...@@ -82,7 +84,7 @@ c10::intrusive_ptr<SparseMatrix> Softmax(
new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val); new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val);
} }
auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val); auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val, dim);
if (expand_dim) { if (expand_dim) {
new_sparse_val = new_sparse_val.view(-1); new_sparse_val = new_sparse_val.view(-1);
......
...@@ -8,11 +8,10 @@ from .sparse_matrix import SparseMatrix ...@@ -8,11 +8,10 @@ from .sparse_matrix import SparseMatrix
__all__ = ["softmax"] __all__ = ["softmax"]
def softmax(input: SparseMatrix) -> SparseMatrix: def softmax(input: SparseMatrix, dim: int = 1) -> SparseMatrix:
"""Applies row-wise softmax to the non-zero elements of the sparse matrix. """Applies softmax to the non-zero elements of the sparse matrix on the
dimension :attr:``dim``. dim = 0 or 1 indicates column-wise or row-wise
Equivalently, applies softmax to the non-zero elements of the sparse softmax respectively.
matrix along the column (``dim=1``) dimension.
If :attr:`input.val` takes shape ``(nnz, D)``, then the output matrix If :attr:`input.val` takes shape ``(nnz, D)``, then the output matrix
:attr:`output` and :attr:`output.val` take the same shape as :attr:`input` :attr:`output` and :attr:`output.val` take the same shape as :attr:`input`
...@@ -32,11 +31,10 @@ def softmax(input: SparseMatrix) -> SparseMatrix: ...@@ -32,11 +31,10 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
Examples Examples
-------- --------
Case1: matrix with values of shape (nnz) Case1: row-wise softmax on matrix with values of shape (nnz)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]]) >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> nnz = len(row) >>> val = torch.tensor([0., 1., 2., 3.])
>>> val = torch.arange(nnz).float()
>>> A = dglsp.spmatrix(indices, val) >>> A = dglsp.spmatrix(indices, val)
>>> dglsp.softmax(A) >>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2], SparseMatrix(indices=tensor([[0, 0, 1, 2],
...@@ -44,7 +42,7 @@ def softmax(input: SparseMatrix) -> SparseMatrix: ...@@ -44,7 +42,7 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
values=tensor([0.2689, 0.7311, 1.0000, 1.0000]), values=tensor([0.2689, 0.7311, 1.0000, 1.0000]),
shape=(3, 3), nnz=4) shape=(3, 3), nnz=4)
Case2: matrix with values of shape (nnz, D) Case2: row-wise softmax on matrix with values of shape (nnz, D)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]]) >>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]]) >>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]])
...@@ -57,8 +55,21 @@ def softmax(input: SparseMatrix) -> SparseMatrix: ...@@ -57,8 +55,21 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
[1.0000, 1.0000], [1.0000, 1.0000],
[1.0000, 1.0000]]), [1.0000, 1.0000]]),
shape=(3, 3), nnz=4, val_size=(2,)) shape=(3, 3), nnz=4, val_size=(2,))
Case3: column-wise softmax on matrix with values of shape (nnz)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> val = torch.tensor([0., 1., 2., 3.])
>>> A = dglsp.spmatrix(indices, val)
>>> dglsp.softmax(A, 0)
SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]),
values=tensor([1.0000, 0.2689, 0.7311, 1.0000]),
shape=(3, 3), nnz=4)
""" """
return SparseMatrix(torch.ops.dgl_sparse.softmax(input.c_sparse_matrix)) return SparseMatrix(
torch.ops.dgl_sparse.softmax(input.c_sparse_matrix, dim)
)
SparseMatrix.softmax = softmax SparseMatrix.softmax = softmax
...@@ -10,7 +10,8 @@ from dgl.sparse import from_coo, softmax ...@@ -10,7 +10,8 @@ from dgl.sparse import from_coo, softmax
@pytest.mark.parametrize("val_D", [None, 2]) @pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False]) @pytest.mark.parametrize("csr", [True, False])
def test_softmax(val_D, csr): @pytest.mark.parametrize("dim", [0, 1])
def test_softmax(val_D, csr, dim):
dev = F.ctx() dev = F.ctx()
row = torch.tensor([0, 0, 1, 1]).to(dev) row = torch.tensor([0, 0, 1, 1]).to(dev)
col = torch.tensor([0, 2, 1, 2]).to(dev) col = torch.tensor([0, 2, 1, 2]).to(dev)
...@@ -27,8 +28,11 @@ def test_softmax(val_D, csr): ...@@ -27,8 +28,11 @@ def test_softmax(val_D, csr):
# Test CSR # Test CSR
A.csr() A.csr()
A_max = softmax(A) A_max = softmax(A, dim)
if dim == 1:
g = dgl.graph((col, row), num_nodes=max(A.shape)) g = dgl.graph((col, row), num_nodes=max(A.shape))
else:
g = dgl.graph((row, col), num_nodes=max(A.shape))
val_g = val.clone().requires_grad_() val_g = val.clone().requires_grad_()
score = dgl.nn.functional.edge_softmax(g, val_g) score = dgl.nn.functional.edge_softmax(g, val_g)
assert torch.allclose(A_max.val, score) assert torch.allclose(A_max.val, score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment