Commit 3c7253aa authored by rusty1s's avatar rusty1s
Browse files

spspmm args

parent b2ba34bd
...@@ -7,7 +7,7 @@ if torch.cuda.is_available(): ...@@ -7,7 +7,7 @@ if torch.cuda.is_available():
import spspmm_cuda import spspmm_cuda
class SpSpMM(torch.autograd.Function): def spspmm(indexA, valueA, indexB, valueB, m, k, n):
"""Matrix product of two sparse tensors. Both input sparse matrices need to """Matrix product of two sparse tensors. Both input sparse matrices need to
be coalesced. be coalesced.
...@@ -23,7 +23,10 @@ class SpSpMM(torch.autograd.Function): ...@@ -23,7 +23,10 @@ class SpSpMM(torch.autograd.Function):
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
class SpSpMM(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n): def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n) indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
...@@ -53,9 +56,6 @@ class SpSpMM(torch.autograd.Function): ...@@ -53,9 +56,6 @@ class SpSpMM(torch.autograd.Function):
return None, grad_valueA, None, grad_valueB, None, None, None return None, grad_valueA, None, grad_valueB, None, None, None
spspmm = SpSpMM.apply
def mm(indexA, valueA, indexB, valueB, m, k, n): def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype assert valueA.dtype == valueB.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment