Unverified Commit b97ddcf7 authored by Ziyue Yang's avatar Ziyue Yang Committed by GitHub
Browse files

Fix wrong torch usage in communication wrapper for Distributed Inference Benchmark (#505)

**Description**
This commit fixes wrong `torch.empty_like` usage and missing dtype and
device argument in communication wrappers.
parent 9d250cdd
...@@ -121,7 +121,7 @@ def __all_gather_wrapper(self, x): ...@@ -121,7 +121,7 @@ def __all_gather_wrapper(self, x):
Return: Return:
Tensor after all-gather. Tensor after all-gather.
""" """
output = torch.empty_like([x.shape[0] * self.num_ranks] + list(x.shape[1:])) output = torch.empty([x.shape[0] * self.num_ranks] + list(x.shape[1:]), dtype=x.dtype, device=x.device)
dist.all_gather_into_tensor(output, x) dist.all_gather_into_tensor(output, x)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment