Unverified Commit b97ddcf7 authored by Ziyue Yang's avatar Ziyue Yang Committed by GitHub
Browse files

Fix wrong torch usage in communication wrapper for Distributed Inference Benchmark (#505)

**Description**
This commit fixes wrong `torch.empty_like` usage and missing dtype and
device argument in communication wrappers.
parent 9d250cdd
......@@ -121,7 +121,7 @@ def __all_gather_wrapper(self, x):
Return:
Tensor after all-gather.
"""
output = torch.empty_like([x.shape[0] * self.num_ranks] + list(x.shape[1:]))
output = torch.empty([x.shape[0] * self.num_ranks] + list(x.shape[1:]), dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(output, x)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment