Commit 66bcc36e authored by rusty1s's avatar rusty1s
Browse files

fix pytorch 1.8.0 compatibility

parent 1f49a3a5
...@@ -277,9 +277,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -277,9 +277,9 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
for (int i = dim + 1; i < out.dim(); i++) for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1); count = count.unsqueeze(-1);
if (out.is_floating_point()) if (out.is_floating_point())
out.div_(count); out.true_divide_(count);
else else
out.div_(count, "floor"); out.floor_divide_(count);
} }
}); });
}); });
......
...@@ -130,9 +130,9 @@ public: ...@@ -130,9 +130,9 @@ public:
count.masked_fill_(count < 1, 1); count.masked_fill_(count < 1, 1);
count = broadcast(count, out, dim); count = broadcast(count, out, dim);
if (out.is_floating_point()) if (out.is_floating_point())
out.div_(count); out.true_divide_(count);
else else
out.div_(count, "floor"); out.floor_divide_(count);
ctx->save_for_backward({index, count}); ctx->save_for_backward({index, count});
if (optional_out.has_value()) if (optional_out.has_value())
......
...@@ -52,8 +52,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -52,8 +52,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
count = scatter_sum(ones, index, index_dim, None, dim_size) count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1 count[count < 1] = 1
count = broadcast(count, out, dim) count = broadcast(count, out, dim)
rounding_mode = None if torch.is_floating_point(out) else 'floor' if out.is_floating_point():
out.div_(count, rounding_mode=rounding_mode) out.true_divide_(count)
else:
out.floor_divide_(count)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment