Unverified Commit 5ffd2a02 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Support column-wise softmax (#5377)



* [Sparse] Support column-wise softmax

* Update python/dgl/sparse/softmax.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>

---------
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent c42fa8a5
......@@ -12,17 +12,20 @@ namespace dgl {
namespace sparse {
/**
* @brief Apply row-wise softmax to the non-zero entries of the sparse matrix.
* @brief Apply softmax to the non-zero entries of the sparse matrix on the
* dimension dim. dim = 0 or 1 indicates column-wise or row-wise softmax
* respectively.
*
* This function supports autograd for the sparse matrix, but it does not
* support higher order gradient.
*
* @param sparse_mat The sparse matrix
* @param dim The dimension to apply softmax
*
* @return Sparse matrix
*/
c10::intrusive_ptr<SparseMatrix> Softmax(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat);
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim);
} // namespace sparse
} // namespace dgl
......
......@@ -112,7 +112,7 @@ torch::Tensor SDDMMNoAutoGrad(
torch::Tensor BroadcastOpNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
const std::string& op) {
const std::string& op, int64_t dim) {
auto sparse_val = sparse_mat->value();
const int64_t out_row = sparse_mat->nnz();
const std::vector<int64_t> shape({out_row, sparse_val.size(1)});
......@@ -121,6 +121,9 @@ torch::Tensor BroadcastOpNoAutoGrad(
auto dgl_sparse_val = TorchTensorToDGLArray(sparse_val);
auto dgl_dense_mat = TorchTensorToDGLArray(dense_mat);
auto dgl_ret = TorchTensorToDGLArray(ret);
// Setting dgl_rhs_target to 0 or 2 means using row or column coordinators
// to access dgl_dense_mat for each edge, respectively.
auto dgl_rhs_target = dim == 0 ? 2 : 0;
// The format for calculation will be chosen in the following order: COO, CSR
// . COO is created if the sparse matrix only has CSC format.
......@@ -130,32 +133,32 @@ torch::Tensor BroadcastOpNoAutoGrad(
auto coo = COOToOldDGLCOO(sparse_mat->COOPtr());
aten::COOSDDMM(
op.c_str(), coo, dgl_sparse_val, dgl_dense_mat, dgl_ret,
1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */);
1 /* Lhs target: e */, dgl_rhs_target);
} else {
auto csr = CSRToOldDGLCSR(sparse_mat->CSRPtr());
aten::CSRSDDMM(
op.c_str(), csr, dgl_sparse_val, dgl_dense_mat, dgl_ret,
1 /* Lhs target: e */, 0 /* rhs target: u due to transpose */);
1 /* Lhs target: e */, dgl_rhs_target);
}
return ret;
}
torch::Tensor BroadcastSubNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub");
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "sub", dim);
}
torch::Tensor BroadcastDivNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div");
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "div", dim);
}
torch::Tensor BroadcastMulNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul");
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim) {
return BroadcastOpNoAutoGrad(sparse_mat, dense_mat, "mul", dim);
}
c10::intrusive_ptr<SparseMatrix> SpSpMMNoAutoGrad(
......
......@@ -57,68 +57,76 @@ torch::Tensor SDDMMNoAutoGrad(
/**
* @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = \phi(x_e, x_v), where x_e is the nonzero value, x_v is the dense
* feature, and \phi is add, sub, mul, or div.
* x_e = \phi(x_e, x_v) on the dimension dim, where x_e is the nonzero value,
* x_v is the dense feature, and \phi is add, sub, mul, or div. dim = 0 or 1
* means column-wise or row-wise broadcast respectively.
*
* This function does not take care of autograd.
*
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D)
* @param op Operator, can be add, sub, mul, or div
* @param dim The dimension to broadcast.
*
* @return Dense tensor of shape (nnz, D)
*/
torch::Tensor BroadcastOpNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
const std::string& op);
const std::string& op, int64_t dim);
/**
* @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e - x_v, where x_e is the nonzero value, x_v is the dense
* feature.
* x_e = x_e - x_v on the dimension dim, where x_e is the nonzero value, x_v is
* the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
*
* This function does not take care of autograd.
*
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
*
* @return Dense tensor of shape (nnz, D)
*/
torch::Tensor BroadcastSubNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat);
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim);
/**
* @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e / x_v, where x_e is the nonzero value, x_v is the dense
* feature.
* x_e = x_e / x_v on the dimension dim, where x_e is the nonzero value, x_v is
* the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
*
* This function does not take care of autograd.
*
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
*
* @return Dense tensor of shape (nnz, D)
*/
torch::Tensor BroadcastDivNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat);
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim);
/**
* @brief Broadcast the dense feature to the nonzero entries and then compute
* x_e = x_e * x_v, where x_e is the nonzero value, x_v is the dense
* feature.
* x_e = x_e * x_v on the dimension dim, where x_e is the nonzero value, x_v is
* the dense feature. dim = 0 or 1 means column-wise or row-wise broadcast
* respectively.
*
* This function does not take care of autograd.
*
* @param sparse_mat The sparse matrix with N rows and (nnz, D) nonzero values
* @param dense_mat Dense feature of shape (N, D)
* @param dim The dimension to broadcast.
*
* @return Dense tensor of shape (nnz, D)
*/
torch::Tensor BroadcastMulNoAutoGrad(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
torch::Tensor dense_mat);
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor dense_mat,
int64_t dim);
/**
* @brief Perform a sparse-sparse matrix multiplication with possibly different
......
......@@ -20,22 +20,22 @@ class SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> {
public:
static torch::Tensor forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
torch::Tensor sparse_val);
torch::Tensor sparse_val, int64_t dim);
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};
torch::Tensor SoftmaxAutoGrad::forward(
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
torch::Tensor sparse_val) {
torch::Tensor sparse_val, int64_t dim) {
// Reduce by columns with dim 1.
auto sparse_val_max = ReduceMax(sparse_mat, 1);
auto sparse_val_max = ReduceMax(sparse_mat, dim);
auto sparse_val_exp =
BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max).exp();
BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max, dim).exp();
auto sparse_val_sum =
ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), 1);
ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), dim);
auto sparse_score = BroadcastDivNoAutoGrad(
SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum);
SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum, dim);
const bool sparse_requires_grad = sparse_val.requires_grad();
torch::Tensor cache_sparse_score;
......@@ -44,6 +44,7 @@ torch::Tensor SoftmaxAutoGrad::forward(
}
ctx->saved_data["sparse_matrix"] = sparse_mat;
ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({cache_sparse_score});
return sparse_score;
}
......@@ -58,21 +59,22 @@ tensor_list SoftmaxAutoGrad::backward(
ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
const bool sparse_requires_grad =
ctx->saved_data["sparse_requires_grad"].toBool();
const int64_t dim = ctx->saved_data["dim"].toInt();
torch::Tensor sparse_val_grad;
if (sparse_requires_grad) {
auto sds = sparse_score * output_grad;
auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), 1);
auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), dim);
sparse_val_grad =
sds - BroadcastMulNoAutoGrad(
SparseMatrix::ValLike(sparse_mat, sparse_score), accum);
SparseMatrix::ValLike(sparse_mat, sparse_score), accum, dim);
}
return {torch::Tensor(), sparse_val_grad};
return {torch::Tensor(), sparse_val_grad, torch::Tensor()};
}
c10::intrusive_ptr<SparseMatrix> Softmax(
const c10::intrusive_ptr<SparseMatrix>& sparse_mat) {
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim) {
auto sparse_val = sparse_mat->value();
bool expand_dim = false;
auto new_sparse_mat = sparse_mat;
......@@ -82,7 +84,7 @@ c10::intrusive_ptr<SparseMatrix> Softmax(
new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val);
}
auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val);
auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val, dim);
if (expand_dim) {
new_sparse_val = new_sparse_val.view(-1);
......
......@@ -8,11 +8,10 @@ from .sparse_matrix import SparseMatrix
__all__ = ["softmax"]
def softmax(input: SparseMatrix) -> SparseMatrix:
"""Applies row-wise softmax to the non-zero elements of the sparse matrix.
Equivalently, applies softmax to the non-zero elements of the sparse
matrix along the column (``dim=1``) dimension.
def softmax(input: SparseMatrix, dim: int = 1) -> SparseMatrix:
"""Applies softmax to the non-zero elements of the sparse matrix on the
dimension :attr:``dim``. dim = 0 or 1 indicates column-wise or row-wise
softmax respectively.
If :attr:`input.val` takes shape ``(nnz, D)``, then the output matrix
:attr:`output` and :attr:`output.val` take the same shape as :attr:`input`
......@@ -32,11 +31,10 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
Examples
--------
Case1: matrix with values of shape (nnz)
Case1: row-wise softmax on matrix with values of shape (nnz)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> nnz = len(row)
>>> val = torch.arange(nnz).float()
>>> val = torch.tensor([0., 1., 2., 3.])
>>> A = dglsp.spmatrix(indices, val)
>>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2],
......@@ -44,7 +42,7 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
values=tensor([0.2689, 0.7311, 1.0000, 1.0000]),
shape=(3, 3), nnz=4)
Case2: matrix with values of shape (nnz, D)
Case2: row-wise softmax on matrix with values of shape (nnz, D)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]])
......@@ -57,8 +55,21 @@ def softmax(input: SparseMatrix) -> SparseMatrix:
[1.0000, 1.0000],
[1.0000, 1.0000]]),
shape=(3, 3), nnz=4, val_size=(2,))
Case3: column-wise softmax on matrix with values of shape (nnz)
>>> indices = torch.tensor([[0, 0, 1, 2], [1, 2, 2, 0]])
>>> val = torch.tensor([0., 1., 2., 3.])
>>> A = dglsp.spmatrix(indices, val)
>>> dglsp.softmax(A, 0)
SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]),
values=tensor([1.0000, 0.2689, 0.7311, 1.0000]),
shape=(3, 3), nnz=4)
"""
return SparseMatrix(torch.ops.dgl_sparse.softmax(input.c_sparse_matrix))
return SparseMatrix(
torch.ops.dgl_sparse.softmax(input.c_sparse_matrix, dim)
)
SparseMatrix.softmax = softmax
......@@ -10,7 +10,8 @@ from dgl.sparse import from_coo, softmax
@pytest.mark.parametrize("val_D", [None, 2])
@pytest.mark.parametrize("csr", [True, False])
def test_softmax(val_D, csr):
@pytest.mark.parametrize("dim", [0, 1])
def test_softmax(val_D, csr, dim):
dev = F.ctx()
row = torch.tensor([0, 0, 1, 1]).to(dev)
col = torch.tensor([0, 2, 1, 2]).to(dev)
......@@ -27,8 +28,11 @@ def test_softmax(val_D, csr):
# Test CSR
A.csr()
A_max = softmax(A)
g = dgl.graph((col, row), num_nodes=max(A.shape))
A_max = softmax(A, dim)
if dim == 1:
g = dgl.graph((col, row), num_nodes=max(A.shape))
else:
g = dgl.graph((row, col), num_nodes=max(A.shape))
val_g = val.clone().requires_grad_()
score = dgl.nn.functional.edge_softmax(g, val_g)
assert torch.allclose(A_max.val, score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment