# import torch from torch_scatter import scatter_add def spmm(index, value, m, n, matrix): """Matrix product of sparse matrix with dense matrix. Args: index (:class:`LongTensor`): The index tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix. m (int): The first dimension of sparse matrix. n (int): The second dimension of sparse matrix. matrix (:class:`Tensor`): The dense matrix. :rtype: :class:`Tensor` """ assert n == matrix.size(-2) row, col = index matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) out = matrix.index_select(-2, col) out = out * value.unsqueeze(-1) out = scatter_add(out, row, dim=-2, dim_size=m) return out