Unverified Commit 2108f20e authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] moe: use all_to_all_single (#168)

parent c5e5ff78
......@@ -27,9 +27,7 @@ class _AllToAll(torch.autograd.Function):
world_size = dist.get_world_size(group)
input = input.contiguous()
output = torch.empty_like(input)
input_chunks = list(input.chunk(world_size))
output_chunks = list(output.chunk(world_size))
dist.all_to_all(output_chunks, input_chunks, group=group)
dist.all_to_all_single(output, input, group=group)
return output
@staticmethod
......
......@@ -35,7 +35,8 @@ def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional[List[int]] = None, input_split_size: Optional[List[int]] = None, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment