"git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "b44a4218ff8a1a2a9afef6bfe20794ae7d4d4a58"
Commit 0fa23d82 authored by rusty1s's avatar rusty1s
Browse files

fixed true divide

parent 1538c112
...@@ -93,7 +93,11 @@ public: ...@@ -93,7 +93,11 @@ public:
auto count = std::get<0>(result); auto count = std::get<0>(result);
count.clamp_(1); count.clamp_(1);
count = broadcast(count, out, dim); count = broadcast(count, out, dim);
out.div_(count);
if (out.is_floating_point())
out.true_divide_(count);
else
out.floor_divide_(count);
ctx->save_for_backward({index, count}); ctx->save_for_backward({index, count});
if (optional_out.has_value()) if (optional_out.has_value())
...@@ -110,7 +114,7 @@ public: ...@@ -110,7 +114,7 @@ public:
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false); count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false); auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count); grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()}; return {grad_in, Variable(), Variable(), Variable(), Variable()};
} }
}; };
......
...@@ -97,7 +97,7 @@ public: ...@@ -97,7 +97,7 @@ public:
count = gather_coo_fw(count, index, torch::nullopt); count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++) for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1); count = count.unsqueeze(-1);
grad_in.div_(count); grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable()}; return {grad_in, Variable(), Variable(), Variable()};
} }
}; };
......
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
count = gather_csr_fw(count, indptr, torch::nullopt); count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++) for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1); count = count.unsqueeze(-1);
grad_in.div_(count); grad_in.true_divide_(count);
} }
return {grad_in, Variable(), Variable()}; return {grad_in, Variable(), Variable()};
} }
......
...@@ -11,4 +11,4 @@ if torch.cuda.is_available(): ...@@ -11,4 +11,4 @@ if torch.cuda.is_available():
def tensor(x, dtype, device): def tensor(x, dtype, device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, device=device).to(dtype)
...@@ -49,7 +49,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -49,7 +49,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
count = scatter_sum(ones, index, index_dim, None, dim_size) count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1) count.clamp_(1)
count = broadcast(count, out, dim) count = broadcast(count, out, dim)
out.div_(count) if torch.is_floating_point(out):
out.true_divide_(count)
else:
out.floor_divide_(count)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment