Commit 57bfa7ca authored by Lawrence McAfee's avatar Lawrence McAfee Committed by Jared Casper
Browse files

Perform distributed optimizer's all-gather in param dtype (instead of grad dtype)

parent 1afe3541
...@@ -145,6 +145,7 @@ def get_megatron_optimizer(model, ...@@ -145,6 +145,7 @@ def get_megatron_optimizer(model,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.fp16, args.fp16,
args.bf16, args.bf16,
args.params_dtype,
grad_scaler, grad_scaler,
model) model)
......
...@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models): fp16, bf16, params_dtype, grad_scaler, models):
""" """
See top of class definition for argument descriptions. See top of class definition for argument descriptions.
...@@ -365,7 +365,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -365,7 +365,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models) fp16, bf16, params_dtype, grad_scaler, models)
# Verify that contiguous buffers are being used. # Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py. # - Note: this should already be checked in arguments.py.
...@@ -394,6 +394,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -394,6 +394,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.model_param_gbuf_map, self.model_param_gbuf_map,
self.opt_group_ranges) self.opt_group_ranges)
# Initialize param buffers.
# - These are views on the DDP model's grad buffers, that share
# storage & have their own dtype. This is safe because the param
# dtype size is always <= grad dtype size.
self.param_buffers = []
for model_index, model in enumerate(self.models):
current_param_buffers = {}
for dtype, grad_buffer in model._grad_buffers.items():
param_buffer = torch.tensor(grad_buffer.data.storage()._untyped(),
dtype = params_dtype,
device = grad_buffer.data.device)
param_buffer = param_buffer[:grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
self.param_buffers.append(current_param_buffers)
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors. # recast preexisting per-param state tensors.
...@@ -488,36 +503,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -488,36 +503,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def get_model_grad_buffer_dp_views(self): @staticmethod
def get_model_buffer_dp_views(model_buffers):
""" """
Get shard views of each of the DDP's grad buffers. Get shard views of each of the DDP's param/grad buffers.
In this nested list, the top level is grouped by the virtual model In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of index and the buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents shards of that buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel a contiguous view of the buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel so the elements of some parameters are split across data parallel
ranks. ranks.
Additionally, return references to the entire grad buffers, for use Additionally, return references to the entire buffers, for use
in _reduce_scatter_base and _all_gather_base. in _reduce_scatter_base and _all_gather_base.
""" """
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views. # Buffer views.
gbuf_view_items = [] view_items = []
for model_index, model in enumerate(self.models): for model_index, buffers in enumerate(model_buffers):
for dtype, gbuf in model._grad_buffers.items(): for dtype, buf in buffers.items():
assert buf.numel() % data_parallel_world_size == 0
shard_size = int(buf.numel() / data_parallel_world_size)
buf_views = [buf[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
view_items.append((model_index, dtype, buf, buf_views))
assert gbuf.numel_padded % data_parallel_world_size == 0 return view_items
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
return self.get_model_buffer_dp_views([
{dtype : mem_buffer.data}
for model in self.models
for dtype, mem_buffer in model._grad_buffers.items()])
def get_model_param_buffer_dp_views(self):
return self.get_model_buffer_dp_views(self.param_buffers)
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
...@@ -574,9 +601,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -574,9 +601,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
""" """
All-gather updated model params. All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no The DDP's param buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param. can be copied from the param buffer to the param.
""" """
timers('params-all-gather', log_level=1).start( timers('params-all-gather', log_level=1).start(
...@@ -586,26 +613,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -586,26 +613,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params. # All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements # - All param buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done # across all data parallel ranks, due to grad buffer padding that is
# in distributed.py. Thus, all sub-views will have consistent start/end # done in distributed.py, and extended to the param buffers. Thus,
# indexes across data parallel ranks. # all sub-views will have consistent start/end indexes across data
gbuf_view_items = self.get_model_grad_buffer_dp_views() # parallel ranks.
for index, (model_index, dtype, gbuf, gbuf_views) \ pbuf_view_items = self.get_model_param_buffer_dp_views()
in enumerate(gbuf_view_items): for index, (model_index, dtype, pbuf, pbuf_views) \
in enumerate(pbuf_view_items):
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
gbuf, pbuf,
gbuf_views[data_parallel_rank], pbuf_views[data_parallel_rank],
group = data_parallel_group, group = data_parallel_group,
) )
# Each model param now contains its updated values in its # Copy from param buffer to each param.
# '.main_grad' field. for model_id, model in enumerate(self.models):
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items(): for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map: for param, buf_range in param_map.items():
param.detach().copy_(param.main_grad) param_buf = self.param_buffers[model_id][dtype]
param_buf_shard = param_buf[buf_range[0]:buf_range[1]]
param.view(-1).detach().copy_(param_buf_shard)
timers('params-all-gather').stop() timers('params-all-gather').stop()
...@@ -685,14 +714,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -685,14 +714,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_group): model_group):
param_range_map = self.get_model_param_range_map(model_param) param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"] world_range = param_range_map["gbuf_world"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad assert world_range.size == shard_main_param.nelement()
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end] model_id, dtype = self.model_param_gbuf_map[model_param]
model_param_buffer = self.param_buffers[model_id][dtype]
shard_model_param = model_param_buffer.view(-1) \
[world_range.start:world_range.end]
shard_model_grad.data.copy_(shard_main_param) shard_model_param.data.copy_(shard_main_param)
# Copy shard groups to model groups. # Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, copy_group_params(self.shard_fp32_from_float16_groups,
......
...@@ -335,6 +335,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -335,6 +335,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
is using a contiguous buffer to hold the model grads. is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16. fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16. bf16: if true, the model is running in bfloat16.
params_dtype: used by distributed optimizer.
grad_scaler: used for scaling gradients. Note that this can be grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have use any loss scale. Note that for `bf16 = True`, we can have
...@@ -346,7 +347,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -346,7 +347,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, fp16, bf16, params_dtype, grad_scaler,
models): models):
super().__init__( super().__init__(
...@@ -356,6 +357,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -356,6 +357,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self.fp16 = fp16 self.fp16 = fp16
self.bf16 = bf16 self.bf16 = bf16
self.params_dtype = params_dtype
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16. # None grad scaler is only supported for bf16.
...@@ -505,12 +507,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -505,12 +507,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models): fp16, bf16, params_dtype, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models) fp16, bf16, params_dtype, grad_scaler, models)
# ====================== # ======================
# main parameter stuff # main parameter stuff
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment