Commit c72f36c0 authored by rusty1s's avatar rusty1s
Browse files

upgrade to PyTorch 1.10

parent 605566ec
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
matrix: matrix:
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
python-version: [3.6] python-version: [3.6]
torch-version: [1.8.0, 1.9.0] torch-version: [1.9.0, 1.10.0]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
......
...@@ -274,7 +274,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -274,7 +274,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
if (out.is_floating_point()) if (out.is_floating_point())
out.true_divide_(count); out.true_divide_(count);
else else
out.floor_divide_(count); out.div_(count, "floor");
} }
}); });
}); });
......
...@@ -132,7 +132,7 @@ public: ...@@ -132,7 +132,7 @@ public:
if (out.is_floating_point()) if (out.is_floating_point())
out.true_divide_(count); out.true_divide_(count);
else else
out.floor_divide_(count); out.div_(count, "floor");
ctx->save_for_backward({index, count}); ctx->save_for_backward({index, count});
if (optional_out.has_value()) if (optional_out.has_value())
......
...@@ -38,7 +38,6 @@ def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -38,7 +38,6 @@ def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size) out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim) dim_size = out.size(dim)
...@@ -55,7 +54,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -55,7 +54,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
if out.is_floating_point(): if out.is_floating_point():
out.true_divide_(count) out.true_divide_(count)
else: else:
out.floor_divide_(count) out.div_(count, rounding_mode='floor')
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment