Unverified Commit 16dd9584 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Sparse] Polish softmax.py. (#5182)



* polish

* [Sparse] Polish softmax.py.

* fix parameter and example
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 82eb3d71
"""Softmax op for SparseMatrix"""
# pylint: disable=invalid-name
# pylint: disable=invalid-name, W0622
import torch
......@@ -8,17 +8,17 @@ from .sparse_matrix import SparseMatrix
__all__ = ["softmax"]
def softmax(A: SparseMatrix) -> SparseMatrix:
"""Apply row-wise softmax to the non-zero entries of the sparse matrix.
def softmax(input: SparseMatrix) -> SparseMatrix:
"""Apply row-wise softmax to the non-zero elements of the sparse matrix.
If :attr:`A.val` takes shape :attr:`(nnz, D)`, then the output matrix
:attr:`A'` and :attr:`A'.val` take the same shape as :attr:`A` and
:attr:`A.val`. :attr:`A'.val[:, i]` is calculated based on
:attr:`A.val[:, i]`.
If :attr:`input.val` takes shape :attr:`(nnz, D)`, then the output matrix
:attr:`output` and :attr:`output.val` take the same shape as :attr:`input`
and :attr:`input.val`. :attr:`output.val[:, i]` is calculated based on
:attr:`input.val[:, i]`.
Parameters
----------
A : SparseMatrix
input : SparseMatrix
The input sparse matrix
Returns
......@@ -35,8 +35,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
>>> col = torch.tensor([1, 2, 2, 0])
>>> nnz = len(row)
>>> val = torch.arange(nnz).float()
>>> A = from_coo(row, col, val)
>>> softmax(A)
>>> A = dglsp.from_coo(row, col, val)
>>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]),
values=tensor([0.2689, 0.7311, 1.0000, 1.0000]),
......@@ -45,8 +45,8 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
Case2: matrix with values of shape (nnz, D)
>>> val = torch.tensor([[0., 7.], [1., 3.], [2., 2.], [3., 1.]])
>>> A = from_coo(row, col, val)
>>> softmax(A)
>>> A = dglsp.from_coo(row, col, val)
>>> dglsp.softmax(A)
SparseMatrix(indices=tensor([[0, 0, 1, 2],
[1, 2, 2, 0]]),
values=tensor([[0.2689, 0.9820],
......@@ -55,7 +55,7 @@ def softmax(A: SparseMatrix) -> SparseMatrix:
[1.0000, 1.0000]]),
shape=(3, 3), nnz=4)
"""
return SparseMatrix(torch.ops.dgl_sparse.softmax(A.c_sparse_matrix))
return SparseMatrix(torch.ops.dgl_sparse.softmax(input.c_sparse_matrix))
SparseMatrix.softmax = softmax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment