"dgl_sparse/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "da1d5de2f16d2f0dc04589c5569d7afaaf1f486b"
Unverified Commit 16dd9584 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Polish softmax.py. (#5182)



* polish

* [Sparse] Polish softmax.py.

* fix parameter and example
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 82eb3d71
"""Softmax op for SparseMatrix""" """Softmax op for SparseMatrix"""
# pylint: disable=invalid-name # pylint: disable=invalid-name, W0622
import torch import torch
...@@ -8,17 +8,17 @@ from .sparse_matrix import SparseMatrix ...@@ -8,17 +8,17 @@ from .sparse_matrix import SparseMatrix
__all__ = ["softmax"] __all__ = ["softmax"]
def softmax(A: SparseMatrix) -> SparseMatrix: def softmax(input: SparseMatrix) -> SparseMatrix:
"""Apply row-wise softmax to the non-zero entries of the sparse matrix. """Apply row-wise softmax to the non-zero elements of the sparse matrix.
If :attr:`A.val` takes shape :attr:`(nnz, D)`, then the output matrix If :attr:`input.val` takes shape :attr:`(nnz, D)`, then the output matrix
:attr:`A'` and :attr:`A'.val` take the same shape as :attr:`A` and :attr:`output` and :attr:`output.val` take the same shape as :attr:`input`
:attr:`A.val`. :attr:`A'.val[:, i]` is calculated based on and :attr:`input.val`. :attr:`output.val[:, i]` is calculated based on
:attr:`A.val[:, i]`. :attr:`input.val[:, i]`.
Parameters Parameters
---------- ----------
A : SparseMatrix input : SparseMatrix
The input sparse matrix The input sparse matrix
Returns Returns
...@@ -35,8 +35,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix: ...@@ -35,8 +35,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
>>> col = torch.tensor([1, 2, 2, 0]) >>> col = torch.tensor([1, 2, 2, 0])
>>> nnz = len(row) >>> nnz = len(row)
>>> val = torch.arange(nnz).float() >>> val = torch.arange(nnz).float()
>>> A = from_coo(row, col, val) >>> A = dglsp.from_coo(row, col, val)
>>> softmax(A) >>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2], SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]), [1, 2, 2, 0]]),
values=tensor([0.2689, 0.7311, 1.0000, 1.0000]), values=tensor([0.2689, 0.7311, 1.0000, 1.0000]),
...@@ -45,8 +45,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix: ...@@ -45,8 +45,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
Case2: matrix with values of shape (nnz, D) Case2: matrix with values of shape (nnz, D)
>>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]]) >>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]])
>>> A = from_coo(row, col, val) >>> A = dglsp.from_coo(row, col, val)
>>> softmax(A) >>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2], SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]), [1, 2, 2, 0]]),
values=tensor([[0.2689, 0.9820], values=tensor([[0.2689, 0.9820],
...@@ -55,7 +55,7 @@ def softmax(A: SparseMatrix) -> SparseMatrix: ...@@ -55,7 +55,7 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
[1.0000, 1.0000]]), [1.0000, 1.0000]]),
shape=(3, 3), nnz=4) shape=(3, 3), nnz=4)
""" """
return SparseMatrix(torch.ops.dgl_sparse.softmax(A.c_sparse_matrix)) return SparseMatrix(torch.ops.dgl_sparse.softmax(input.c_sparse_matrix))
SparseMatrix.softmax = softmax SparseMatrix.softmax = softmax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment