Unverified Commit 6ed6ed29 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

fix device issue (#20227)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d3d5fa3e
......@@ -1784,9 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
# changed "view" by "reshape" in the following line
flat_values = values.reshape(flattened_shape.tolist())
out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype)
out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
segment_means = out.scatter_reduce(
dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False
dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
)
# Unflatten the values.
......@@ -1799,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
dim=0,
)
output_values = segment_means.clone().view(new_shape.tolist())
output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
output_index = range_index_map(index.batch_shape(), index.num_segments)
return output_values, output_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment