"vscode:/vscode.git/clone" did not exist on "578d3d9d0927977769db6218a61cd1a7ccf65427"
Commit 0ed2f6ac authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'all_gather_base' into 'main'

changed all_gather to _all_gather_base in distributed checkpointing

See merge request ADLR/megatron-lm!395
parents 8c8063eb 90ce932d
...@@ -87,17 +87,21 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -87,17 +87,21 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size() numel_gathered = torch.numel(tensor) * \
numel = torch.numel(tensor) get_tensor_model_parallel_world_size()
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)] # TODO: This API is experimental in pytorch (as of Feb 2022) and
torch.distributed.all_gather(chunks, tensor, # this might break in future pytorch releases. We chose this API
group=get_tensor_model_parallel_group()) # as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=get_tensor_model_parallel_group())
return gathered return gathered
def _kernel_make_viewless_tensor(inp, requires_grad): def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor. '''Make a viewless tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment