# import torch from torch_scatter import scatter_add def spmm(index, value, m, n, matrix): """Matrix product of sparse matrix with dense matrix. Args: index (:class:`LongTensor`): The index tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix. m (int): The first dimension of corresponding dense matrix. n (int): The second dimension of corresponding dense matrix. matrix (:class:`Tensor`): The dense matrix. :rtype: :class:`Tensor` """ assert n == matrix.shape[-2] row, col = index matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) out = matrix[..., col, :] out = out * value.unsqueeze(-1) out = scatter_add(out, row, dim=-2, dim_size=m) return out