"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "53377ef83c6446033f3ee506e3ef718db817b293"
Commit fb28fe78 authored by rusty1s's avatar rusty1s
Browse files

add coalesced argument to spspmm call

parent 37fb98cd
import torch import torch
from torch_sparse import transpose, to_scipy, from_scipy from torch_sparse import transpose, to_scipy, from_scipy, coalesce
import torch_sparse.spspmm_cpu import torch_sparse.spspmm_cpu
...@@ -7,9 +7,9 @@ if torch.cuda.is_available(): ...@@ -7,9 +7,9 @@ if torch.cuda.is_available():
import torch_sparse.spspmm_cuda import torch_sparse.spspmm_cuda
def spspmm(indexA, valueA, indexB, valueB, m, k, n): def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""Matrix product of two sparse tensors. Both input sparse matrices need to """Matrix product of two sparse tensors. Both input sparse matrices need to
be coalesced. be coalesced (use the :obj:`coalesce` attribute to force).
Args: Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix. indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
...@@ -20,9 +20,15 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n): ...@@ -20,9 +20,15 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n):
k (int): The second dimension of first corresponding dense matrix and k (int): The second dimension of first corresponding dense matrix and
first dimension of second corresponding dense matrix. first dimension of second corresponding dense matrix.
n (int): The second dimension of second corresponding dense matrix. n (int): The second dimension of second corresponding dense matrix.
coalesced (bool, optional): If set to :obj:`False`, will coalesce both
input sparse matrices (default: :obj:`True`).
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
if indexA.is_cuda and coalesced:
indexA, valueA = coalesce(indexA, valueA, m, k)
indexB, valueB = coalesce(indexB, valueB, k, n)
index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n) index, value = SpSpMM.apply(indexA, valueA, indexB, valueB, m, k, n)
return index.detach(), value return index.detach(), value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment