Commit 0fa23d82 authored by rusty1s's avatar rusty1s
Browse files

fixed true divide

parent 1538c112
......@@ -93,7 +93,11 @@ public:
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
if (out.is_floating_point())
out.true_divide_(count);
else
out.floor_divide_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
......@@ -110,7 +114,7 @@ public:
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
......
......@@ -97,7 +97,7 @@ public:
count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable()};
}
};
......
......@@ -95,7 +95,7 @@ public:
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
grad_in.true_divide_(count);
}
return {grad_in, Variable(), Variable()};
}
......
......@@ -11,4 +11,4 @@ if torch.cuda.is_available():
def tensor(x, dtype, device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
return None if x is None else torch.tensor(x, device=device).to(dtype)
......@@ -49,7 +49,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1)
count = broadcast(count, out, dim)
out.div_(count)
if torch.is_floating_point(out):
out.true_divide_(count)
else:
out.floor_divide_(count)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment