"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "27d11a0094e292a8d790714d1b5cdf5e9186814d"
Commit 4f8a6cc9 authored by rusty1s's avatar rusty1s
Browse files

cleanup

parent 9a4ea01b
...@@ -15,12 +15,12 @@ def spmm(index, value, m, n, matrix): ...@@ -15,12 +15,12 @@ def spmm(index, value, m, n, matrix):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
assert n == matrix.shape[-2] assert n == matrix.size(-2)
row, col = index row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix[..., col, :] out = matrix.index_select(-2, col)
out = out * value.unsqueeze(-1) out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=-2, dim_size=m) out = scatter_add(out, row, dim=-2, dim_size=m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment