Commit 37181ef4 authored by mshoeybi's avatar mshoeybi
Browse files

changed all_gather to _all_gather_base in distributed checkpointing

parent 8c8063eb
...@@ -87,17 +87,16 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -87,17 +87,16 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size() numel_gathered = torch.numel(tensor) * \
numel = torch.numel(tensor) get_tensor_model_parallel_world_size()
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] torch.distributed._all_gather_base(gathered, tensor,
torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group())
group=get_tensor_model_parallel_group())
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad): def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor. '''Make a viewless tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment