"vscode:/vscode.git/clone" did not exist on "f7dfcfd971131fbaae43d9d9f59e5e3a9aa6234a"
Commit 3c7253aa authored by rusty1s's avatar rusty1s
Browse files

spspmm args

parent b2ba34bd
......@@ -7,7 +7,7 @@ if torch.cuda.is_available():
import spspmm_cuda
class SpSpMM(torch.autograd.Function):
def spspmm(indexA, valueA, indexB, valueB, m, k, n):
"""Matrix product of two sparse tensors. Both input sparse matrices need to
be coalesced.
......@@ -23,7 +23,10 @@ class SpSpMM(torch.autograd.Function):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
return SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
class SpSpMM(torch.autograd.Function):
@staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
indexC, valueC = mm(indexA, valueA, indexB, valueB, m, k, n)
......@@ -53,9 +56,6 @@ class SpSpMM(torch.autograd.Function):
return None, grad_valueA, None, grad_valueB, None, None, None
spspmm = SpSpMM.apply
def mm(indexA, valueA, indexB, valueB, m, k, n):
assert valueA.dtype == valueB.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment