Commit 85ce67e4 authored by rusty1s's avatar rusty1s
Browse files

fix matmul jit bug

parent 91ab2667
......@@ -92,11 +92,11 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr()
value = valueA
value = valueA if valueA is not None else valueB
if valueA is not None and valueA.dtype == torch.half:
valueA = valueA.to(torch.float)
if valueB is not None:
valueB = valueB.to(valueA.dtype)
if valueB is not None and valueB.dtype == torch.half:
valueB = valueB.to(torch.float)
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment