Unverified Commit 6ed6ed29 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

fix device issue (#20227)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d3d5fa3e
...@@ -1784,9 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name): ...@@ -1784,9 +1784,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
# changed "view" by "reshape" in the following line # changed "view" by "reshape" in the following line
flat_values = values.reshape(flattened_shape.tolist()) flat_values = values.reshape(flattened_shape.tolist())
out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype) out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
segment_means = out.scatter_reduce( segment_means = out.scatter_reduce(
dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
) )
# Unflatten the values. # Unflatten the values.
...@@ -1799,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name): ...@@ -1799,7 +1799,7 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
dim=0, dim=0,
) )
output_values = segment_means.clone().view(new_shape.tolist()) output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
output_index = range_index_map(index.batch_shape(), index.num_segments) output_index = range_index_map(index.batch_shape(), index.num_segments)
return output_values, output_index return output_values, output_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment