Commit 87c88d95 authored by rusty1s's avatar rusty1s
Browse files

enable autocast

parent bdd1ced8
......@@ -11,6 +11,9 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
......@@ -35,6 +38,9 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
......@@ -51,12 +57,20 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
......@@ -81,8 +95,8 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
value = valueA
if valueA is not None and valueA.dtype == torch.half:
valueA = valueA.to(torch.float)
if valueB is not None and valueB.dtype == torch.half:
valueB = valueB.to(torch.float)
if valueB is not None:
valueB = valueB.to(valueA.dtype)
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment