Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -262,32 +262,56 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns:
Instance of MegatronOptimizer.
"""
if config.optimizer == 'adam':
optimizer = Adam(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
)
def init_state_fn(opt):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
optimizer = None
init_state_fn = None
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
......@@ -407,6 +431,7 @@ def get_megatron_optimizer(
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......
......@@ -139,6 +139,7 @@ def clip_grad_by_total_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float],
total_norm: float,
use_decoupled_grad: bool = False,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
......@@ -149,15 +150,23 @@ def clip_grad_by_total_norm_fp32(
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
# Grads.
params = []
grads = []
for param in parameters:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
grads.append(to_local_if_dtensor(param.grad).detach())
if use_decoupled_grad:
if hasattr(param, "decoupled_grad") and param.decoupled_grad is not None:
assert param.decoupled_grad.dtype in [torch.float32, torch.bfloat16]
params.append(param)
grads.append(to_local_if_dtensor(param.decoupled_grad).detach())
else:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
grads.append(to_local_if_dtensor(param.grad).detach())
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
......@@ -171,6 +180,7 @@ def clip_grad_by_total_norm_fp32(
def count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
......@@ -182,6 +192,8 @@ def count_zeros_fp32(
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
if isinstance(parameters, torch.Tensor):
......@@ -194,14 +206,14 @@ def count_zeros_fp32(
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
data_parallel_group = None
for param in parameters:
grad_not_none = param.grad is not None
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
grad_not_none = hasattr(param, grad_attr) and getattr(param, grad_attr) is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
data_parallel_group = get_data_parallel_group_if_dtensor(
param.grad, data_parallel_group
)
grad = to_local_if_dtensor(param.grad).detach()
grad_obj = getattr(param, grad_attr)
data_parallel_group = get_data_parallel_group_if_dtensor(grad_obj, data_parallel_group)
grad = to_local_if_dtensor(grad_obj).detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
......
......@@ -293,6 +293,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
config: OptimizerConfig,
):
"""
Create main parameter groups needed for the optimizer step.
......@@ -343,38 +344,45 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
# Generate sharded model param.
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Generate main param.
if not config.use_precision_aware_optimizer:
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a
# checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_main_param.shared = model_param.shared
else:
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None
# Add to group.
model_float16_params_this_group.append(model_param)
......@@ -402,10 +410,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
)
# Update optimizer's params.
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
if not config.use_precision_aware_optimizer:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
else:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_float16_params_this_group,
]
return (
model_float16_groups,
......@@ -469,10 +483,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config
assert isinstance(
optimizer, Adam
assert (
isinstance(optimizer, Adam) or optimizer is None
), "Only Adam currently supported, due to checkpointing requirements."
# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
if optimizer is None:
self.is_stub_optimizer = True
return
# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values()))
......@@ -528,7 +548,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)
# Update optimizer groups.
......@@ -537,6 +557,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
self.is_stub_optimizer = False
def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
......@@ -655,9 +677,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
)
state_dict_state.append(
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
)
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
if self.config.use_precision_aware_optimizer:
tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
# Sort by state order (see method docstring for details).
state_dict_state.sort(key=lambda s: s[0])
......@@ -712,6 +735,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
def _get_main_param_and_optimizer_states(self, model_param):
"""Return a dict containing the main param and optimizer states corresponding to the input
model_param.
The structure of the returned dict:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
tensors = {}
for k in self.optimizer.state[sharded_model_param]:
tensors[k] = self.optimizer.get_unscaled_state(sharded_model_param, k)
tensors["param"] = tensors.pop("master_param")
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
return tensors
def _set_main_param_and_optimizer_states(self, model_param, tensors):
"""Set the main param and optimizer states corresponding to the input model_param.
The structure of the input `tensors`:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
if k == "param":
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
self.optimizer.set_scaled_state(sharded_model_param, k, v)
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(tensors[key])
def get_parameter_state_fs_bucket_space(self):
"""Get internal representation of parameter state without any copies and modifications.
......@@ -734,18 +806,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = []
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors.update(
{
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
)
bucket_state.append(tensors)
buckets_state.append(bucket_state)
dtype_state[dtype] = buckets_state
......@@ -810,13 +877,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
tensors = self._get_main_param_and_optimizer_states(model_param)
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
......@@ -1108,13 +1169,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
param_range = param_range_map['param']
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"fp32_param": main_param, **optim_state}
# Main param & optimizer states.
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors["fp32_param"] = tensors.pop("param")
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
try:
......@@ -1188,13 +1246,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
bucket_state, gbuf_range_map["param_map"].items()
):
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
self._set_main_param_and_optimizer_states(model_param, src_tensors)
@torch.no_grad()
def load_parameter_state_from_fs_model_space(self, state_dict):
......@@ -1207,15 +1259,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
src_tensors = state_dict[param_idx]
dst_tensors = {"fp32_param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
src_tensors = {}
for k, v in state_dict[param_idx].items():
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
param_idx += 1
@classmethod
......@@ -1390,6 +1440,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
f"Number of unpadded elements must be same in current run "
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
)
recv_tensors = {}
for key in ("param", "exp_avg", "exp_avg_sq"):
offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
......@@ -1440,26 +1491,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group_gloo,
)
# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][
group_order
]
if key == "param":
tensor_to_copy_into = main_param
else:
optim_state = self.optimizer.state[main_param]
tensor_to_copy_into = optim_state[key]
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
tensor_to_copy_into.data.copy_(
recv_tensor[gbuf_local_start:gbuf_local_end]
)
if model_param not in recv_tensors:
recv_tensors[model_param] = {}
recv_tensors[model_param][key] = recv_tensor[
gbuf_local_start:gbuf_local_end
]
for model_param, tensors in recv_tensors.items():
self._set_main_param_and_optimizer_states(model_param, tensors)
def split_state_dict_if_needed(self, state_dict):
"""
......@@ -1600,6 +1643,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
filename (str): path to load parameter state from.
"""
if self.is_stub_optimizer:
return
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
......@@ -1618,24 +1663,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
set_to_none (bool): if true, set grads to None.
"""
for groups in (
if self.is_stub_optimizer:
return
total_groups = [
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups,
):
]
if not self.config.use_precision_aware_optimizer:
total_groups.append(self.shard_fp32_from_float16_groups)
for groups in total_groups:
for group in groups:
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(
group, set_to_none, self.config.use_precision_aware_optimizer
)
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
return [
param.grad.data for group in self.optimizer.param_groups for param in group["params"]
]
if self.config.use_precision_aware_optimizer:
return [
param.decoupled_grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
else:
return [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
def _get_model_and_main_params_data_float16(self):
"""
......@@ -1648,7 +1708,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
if self.config.use_precision_aware_optimizer:
main_data.append(None)
else:
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
......@@ -1659,6 +1722,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
if self.is_stub_optimizer:
return
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
......@@ -1671,11 +1736,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
shard_main_param.grad = shard_model_grad.float()
if self.config.use_precision_aware_optimizer:
# Pytorch requires a param and its' grad to be the same dtype, but we want
# their types to be different in precision-aware optimizer. So we use
# ".decoupled_grad" to replace ".grad".
# Note that this requires corresponding modifications in the optimizer (Let
# the optimizer read gradients from ".decoupled_grad" instead of ".grad").
shard_main_param.decoupled_grad = shard_model_grad
else:
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
if self.config.use_precision_aware_optimizer:
copy_group_grads(self.model_float16_groups, self.shard_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
else:
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
def _copy_main_params_to_model_params(self):
"""
......@@ -1685,6 +1762,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
if self.is_stub_optimizer:
return
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
......@@ -1724,6 +1803,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
shard_model_param.data.copy_(shard_main_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, self.model_fp32_groups)
......@@ -1749,6 +1833,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
......@@ -1758,6 +1847,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
if self.is_stub_optimizer:
return
amaxes = []
scales = []
scale_invs = []
......
File mode changed from 100755 to 100644
......@@ -4,6 +4,7 @@
import copy
import math
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from logging import getLogger
......@@ -52,21 +53,25 @@ from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool):
def _zero_grad_group_helper(
group: List[torch.nn.Parameter], set_to_none: bool, use_decoupled_grad: bool = False
):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for param in group:
if param.grad is not None:
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
if hasattr(param, grad_attr) and getattr(param, grad_attr) is not None:
if set_to_none:
param.grad = None
setattr(param, grad_attr, None)
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
grad_obj = getattr(param, grad_attr)
if grad_obj.grad_fn is not None:
grad_obj.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
grad_obj.requires_grad_(False)
grad_obj.zero_()
def _multi_tensor_copy_this_to_that(
......@@ -105,7 +110,11 @@ class MegatronOptimizer(ABC):
):
"""Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
if self.optimizer is None:
warnings.warn(
f"WARNING: there is no optimizer on RANK {torch.distributed.get_rank()}. "
"This may be expected if you have frozen sub-models."
)
self.config = config
self.init_state_fn = init_state_fn
......@@ -114,9 +123,10 @@ class MegatronOptimizer(ABC):
Get list of parameters wrapped in optimizer.
"""
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
......@@ -131,7 +141,10 @@ class MegatronOptimizer(ABC):
params = self.get_parameters()
grads_for_norm = []
for param in params:
grad = param.grad
if self.config.use_precision_aware_optimizer:
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
else:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
......@@ -182,18 +195,27 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute and return grad norm, also clip grads."""
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
if params:
grads_for_norm = self.get_main_grads_for_grad_norm()
else:
grads_for_norm = []
grad_norm = get_grad_norm_fp32(
grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
)
clip_grad_by_total_norm_fp32(params, clip_grad, grad_norm)
if params:
clip_grad_by_total_norm_fp32(
params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer
)
return grad_norm
def count_zeros(self) -> float:
"""Count number of zeros in model's gradients."""
params = self.get_parameters()
return count_zeros_fp32(
params, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer,
)
@abstractmethod
......@@ -213,13 +235,6 @@ class MegatronOptimizer(ABC):
"""Simple scaling."""
return self.get_loss_scale() * loss
def start_param_sync(self, model_index: int, *unused):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -253,7 +268,10 @@ class MegatronOptimizer(ABC):
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
if self.is_stub_optimizer:
return []
else:
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
......@@ -361,15 +379,17 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
if not self.is_stub_optimizer:
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
if not self.is_stub_optimizer:
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
......@@ -393,7 +413,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_model_grads_to_main_grads()
if not self.is_stub_optimizer:
self._copy_model_grads_to_main_grads()
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
......@@ -427,7 +448,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if not self.is_stub_optimizer:
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
......@@ -436,7 +458,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_main_params_to_model_params()
if not self.is_stub_optimizer:
self._copy_main_params_to_model_params()
if timers is not None:
timers('optimizer-copy-main-to-model-params').stop()
......@@ -455,7 +478,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
grad_norm = 0.0
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
......@@ -466,7 +489,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else 0
if timers is not None:
timers('optimizer-count-zeros').stop()
......@@ -502,56 +525,60 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
if optimizer:
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......@@ -559,6 +586,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
if self.is_stub_optimizer:
return
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
......@@ -567,6 +596,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self):
if self.is_stub_optimizer:
return
main_grads = []
......@@ -640,7 +671,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
......@@ -735,9 +766,12 @@ class FP32Optimizer(MegatronOptimizer):
super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
self.is_stub_optimizer = True if optimizer is None else False
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
if self.is_stub_optimizer:
return
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
......@@ -748,6 +782,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
if self.is_stub_optimizer:
return False
timers = self.config.timers
# Copy main_grads to grads.
......@@ -767,6 +803,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
if self.is_stub_optimizer:
return True
timers = self.config.timers
# Update parameters.
......@@ -832,7 +870,7 @@ class FP32Optimizer(MegatronOptimizer):
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
......@@ -900,13 +938,19 @@ class ChainedOptimizer(MegatronOptimizer):
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.model_chunks = []
self.config = getattr(chained_optimizers[0], 'config', None)
for optimizer in chained_optimizers:
if hasattr(optimizer, 'model_chunks'):
for model_chunk in optimizer.model_chunks:
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
# chained_optimizers would be empty in the case that a rank
# has no trainable parameters
if chained_optimizers:
self.config = getattr(chained_optimizers[0], 'config', None)
for optimizer in chained_optimizers:
if hasattr(optimizer, 'model_chunks'):
for model_chunk in optimizer.model_chunks:
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
self.chained_optimizers = chained_optimizers
@property
......@@ -930,7 +974,10 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.zero_grad(set_to_none)
def get_loss_scale(self):
return self.chained_optimizers[0].get_loss_scale()
if self.chained_optimizers:
return self.chained_optimizers[0].get_loss_scale()
else:
return torch.tensor([1.0], dtype=torch.float32, device=torch.cuda.current_device())
def reload_model_params(self):
for optimizer in self.chained_optimizers:
......@@ -987,6 +1034,8 @@ class ChainedOptimizer(MegatronOptimizer):
@torch.no_grad()
def step(self):
"""ChainedOptimizer will step all optimizers one by one."""
if self.is_stub_optimizer:
return True, 0.0, 0
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
......@@ -1005,6 +1054,7 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=optimizer.config.use_precision_aware_optimizer,
)
# Count the zeros in the grads.
......@@ -1062,8 +1112,3 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
)
def start_param_sync(self, model_index: int, *unused):
"""Start parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.start_param_sync(model_index, *unused)
......@@ -47,6 +47,23 @@ class OptimizerConfig:
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
use_precision_aware_optimizer: bool = False
"""If true, allows optimizer-related tensors (master_param, gradients and optimizer states)
to be set to lower precision. Defaults to False.
"""
main_grads_dtype: torch.dtype = torch.float32
"""dtype of main grads when enabling precision-aware-optimizer"""
main_params_dtype: torch.dtype = torch.float32
"""dtype of main params when enabling precision-aware-optimizer"""
exp_avg_dtype: torch.dtype = torch.float32
"""dtype of exp_avg when enabling precision-aware-optimizer"""
exp_avg_sq_dtype: torch.dtype = torch.float32
"""dtype of exp_avg_sq when enabling precision-aware-optimizer"""
###############
# Loss scaling
###############
......@@ -114,3 +131,51 @@ class OptimizerConfig:
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
def __post_init__(self):
"""Check the validity of the config."""
if self.use_precision_aware_optimizer:
assert (
self.optimizer == 'adam'
), '--use-precision-aware-optimizer only supported with adam'
assert (
self.use_distributed_optimizer
), '--use-precision-aware-optimizer only supported with distributed optimizer'
# Only the FusedAdam in TE supports --use-precision-aware-optimizer.
# TODO: Remove this check when apex's FusedAdam is no longer used.
try:
import inspect
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
adam_args = inspect.signature(Adam).parameters
arg_names = [
'master_weight_dtype',
'exp_avg_dtype',
'exp_avg_sq_dtype',
'use_decoupled_grad',
]
for name in arg_names:
assert name in adam_args, (
"Current FusedAdam of TE doesn't support --use-precision-aware-optimizer, "
"please update TE version."
)
except ImportError:
raise RuntimeError(
'--use-precision-aware-optimizer requires FusedAdam from TransformerEngine, '
'but not found.'
)
else:
assert (
self.main_grads_dtype == torch.float32
), "main_grads_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.main_params_dtype == torch.float32
), "main_params_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_dtype == torch.float32
), "exp_avg_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_sq_dtype == torch.float32
), "exp_avg_sq_dtype can only be fp32 when not using precision-aware optimizer"
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -9,6 +9,7 @@ from torch.autograd.variable import Variable
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
......@@ -496,6 +497,9 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1479,6 +1483,9 @@ def forward_backward_pipelining_with_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1874,4 +1881,7 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment