Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -262,32 +262,56 @@ def _get_megatron_optimizer_based_on_param_groups( ...@@ -262,32 +262,56 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns: Returns:
Instance of MegatronOptimizer. Instance of MegatronOptimizer.
""" """
if config.optimizer == 'adam':
optimizer = Adam(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
)
def init_state_fn(opt): # when freezing sub-models we may have no trainable parameters on a rank and
for group in opt.param_groups: # hence an empty param_groups. However, we still need to create an optimizer
for p in group['params']: # for the purposes of grad stats reductions
if len(opt.state[p]) == 0: if param_groups:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data) if config.optimizer == 'adam':
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) kwargs = {
"params": param_groups,
elif config.optimizer == 'sgd': "lr": config.lr,
optimizer = SGD( "weight_decay": config.weight_decay,
param_groups, "betas": (config.adam_beta1, config.adam_beta2),
lr=config.lr, "eps": config.adam_eps,
weight_decay=config.weight_decay, }
momentum=config.sgd_momentum,
) if config.use_precision_aware_optimizer:
init_state_fn = None kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else: else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer)) optimizer = None
init_state_fn = None
# Mixed precision optimizer. # Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit # - Note: both the Float16Optimizer and the DistributedOptimizer inherit
...@@ -407,6 +431,7 @@ def get_megatron_optimizer( ...@@ -407,6 +431,7 @@ def get_megatron_optimizer(
model_chunk.overlap_param_gather_with_optimizer_step = ( model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step overlap_param_gather_with_optimizer_step
) )
optimizers.append( optimizers.append(
_get_megatron_optimizer_based_on_param_groups( _get_megatron_optimizer_based_on_param_groups(
config, config,
......
...@@ -139,6 +139,7 @@ def clip_grad_by_total_norm_fp32( ...@@ -139,6 +139,7 @@ def clip_grad_by_total_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor], parameters: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float], max_norm: Union[int, float],
total_norm: float, total_norm: float,
use_decoupled_grad: bool = False,
): ):
"""Clips gradient of an iterable of parameters in fp32 by total norm. """Clips gradient of an iterable of parameters in fp32 by total norm.
...@@ -149,15 +150,23 @@ def clip_grad_by_total_norm_fp32( ...@@ -149,15 +150,23 @@ def clip_grad_by_total_norm_fp32(
single Tensor that will have gradients normalized. single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients. max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients. total_norm (float): total norm of the gradients.
use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
""" """
# Grads. # Grads.
params = [] params = []
grads = [] grads = []
for param in parameters: for param in parameters:
if param.grad is not None: if use_decoupled_grad:
assert param.grad.type() == 'torch.cuda.FloatTensor' if hasattr(param, "decoupled_grad") and param.decoupled_grad is not None:
params.append(param) assert param.decoupled_grad.dtype in [torch.float32, torch.bfloat16]
grads.append(to_local_if_dtensor(param.grad).detach()) params.append(param)
grads.append(to_local_if_dtensor(param.decoupled_grad).detach())
else:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
grads.append(to_local_if_dtensor(param.grad).detach())
# Scale. # Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6) clip_coeff = max_norm / (total_norm + 1.0e-6)
...@@ -171,6 +180,7 @@ def clip_grad_by_total_norm_fp32( ...@@ -171,6 +180,7 @@ def clip_grad_by_total_norm_fp32(
def count_zeros_fp32( def count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor], parameters: Union[List[torch.Tensor], torch.Tensor],
grad_stats_parallel_group: torch.distributed.ProcessGroup, grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
) -> float: ) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of """Counts the number of zeros in gradients associated with the passed-in list of
parameters. parameters.
...@@ -182,6 +192,8 @@ def count_zeros_fp32( ...@@ -182,6 +192,8 @@ def count_zeros_fp32(
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer. world for the distributed optimizer.
use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
...@@ -194,14 +206,14 @@ def count_zeros_fp32( ...@@ -194,14 +206,14 @@ def count_zeros_fp32(
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda') total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
data_parallel_group = None data_parallel_group = None
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
grad_not_none = hasattr(param, grad_attr) and getattr(param, grad_attr) is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
data_parallel_group = get_data_parallel_group_if_dtensor( grad_obj = getattr(param, grad_attr)
param.grad, data_parallel_group data_parallel_group = get_data_parallel_group_if_dtensor(grad_obj, data_parallel_group)
) grad = to_local_if_dtensor(grad_obj).detach()
grad = to_local_if_dtensor(param.grad).detach()
num_zeros = grad.numel() - torch.count_nonzero(grad) num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros total_num_zeros = num_zeros + total_num_zeros
......
...@@ -293,6 +293,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -293,6 +293,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_ranges: List[Dict], gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple], param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List, opt_group_ranges: List,
config: OptimizerConfig,
): ):
""" """
Create main parameter groups needed for the optimizer step. Create main parameter groups needed for the optimizer step.
...@@ -343,38 +344,45 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -343,38 +344,45 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp16, bf16 params. # fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main. # Generate sharded model param.
shard_model_param = model_param.detach().view(-1)[ shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end param_range.start : param_range.end
] ]
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param shard_model_param, model_param
) )
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'): if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Generate main param.
if not config.use_precision_aware_optimizer:
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a
# checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_main_param.shared = model_param.shared
else:
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None
# Add to group. # Add to group.
model_float16_params_this_group.append(model_param) model_float16_params_this_group.append(model_param)
...@@ -402,10 +410,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -402,10 +410,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
) )
# Update optimizer's params. # Update optimizer's params.
group_range["orig_group"]["params"] = [ if not config.use_precision_aware_optimizer:
*shard_fp32_params_this_group, group_range["orig_group"]["params"] = [
*shard_fp32_from_float16_params_this_group, *shard_fp32_params_this_group,
] *shard_fp32_from_float16_params_this_group,
]
else:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_float16_params_this_group,
]
return ( return (
model_float16_groups, model_float16_groups,
...@@ -469,10 +483,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -469,10 +483,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_chunk in self.model_chunks: for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config assert self.ddp_config == model_chunk.ddp_config
assert isinstance( assert (
optimizer, Adam isinstance(optimizer, Adam) or optimizer is None
), "Only Adam currently supported, due to checkpointing requirements." ), "Only Adam currently supported, due to checkpointing requirements."
# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
if optimizer is None:
self.is_stub_optimizer = True
return
# Model grad buffer ranges. # Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided" assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values())) self.buffers = list(itertools.chain(*per_model_buffers.values()))
...@@ -528,7 +548,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -528,7 +548,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.shard_fp32_groups, self.shard_fp32_groups,
self.shard_fp32_from_float16_groups, self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups( ) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
) )
# Update optimizer groups. # Update optimizer groups.
...@@ -537,6 +557,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -537,6 +557,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
self.is_stub_optimizer = False
def _get_model_param_range_map(self, param: torch.nn.Parameter): def _get_model_param_range_map(self, param: torch.nn.Parameter):
""" """
Given a model param, get the index sub-range of the param that this Given a model param, get the index sub-range of the param that this
...@@ -655,9 +677,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -655,9 +677,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
(numel,), dtype=torch.float32, device=torch.cuda.current_device() (numel,), dtype=torch.float32, device=torch.cuda.current_device()
) )
state_dict_state.append( tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}) if self.config.use_precision_aware_optimizer:
) tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
# Sort by state order (see method docstring for details). # Sort by state order (see method docstring for details).
state_dict_state.sort(key=lambda s: s[0]) state_dict_state.sort(key=lambda s: s[0])
...@@ -712,6 +735,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -712,6 +735,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else: else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}') raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
def _get_main_param_and_optimizer_states(self, model_param):
"""Return a dict containing the main param and optimizer states corresponding to the input
model_param.
The structure of the returned dict:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
tensors = {}
for k in self.optimizer.state[sharded_model_param]:
tensors[k] = self.optimizer.get_unscaled_state(sharded_model_param, k)
tensors["param"] = tensors.pop("master_param")
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
return tensors
def _set_main_param_and_optimizer_states(self, model_param, tensors):
"""Set the main param and optimizer states corresponding to the input model_param.
The structure of the input `tensors`:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
if k == "param":
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
self.optimizer.set_scaled_state(sharded_model_param, k, v)
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(tensors[key])
def get_parameter_state_fs_bucket_space(self): def get_parameter_state_fs_bucket_space(self):
"""Get internal representation of parameter state without any copies and modifications. """Get internal representation of parameter state without any copies and modifications.
...@@ -734,18 +806,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -734,18 +806,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = [] bucket_state = []
for model_param, param_range_map in gbuf_range_map["param_map"].items(): for model_param, param_range_map in gbuf_range_map["param_map"].items():
tensors = self._get_main_param_and_optimizer_states(model_param)
# Main param & optimizer states. tensors.update(
group_index, group_order = self.model_param_group_index_map[model_param] {
main_param = self.optimizer.param_groups[group_index]["params"][group_order] "gbuf_local_start": param_range_map["gbuf_local"].start,
optim_state = self.optimizer.state[main_param] "gbuf_local_end": param_range_map["gbuf_local"].end,
}
tensors = { )
"param": main_param,
**optim_state,
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
bucket_state.append(tensors) bucket_state.append(tensors)
buckets_state.append(bucket_state) buckets_state.append(bucket_state)
dtype_state[dtype] = buckets_state dtype_state[dtype] = buckets_state
...@@ -810,13 +877,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -810,13 +877,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Build contiguous DP rank shards (for param + optim states). # Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in gbuf_range_map["param_map"].items(): for model_param, param_range_map in gbuf_range_map["param_map"].items():
tensors = self._get_main_param_and_optimizer_states(model_param)
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
# Copy states into contiguous shard. # Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start gbuf_local_start = param_range_map["gbuf_local"].start
...@@ -1108,13 +1169,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1108,13 +1169,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values(): for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets: for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items(): for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
param_range = param_range_map['param'] param_range = param_range_map['param']
# Main param & optimizer states.
main_param = self.optimizer.param_groups[group_index]["params"][group_order] tensors = self._get_main_param_and_optimizer_states(model_param)
optim_state = self.optimizer.state[main_param] tensors["fp32_param"] = tensors.pop("param")
tensors = {"fp32_param": main_param, **optim_state}
# Match optimizer parameter with model ShardedTensor (or # Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory). # ShardedTensorFactory).
try: try:
...@@ -1188,13 +1246,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1188,13 +1246,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
bucket_state, gbuf_range_map["param_map"].items() bucket_state, gbuf_range_map["param_map"].items()
): ):
# Main param & optimizer states. # Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param] self._set_main_param_and_optimizer_states(model_param, src_tensors)
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
@torch.no_grad() @torch.no_grad()
def load_parameter_state_from_fs_model_space(self, state_dict): def load_parameter_state_from_fs_model_space(self, state_dict):
...@@ -1207,15 +1259,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1207,15 +1259,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values(): for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets: for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items(): for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param] src_tensors = {}
main_param = self.optimizer.param_groups[group_index]["params"][group_order] for k, v in state_dict[param_idx].items():
optim_state = self.optimizer.state[main_param] if k == "fp32_param":
src_tensors["param"] = v
src_tensors = state_dict[param_idx] else:
dst_tensors = {"fp32_param": main_param, **optim_state} src_tensors[k] = v
for key in dst_tensors: self._set_main_param_and_optimizer_states(model_param, src_tensors)
dst_tensors[key].copy_(src_tensors[key])
param_idx += 1 param_idx += 1
@classmethod @classmethod
...@@ -1390,6 +1440,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1390,6 +1440,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
f"Number of unpadded elements must be same in current run " f"Number of unpadded elements must be same in current run "
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})" f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
) )
recv_tensors = {}
for key in ("param", "exp_avg", "exp_avg_sq"): for key in ("param", "exp_avg", "exp_avg_sq"):
offset_in_world_tensors = 0 offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
...@@ -1440,26 +1491,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1440,26 +1491,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group_gloo, data_parallel_group_gloo,
) )
# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in gbuf_range_map["param_map"].items(): for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][
group_order
]
if key == "param":
tensor_to_copy_into = main_param
else:
optim_state = self.optimizer.state[main_param]
tensor_to_copy_into = optim_state[key]
# Copy states into contiguous shard. # Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end gbuf_local_end = param_range_map["gbuf_local"].end
tensor_to_copy_into.data.copy_( if model_param not in recv_tensors:
recv_tensor[gbuf_local_start:gbuf_local_end] recv_tensors[model_param] = {}
) recv_tensors[model_param][key] = recv_tensor[
gbuf_local_start:gbuf_local_end
]
for model_param, tensors in recv_tensors.items():
self._set_main_param_and_optimizer_states(model_param, tensors)
def split_state_dict_if_needed(self, state_dict): def split_state_dict_if_needed(self, state_dict):
""" """
...@@ -1600,6 +1643,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1600,6 +1643,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args: Args:
filename (str): path to load parameter state from. filename (str): path to load parameter state from.
""" """
if self.is_stub_optimizer:
return
state_dict = None state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0: if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename) state_dict = torch.load(filename)
...@@ -1618,24 +1663,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1618,24 +1663,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args: Args:
set_to_none (bool): if true, set grads to None. set_to_none (bool): if true, set grads to None.
""" """
for groups in ( if self.is_stub_optimizer:
return
total_groups = [
self.model_float16_groups, self.model_float16_groups,
self.model_fp32_groups, self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here? self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups, ]
): if not self.config.use_precision_aware_optimizer:
total_groups.append(self.shard_fp32_from_float16_groups)
for groups in total_groups:
for group in groups: for group in groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(
group, set_to_none, self.config.use_precision_aware_optimizer
)
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
""" """
Note: this should be equivalent to the float-16 optimizer's method, Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined. but written differently, so the two should be combined.
""" """
return [ if self.config.use_precision_aware_optimizer:
param.grad.data for group in self.optimizer.param_groups for param in group["params"] return [
] param.decoupled_grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
else:
return [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
def _get_model_and_main_params_data_float16(self): def _get_model_and_main_params_data_float16(self):
""" """
...@@ -1648,7 +1708,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1648,7 +1708,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
): ):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data) model_data.append(model_param.data)
main_data.append(main_param.data) if self.config.use_precision_aware_optimizer:
main_data.append(None)
else:
main_data.append(main_param.data)
return model_data, main_data return model_data, main_data
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
...@@ -1659,6 +1722,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1659,6 +1722,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated grads buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field. from the grad buffer to the main shard's grad field.
""" """
if self.is_stub_optimizer:
return
# Utility method for copying group grads. # Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups): def copy_group_grads(model_groups, shard_main_groups):
...@@ -1671,11 +1736,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1671,11 +1736,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_grad = model_param.main_grad model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end] shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
shard_main_param.grad = shard_model_grad.float() if self.config.use_precision_aware_optimizer:
# Pytorch requires a param and its' grad to be the same dtype, but we want
# their types to be different in precision-aware optimizer. So we use
# ".decoupled_grad" to replace ".grad".
# Note that this requires corresponding modifications in the optimizer (Let
# the optimizer read gradients from ".decoupled_grad" instead of ".grad").
shard_main_param.decoupled_grad = shard_model_grad
else:
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups. # Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups) if self.config.use_precision_aware_optimizer:
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups) copy_group_grads(self.model_float16_groups, self.shard_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
else:
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
""" """
...@@ -1685,6 +1762,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1685,6 +1762,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated params buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer. from the main shards into the correct position in the grad buffer.
""" """
if self.is_stub_optimizer:
return
# Utility method for copying group params. # Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups): def copy_group_params(shard_main_groups, model_groups):
...@@ -1724,6 +1803,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1724,6 +1803,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else: else:
shard_model_param.data.copy_(shard_main_param) shard_model_param.data.copy_(shard_main_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy shard groups to model groups. # Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups) copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, self.model_fp32_groups) copy_group_params(self.shard_fp32_groups, self.model_fp32_groups)
...@@ -1749,6 +1833,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1749,6 +1833,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.view(-1)[param_range.start : param_range.end] shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param) shard_main_param.data.copy_(shard_model_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy model groups to shard groups. # Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups) copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_groups) copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
...@@ -1758,6 +1847,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -1758,6 +1847,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`. `amax_history`.
""" """
if self.is_stub_optimizer:
return
amaxes = [] amaxes = []
scales = [] scales = []
scale_invs = [] scale_invs = []
......
File mode changed from 100755 to 100644
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import copy import copy
import math import math
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
from logging import getLogger from logging import getLogger
...@@ -52,21 +53,25 @@ from .optimizer_config import OptimizerConfig ...@@ -52,21 +53,25 @@ from .optimizer_config import OptimizerConfig
logger = getLogger(__name__) logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool): def _zero_grad_group_helper(
group: List[torch.nn.Parameter], set_to_none: bool, use_decoupled_grad: bool = False
):
""" """
Zero out the gradient for a group of parameters. Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer. Note: copied from torch.optim.optimizer.
""" """
for param in group: for param in group:
if param.grad is not None: grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
if hasattr(param, grad_attr) and getattr(param, grad_attr) is not None:
if set_to_none: if set_to_none:
param.grad = None setattr(param, grad_attr, None)
else: else:
if param.grad.grad_fn is not None: grad_obj = getattr(param, grad_attr)
param.grad.detach_() if grad_obj.grad_fn is not None:
grad_obj.detach_()
else: else:
param.grad.requires_grad_(False) grad_obj.requires_grad_(False)
param.grad.zero_() grad_obj.zero_()
def _multi_tensor_copy_this_to_that( def _multi_tensor_copy_this_to_that(
...@@ -105,7 +110,11 @@ class MegatronOptimizer(ABC): ...@@ -105,7 +110,11 @@ class MegatronOptimizer(ABC):
): ):
"""Input optimizer is the base optimizer (e.g., Adam).""" """Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' if self.optimizer is None:
warnings.warn(
f"WARNING: there is no optimizer on RANK {torch.distributed.get_rank()}. "
"This may be expected if you have frozen sub-models."
)
self.config = config self.config = config
self.init_state_fn = init_state_fn self.init_state_fn = init_state_fn
...@@ -114,9 +123,10 @@ class MegatronOptimizer(ABC): ...@@ -114,9 +123,10 @@ class MegatronOptimizer(ABC):
Get list of parameters wrapped in optimizer. Get list of parameters wrapped in optimizer.
""" """
params = [] params = []
for param_group in self.optimizer.param_groups: if hasattr(self.optimizer, 'param_groups'):
for param in param_group['params']: for param_group in self.optimizer.param_groups:
params.append(param) for param in param_group['params']:
params.append(param)
return params return params
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
...@@ -131,7 +141,10 @@ class MegatronOptimizer(ABC): ...@@ -131,7 +141,10 @@ class MegatronOptimizer(ABC):
params = self.get_parameters() params = self.get_parameters()
grads_for_norm = [] grads_for_norm = []
for param in params: for param in params:
grad = param.grad if self.config.use_precision_aware_optimizer:
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
else:
grad = param.grad
grad_not_none = grad is not None grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
...@@ -182,18 +195,27 @@ class MegatronOptimizer(ABC): ...@@ -182,18 +195,27 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad: float) -> float: def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute and return grad norm, also clip grads.""" """Compute and return grad norm, also clip grads."""
params = self.get_parameters() params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm() if params:
grads_for_norm = self.get_main_grads_for_grad_norm()
else:
grads_for_norm = []
grad_norm = get_grad_norm_fp32( grad_norm = get_grad_norm_fp32(
grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group() grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
) )
clip_grad_by_total_norm_fp32(params, clip_grad, grad_norm)
if params:
clip_grad_by_total_norm_fp32(
params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer
)
return grad_norm return grad_norm
def count_zeros(self) -> float: def count_zeros(self) -> float:
"""Count number of zeros in model's gradients.""" """Count number of zeros in model's gradients."""
params = self.get_parameters() params = self.get_parameters()
return count_zeros_fp32( return count_zeros_fp32(
params, grad_stats_parallel_group=self.get_grad_stats_parallel_group() params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer,
) )
@abstractmethod @abstractmethod
...@@ -213,13 +235,6 @@ class MegatronOptimizer(ABC): ...@@ -213,13 +235,6 @@ class MegatronOptimizer(ABC):
"""Simple scaling.""" """Simple scaling."""
return self.get_loss_scale() * loss return self.get_loss_scale() * loss
def start_param_sync(self, model_index: int, *unused):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod @abstractmethod
def reload_model_params(self): def reload_model_params(self):
"""Refreshes any internal state from the current model parameters. """Refreshes any internal state from the current model parameters.
...@@ -253,7 +268,10 @@ class MegatronOptimizer(ABC): ...@@ -253,7 +268,10 @@ class MegatronOptimizer(ABC):
# "optimizer_instance.param_groups" # "optimizer_instance.param_groups"
# (for example, to adjust the learning rate) # (for example, to adjust the learning rate)
def _get_param_groups(self): def _get_param_groups(self):
return self.optimizer.param_groups if self.is_stub_optimizer:
return []
else:
return self.optimizer.param_groups
def _set_param_groups(self, value): def _set_param_groups(self, value):
self.optimizer.param_groups = value self.optimizer.param_groups = value
...@@ -361,15 +379,17 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -361,15 +379,17 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads. # Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling() if not self.is_stub_optimizer:
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf. # Reset found inf.
self.found_inf.fill_(0.0) self.found_inf.fill_(0.0)
# Unscale and set found inf/nan if not self.is_stub_optimizer:
torch._amp_foreach_non_finite_check_and_unscale_( # Unscale and set found inf/nan
main_grads, self.found_inf, self.grad_scaler.inv_scale torch._amp_foreach_non_finite_check_and_unscale_(
) main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances. # Update across all model parallel instances.
torch.distributed.all_reduce( torch.distributed.all_reduce(
...@@ -393,7 +413,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -393,7 +413,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad', log_level=1).start( timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time barrier=self.config.barrier_with_L1_time
) )
self._copy_model_grads_to_main_grads() if not self.is_stub_optimizer:
self._copy_model_grads_to_main_grads()
if timers is not None: if timers is not None:
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
...@@ -427,7 +448,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -427,7 +448,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-inner-step', log_level=1).start( timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time barrier=self.config.barrier_with_L1_time
) )
self.optimizer.step() if not self.is_stub_optimizer:
self.optimizer.step()
if timers is not None: if timers is not None:
timers('optimizer-inner-step').stop() timers('optimizer-inner-step').stop()
...@@ -436,7 +458,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -436,7 +458,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params', log_level=1).start( timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time barrier=self.config.barrier_with_L1_time
) )
self._copy_main_params_to_model_params() if not self.is_stub_optimizer:
self._copy_main_params_to_model_params()
if timers is not None: if timers is not None:
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
...@@ -455,7 +478,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -455,7 +478,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad', log_level=1).start( timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time barrier=self.config.barrier_with_L1_time
) )
grad_norm = None grad_norm = 0.0
if self.config.clip_grad > 0.0: if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad) grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None: if timers is not None:
...@@ -466,7 +489,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -466,7 +489,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-count-zeros', log_level=1).start( timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time barrier=self.config.barrier_with_L1_time
) )
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else 0
if timers is not None: if timers is not None:
timers('optimizer-count-zeros').stop() timers('optimizer-count-zeros').stop()
...@@ -502,56 +525,60 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -502,56 +525,60 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Handle main parameters. # Handle main parameters.
# Three groups of parameters: if optimizer:
# float16_groups: original float16 parameters # Three groups of parameters:
# fp32_from_float16_groups: fp32 copy of float16 parameters # float16_groups: original float16 parameters
# fp32_from_fp32_groups: original fp32 parameters # fp32_from_float16_groups: fp32 copy of float16 parameters
self.float16_groups = [] # fp32_from_fp32_groups: original fp32 parameters
self.fp32_from_float16_groups = [] self.float16_groups = []
self.fp32_from_fp32_groups = [] self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups: # For all the groups in the original optimizer:
float16_params_this_group = [] for param_group in self.optimizer.param_groups:
fp32_params_this_group = [] float16_params_this_group = []
fp32_from_float16_params_this_group = [] fp32_params_this_group = []
# For all the parameters in this group: fp32_from_float16_params_this_group = []
for i, param in enumerate(param_group['params']): # For all the parameters in this group:
if param.requires_grad: for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: # float16 params:
float16_params_this_group.append(param) if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Create a copy float16_params_this_group.append(param)
main_param = param.detach().clone().float() # Create a copy
# Copy tensor model parallel attributes. main_param = param.detach().clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param) # Copy tensor model parallel attributes.
if hasattr(param, 'shared'): tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
main_param.shared = param.shared if hasattr(param, 'shared'):
# Replace the optimizer params with the new fp32 copy. main_param.shared = param.shared
param_group['params'][i] = main_param # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. fp32_from_float16_params_this_group.append(main_param)
if param in self.optimizer.state: # Reset existing state dict key to the new main param.
self.optimizer.state[main_param] = self.optimizer.state.pop(param) if param in self.optimizer.state:
# fp32 params. self.optimizer.state[main_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor': # fp32 params.
fp32_params_this_group.append(param) elif param.type() == 'torch.cuda.FloatTensor':
param_group['params'][i] = param fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError( else:
'Wrapped parameters must be one of ' raise TypeError(
'torch.cuda.FloatTensor, ' 'Wrapped parameters must be one of '
'torch.cuda.HalfTensor, or ' 'torch.cuda.FloatTensor, '
'torch.cuda.BFloat16Tensor. ' 'torch.cuda.HalfTensor, or '
'Received {}'.format(param.type()) 'torch.cuda.BFloat16Tensor. '
) 'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) self.float16_groups.append(float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
...@@ -559,6 +586,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -559,6 +586,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point.""" used by this field can be safely deallocated at this point."""
if self.is_stub_optimizer:
return
for group in self.float16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups: for group in self.fp32_from_float16_groups:
...@@ -567,6 +596,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -567,6 +596,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
if self.is_stub_optimizer:
return
main_grads = [] main_grads = []
...@@ -640,7 +671,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -640,7 +671,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
): ):
if is_loading: if is_loading:
self.init_state_fn(self.optimizer) self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict() state_dict = self.state_dict()
...@@ -735,9 +766,12 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -735,9 +766,12 @@ class FP32Optimizer(MegatronOptimizer):
super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn) super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda') self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
self.is_stub_optimizer = True if optimizer is None else False
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer""" """Copied from torch.optim.optimizer"""
if self.is_stub_optimizer:
return
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none) _zero_grad_group_helper(group['params'], set_to_none)
...@@ -748,6 +782,8 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -748,6 +782,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad() @torch.no_grad()
def prepare_grads(self) -> bool: def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found.""" """Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
if self.is_stub_optimizer:
return False
timers = self.config.timers timers = self.config.timers
# Copy main_grads to grads. # Copy main_grads to grads.
...@@ -767,6 +803,8 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -767,6 +803,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad() @torch.no_grad()
def step_with_ready_grads(self) -> bool: def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful.""" """Step the optimizer with ready gradients, return successful."""
if self.is_stub_optimizer:
return True
timers = self.config.timers timers = self.config.timers
# Update parameters. # Update parameters.
...@@ -832,7 +870,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -832,7 +870,7 @@ class FP32Optimizer(MegatronOptimizer):
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
): ):
if is_loading: if is_loading:
self.init_state_fn(self.optimizer) self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict() state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map( id_to_sharded_param_map = get_param_id_to_sharded_param_map(
...@@ -900,13 +938,19 @@ class ChainedOptimizer(MegatronOptimizer): ...@@ -900,13 +938,19 @@ class ChainedOptimizer(MegatronOptimizer):
def __init__(self, chained_optimizers: List[MegatronOptimizer]): def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.model_chunks = [] self.model_chunks = []
self.config = getattr(chained_optimizers[0], 'config', None) # chained_optimizers would be empty in the case that a rank
for optimizer in chained_optimizers: # has no trainable parameters
if hasattr(optimizer, 'model_chunks'): if chained_optimizers:
for model_chunk in optimizer.model_chunks: self.config = getattr(chained_optimizers[0], 'config', None)
if model_chunk not in self.model_chunks: for optimizer in chained_optimizers:
self.model_chunks.append(model_chunk) if hasattr(optimizer, 'model_chunks'):
assert self.config == getattr(optimizer, 'config', None) for model_chunk in optimizer.model_chunks:
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
self.chained_optimizers = chained_optimizers self.chained_optimizers = chained_optimizers
@property @property
...@@ -930,7 +974,10 @@ class ChainedOptimizer(MegatronOptimizer): ...@@ -930,7 +974,10 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.zero_grad(set_to_none) optimizer.zero_grad(set_to_none)
def get_loss_scale(self): def get_loss_scale(self):
return self.chained_optimizers[0].get_loss_scale() if self.chained_optimizers:
return self.chained_optimizers[0].get_loss_scale()
else:
return torch.tensor([1.0], dtype=torch.float32, device=torch.cuda.current_device())
def reload_model_params(self): def reload_model_params(self):
for optimizer in self.chained_optimizers: for optimizer in self.chained_optimizers:
...@@ -987,6 +1034,8 @@ class ChainedOptimizer(MegatronOptimizer): ...@@ -987,6 +1034,8 @@ class ChainedOptimizer(MegatronOptimizer):
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
"""ChainedOptimizer will step all optimizers one by one.""" """ChainedOptimizer will step all optimizers one by one."""
if self.is_stub_optimizer:
return True, 0.0, 0
found_inf_flag = self.prepare_grads() found_inf_flag = self.prepare_grads()
if found_inf_flag: if found_inf_flag:
return False, None, None return False, None, None
...@@ -1005,6 +1054,7 @@ class ChainedOptimizer(MegatronOptimizer): ...@@ -1005,6 +1054,7 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.get_parameters(), optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad, max_norm=optimizer.config.clip_grad,
total_norm=grad_norm, total_norm=grad_norm,
use_decoupled_grad=optimizer.config.use_precision_aware_optimizer,
) )
# Count the zeros in the grads. # Count the zeros in the grads.
...@@ -1062,8 +1112,3 @@ class ChainedOptimizer(MegatronOptimizer): ...@@ -1062,8 +1112,3 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.load_parameter_state_from_dp_zero( optimizer.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format state_dict, update_legacy_format=update_legacy_format
) )
def start_param_sync(self, model_index: int, *unused):
"""Start parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.start_param_sync(model_index, *unused)
...@@ -47,6 +47,23 @@ class OptimizerConfig: ...@@ -47,6 +47,23 @@ class OptimizerConfig:
params_dtype: torch.dtype = torch.float32 params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32.""" """dtype used when intializing the weights. Defaults to torch.float32."""
use_precision_aware_optimizer: bool = False
"""If true, allows optimizer-related tensors (master_param, gradients and optimizer states)
to be set to lower precision. Defaults to False.
"""
main_grads_dtype: torch.dtype = torch.float32
"""dtype of main grads when enabling precision-aware-optimizer"""
main_params_dtype: torch.dtype = torch.float32
"""dtype of main params when enabling precision-aware-optimizer"""
exp_avg_dtype: torch.dtype = torch.float32
"""dtype of exp_avg when enabling precision-aware-optimizer"""
exp_avg_sq_dtype: torch.dtype = torch.float32
"""dtype of exp_avg_sq when enabling precision-aware-optimizer"""
############### ###############
# Loss scaling # Loss scaling
############### ###############
...@@ -114,3 +131,51 @@ class OptimizerConfig: ...@@ -114,3 +131,51 @@ class OptimizerConfig:
config_logger_dir: str = "" config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir""" """When non-empty, dumps entry-point configs to config_logger_dir"""
def __post_init__(self):
"""Check the validity of the config."""
if self.use_precision_aware_optimizer:
assert (
self.optimizer == 'adam'
), '--use-precision-aware-optimizer only supported with adam'
assert (
self.use_distributed_optimizer
), '--use-precision-aware-optimizer only supported with distributed optimizer'
# Only the FusedAdam in TE supports --use-precision-aware-optimizer.
# TODO: Remove this check when apex's FusedAdam is no longer used.
try:
import inspect
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
adam_args = inspect.signature(Adam).parameters
arg_names = [
'master_weight_dtype',
'exp_avg_dtype',
'exp_avg_sq_dtype',
'use_decoupled_grad',
]
for name in arg_names:
assert name in adam_args, (
"Current FusedAdam of TE doesn't support --use-precision-aware-optimizer, "
"please update TE version."
)
except ImportError:
raise RuntimeError(
'--use-precision-aware-optimizer requires FusedAdam from TransformerEngine, '
'but not found.'
)
else:
assert (
self.main_grads_dtype == torch.float32
), "main_grads_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.main_params_dtype == torch.float32
), "main_params_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_dtype == torch.float32
), "exp_avg_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_sq_dtype == torch.float32
), "exp_avg_sq_dtype can only be fp32 when not using precision-aware optimizer"
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -9,6 +9,7 @@ from torch.autograd.variable import Variable ...@@ -9,6 +9,7 @@ from torch.autograd.variable import Variable
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import ( from megatron.core.utils import (
drain_embedding_wgrad_compute, drain_embedding_wgrad_compute,
...@@ -496,6 +497,9 @@ def forward_backward_no_pipelining( ...@@ -496,6 +497,9 @@ def forward_backward_no_pipelining(
if config.timers is not None: if config.timers is not None:
config.timers('forward-backward').stop() config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store return forward_data_store
...@@ -1479,6 +1483,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1479,6 +1483,9 @@ def forward_backward_pipelining_with_interleaving(
if config.timers is not None: if config.timers is not None:
config.timers('forward-backward').stop() config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store return forward_data_store
...@@ -1874,4 +1881,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -1874,4 +1881,7 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None: if config.timers is not None:
config.timers('forward-backward').stop() config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store return forward_data_store
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment