Commit 85ce67e4 authored by rusty1s's avatar rusty1s
Browse files

fix matmul jit bug

parent 91ab2667
...@@ -92,11 +92,11 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: ...@@ -92,11 +92,11 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0) assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr() rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr() rowptrB, colB, valueB = other.csr()
value = valueA value = valueA if valueA is not None else valueB
if valueA is not None and valueA.dtype == torch.half: if valueA is not None and valueA.dtype == torch.half:
valueA = valueA.to(torch.float) valueA = valueA.to(torch.float)
if valueB is not None: if valueB is not None and valueB.dtype == torch.half:
valueB = valueB.to(valueA.dtype) valueB = valueB.to(torch.float)
M, K = src.sparse_size(0), other.sparse_size(1) M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) rowptrA, colA, valueA, rowptrB, colB, valueB, K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment