Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
Pipeline #2152 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -262,21 +262,42 @@ def _get_megatron_optimizer_based_on_param_groups(
Returns:
Instance of MegatronOptimizer.
"""
# when freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions
if param_groups:
if config.optimizer == 'adam':
optimizer = Adam(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
def init_state_fn(opt):
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
......@@ -288,6 +309,9 @@ def _get_megatron_optimizer_based_on_param_groups(
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
optimizer = None
init_state_fn = None
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
......@@ -407,6 +431,7 @@ def get_megatron_optimizer(
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)
optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config,
......
......@@ -139,6 +139,7 @@ def clip_grad_by_total_norm_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
max_norm: Union[int, float],
total_norm: float,
use_decoupled_grad: bool = False,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
......@@ -149,11 +150,19 @@ def clip_grad_by_total_norm_fp32(
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
use_decoupled_grad (bool, optional): whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
# Grads.
params = []
grads = []
for param in parameters:
if use_decoupled_grad:
if hasattr(param, "decoupled_grad") and param.decoupled_grad is not None:
assert param.decoupled_grad.dtype in [torch.float32, torch.bfloat16]
params.append(param)
grads.append(to_local_if_dtensor(param.decoupled_grad).detach())
else:
if param.grad is not None:
assert param.grad.type() == 'torch.cuda.FloatTensor'
params.append(param)
......@@ -171,6 +180,7 @@ def clip_grad_by_total_norm_fp32(
def count_zeros_fp32(
parameters: Union[List[torch.Tensor], torch.Tensor],
grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
) -> float:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
......@@ -182,6 +192,8 @@ def count_zeros_fp32(
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
use_decoupled_grad (bool, optional) whether to read grad from ".grad" or ".decoupled_grad",
default value is False.
"""
if isinstance(parameters, torch.Tensor):
......@@ -194,14 +206,14 @@ def count_zeros_fp32(
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device='cuda')
data_parallel_group = None
for param in parameters:
grad_not_none = param.grad is not None
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
grad_not_none = hasattr(param, grad_attr) and getattr(param, grad_attr) is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
data_parallel_group = get_data_parallel_group_if_dtensor(
param.grad, data_parallel_group
)
grad = to_local_if_dtensor(param.grad).detach()
grad_obj = getattr(param, grad_attr)
data_parallel_group = get_data_parallel_group_if_dtensor(grad_obj, data_parallel_group)
grad = to_local_if_dtensor(grad_obj).detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
......
......@@ -293,6 +293,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
config: OptimizerConfig,
):
"""
Create main parameter groups needed for the optimizer step.
......@@ -343,15 +344,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
# Generate sharded model param.
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
# Generate main param.
if not config.use_precision_aware_optimizer:
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
# training is long enough or if the main params are loaded from a
# checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
......@@ -366,15 +375,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
else:
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None
# Add to group.
model_float16_params_this_group.append(model_param)
......@@ -402,10 +410,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
)
# Update optimizer's params.
if not config.use_precision_aware_optimizer:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
else:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_float16_params_this_group,
]
return (
model_float16_groups,
......@@ -469,10 +483,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config
assert isinstance(
optimizer, Adam
assert (
isinstance(optimizer, Adam) or optimizer is None
), "Only Adam currently supported, due to checkpointing requirements."
# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
if optimizer is None:
self.is_stub_optimizer = True
return
# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
self.buffers = list(itertools.chain(*per_model_buffers.values()))
......@@ -528,7 +548,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)
# Update optimizer groups.
......@@ -537,6 +557,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
self.is_stub_optimizer = False
def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
Given a model param, get the index sub-range of the param that this
......@@ -655,9 +677,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
)
state_dict_state.append(
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
)
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
if self.config.use_precision_aware_optimizer:
tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
# Sort by state order (see method docstring for details).
state_dict_state.sort(key=lambda s: s[0])
......@@ -712,6 +735,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
def _get_main_param_and_optimizer_states(self, model_param):
"""Return a dict containing the main param and optimizer states corresponding to the input
model_param.
The structure of the returned dict:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
tensors = {}
for k in self.optimizer.state[sharded_model_param]:
tensors[k] = self.optimizer.get_unscaled_state(sharded_model_param, k)
tensors["param"] = tensors.pop("master_param")
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
return tensors
def _set_main_param_and_optimizer_states(self, model_param, tensors):
"""Set the main param and optimizer states corresponding to the input model_param.
The structure of the input `tensors`:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
if k == "param":
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
self.optimizer.set_scaled_state(sharded_model_param, k, v)
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(tensors[key])
def get_parameter_state_fs_bucket_space(self):
"""Get internal representation of parameter state without any copies and modifications.
......@@ -734,18 +806,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = []
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors.update(
{
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
)
bucket_state.append(tensors)
buckets_state.append(bucket_state)
dtype_state[dtype] = buckets_state
......@@ -810,13 +877,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
tensors = self._get_main_param_and_optimizer_states(model_param)
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
......@@ -1108,13 +1169,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
param_range = param_range_map['param']
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"fp32_param": main_param, **optim_state}
# Main param & optimizer states.
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors["fp32_param"] = tensors.pop("param")
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
try:
......@@ -1188,13 +1246,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
bucket_state, gbuf_range_map["param_map"].items()
):
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
self._set_main_param_and_optimizer_states(model_param, src_tensors)
@torch.no_grad()
def load_parameter_state_from_fs_model_space(self, state_dict):
......@@ -1207,15 +1259,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
src_tensors = state_dict[param_idx]
dst_tensors = {"fp32_param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
src_tensors = {}
for k, v in state_dict[param_idx].items():
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
param_idx += 1
@classmethod
......@@ -1390,6 +1440,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
f"Number of unpadded elements must be same in current run "
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
)
recv_tensors = {}
for key in ("param", "exp_avg", "exp_avg_sq"):
offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
......@@ -1440,26 +1491,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group_gloo,
)
# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][
group_order
]
if key == "param":
tensor_to_copy_into = main_param
else:
optim_state = self.optimizer.state[main_param]
tensor_to_copy_into = optim_state[key]
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
tensor_to_copy_into.data.copy_(
recv_tensor[gbuf_local_start:gbuf_local_end]
)
if model_param not in recv_tensors:
recv_tensors[model_param] = {}
recv_tensors[model_param][key] = recv_tensor[
gbuf_local_start:gbuf_local_end
]
for model_param, tensors in recv_tensors.items():
self._set_main_param_and_optimizer_states(model_param, tensors)
def split_state_dict_if_needed(self, state_dict):
"""
......@@ -1600,6 +1643,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
filename (str): path to load parameter state from.
"""
if self.is_stub_optimizer:
return
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
......@@ -1618,23 +1663,38 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
set_to_none (bool): if true, set grads to None.
"""
for groups in (
if self.is_stub_optimizer:
return
total_groups = [
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups,
):
]
if not self.config.use_precision_aware_optimizer:
total_groups.append(self.shard_fp32_from_float16_groups)
for groups in total_groups:
for group in groups:
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(
group, set_to_none, self.config.use_precision_aware_optimizer
)
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
if self.config.use_precision_aware_optimizer:
return [
param.grad.data for group in self.optimizer.param_groups for param in group["params"]
param.decoupled_grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
else:
return [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
def _get_model_and_main_params_data_float16(self):
......@@ -1648,6 +1708,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
if self.config.use_precision_aware_optimizer:
main_data.append(None)
else:
main_data.append(main_param.data)
return model_data, main_data
......@@ -1659,6 +1722,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
if self.is_stub_optimizer:
return
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
......@@ -1671,9 +1736,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
if self.config.use_precision_aware_optimizer:
# Pytorch requires a param and its' grad to be the same dtype, but we want
# their types to be different in precision-aware optimizer. So we use
# ".decoupled_grad" to replace ".grad".
# Note that this requires corresponding modifications in the optimizer (Let
# the optimizer read gradients from ".decoupled_grad" instead of ".grad").
shard_main_param.decoupled_grad = shard_model_grad
else:
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
if self.config.use_precision_aware_optimizer:
copy_group_grads(self.model_float16_groups, self.shard_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
else:
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
......@@ -1685,6 +1762,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
if self.is_stub_optimizer:
return
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
......@@ -1724,6 +1803,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
shard_model_param.data.copy_(shard_main_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, self.model_fp32_groups)
......@@ -1749,6 +1833,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
......@@ -1758,6 +1847,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
if self.is_stub_optimizer:
return
amaxes = []
scales = []
scale_invs = []
......
File mode changed from 100755 to 100644
......@@ -4,6 +4,7 @@
import copy
import math
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from logging import getLogger
......@@ -52,21 +53,25 @@ from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool):
def _zero_grad_group_helper(
group: List[torch.nn.Parameter], set_to_none: bool, use_decoupled_grad: bool = False
):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for param in group:
if param.grad is not None:
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
if hasattr(param, grad_attr) and getattr(param, grad_attr) is not None:
if set_to_none:
param.grad = None
setattr(param, grad_attr, None)
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
grad_obj = getattr(param, grad_attr)
if grad_obj.grad_fn is not None:
grad_obj.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
grad_obj.requires_grad_(False)
grad_obj.zero_()
def _multi_tensor_copy_this_to_that(
......@@ -105,7 +110,11 @@ class MegatronOptimizer(ABC):
):
"""Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
if self.optimizer is None:
warnings.warn(
f"WARNING: there is no optimizer on RANK {torch.distributed.get_rank()}. "
"This may be expected if you have frozen sub-models."
)
self.config = config
self.init_state_fn = init_state_fn
......@@ -114,6 +123,7 @@ class MegatronOptimizer(ABC):
Get list of parameters wrapped in optimizer.
"""
params = []
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
......@@ -131,6 +141,9 @@ class MegatronOptimizer(ABC):
params = self.get_parameters()
grads_for_norm = []
for param in params:
if self.config.use_precision_aware_optimizer:
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
else:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
......@@ -182,18 +195,27 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute and return grad norm, also clip grads."""
params = self.get_parameters()
if params:
grads_for_norm = self.get_main_grads_for_grad_norm()
else:
grads_for_norm = []
grad_norm = get_grad_norm_fp32(
grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
)
clip_grad_by_total_norm_fp32(params, clip_grad, grad_norm)
if params:
clip_grad_by_total_norm_fp32(
params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer
)
return grad_norm
def count_zeros(self) -> float:
"""Count number of zeros in model's gradients."""
params = self.get_parameters()
return count_zeros_fp32(
params, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer,
)
@abstractmethod
......@@ -213,13 +235,6 @@ class MegatronOptimizer(ABC):
"""Simple scaling."""
return self.get_loss_scale() * loss
def start_param_sync(self, model_index: int, *unused):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -253,6 +268,9 @@ class MegatronOptimizer(ABC):
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
if self.is_stub_optimizer:
return []
else:
return self.optimizer.param_groups
def _set_param_groups(self, value):
......@@ -361,11 +379,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
if not self.is_stub_optimizer:
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
if not self.is_stub_optimizer:
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
......@@ -393,6 +413,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
if not self.is_stub_optimizer:
self._copy_model_grads_to_main_grads()
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
......@@ -427,6 +448,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
if not self.is_stub_optimizer:
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
......@@ -436,6 +458,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
if not self.is_stub_optimizer:
self._copy_main_params_to_model_params()
if timers is not None:
timers('optimizer-copy-main-to-model-params').stop()
......@@ -455,7 +478,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
grad_norm = 0.0
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
......@@ -466,7 +489,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else 0
if timers is not None:
timers('optimizer-count-zeros').stop()
......@@ -502,6 +525,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Handle main parameters.
if optimizer:
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
......@@ -552,6 +576,9 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......@@ -559,6 +586,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
if self.is_stub_optimizer:
return
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
......@@ -567,6 +596,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self):
if self.is_stub_optimizer:
return
main_grads = []
......@@ -640,7 +671,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
......@@ -735,9 +766,12 @@ class FP32Optimizer(MegatronOptimizer):
super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
self.is_stub_optimizer = True if optimizer is None else False
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
if self.is_stub_optimizer:
return
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
......@@ -748,6 +782,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
if self.is_stub_optimizer:
return False
timers = self.config.timers
# Copy main_grads to grads.
......@@ -767,6 +803,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
if self.is_stub_optimizer:
return True
timers = self.config.timers
# Update parameters.
......@@ -832,7 +870,7 @@ class FP32Optimizer(MegatronOptimizer):
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
......@@ -900,6 +938,9 @@ class ChainedOptimizer(MegatronOptimizer):
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.model_chunks = []
# chained_optimizers would be empty in the case that a rank
# has no trainable parameters
if chained_optimizers:
self.config = getattr(chained_optimizers[0], 'config', None)
for optimizer in chained_optimizers:
if hasattr(optimizer, 'model_chunks'):
......@@ -907,6 +948,9 @@ class ChainedOptimizer(MegatronOptimizer):
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
self.chained_optimizers = chained_optimizers
@property
......@@ -930,7 +974,10 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.zero_grad(set_to_none)
def get_loss_scale(self):
if self.chained_optimizers:
return self.chained_optimizers[0].get_loss_scale()
else:
return torch.tensor([1.0], dtype=torch.float32, device=torch.cuda.current_device())
def reload_model_params(self):
for optimizer in self.chained_optimizers:
......@@ -987,6 +1034,8 @@ class ChainedOptimizer(MegatronOptimizer):
@torch.no_grad()
def step(self):
"""ChainedOptimizer will step all optimizers one by one."""
if self.is_stub_optimizer:
return True, 0.0, 0
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
......@@ -1005,6 +1054,7 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=optimizer.config.use_precision_aware_optimizer,
)
# Count the zeros in the grads.
......@@ -1062,8 +1112,3 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
)
def start_param_sync(self, model_index: int, *unused):
"""Start parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.start_param_sync(model_index, *unused)
......@@ -47,6 +47,23 @@ class OptimizerConfig:
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
use_precision_aware_optimizer: bool = False
"""If true, allows optimizer-related tensors (master_param, gradients and optimizer states)
to be set to lower precision. Defaults to False.
"""
main_grads_dtype: torch.dtype = torch.float32
"""dtype of main grads when enabling precision-aware-optimizer"""
main_params_dtype: torch.dtype = torch.float32
"""dtype of main params when enabling precision-aware-optimizer"""
exp_avg_dtype: torch.dtype = torch.float32
"""dtype of exp_avg when enabling precision-aware-optimizer"""
exp_avg_sq_dtype: torch.dtype = torch.float32
"""dtype of exp_avg_sq when enabling precision-aware-optimizer"""
###############
# Loss scaling
###############
......@@ -114,3 +131,51 @@ class OptimizerConfig:
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
def __post_init__(self):
"""Check the validity of the config."""
if self.use_precision_aware_optimizer:
assert (
self.optimizer == 'adam'
), '--use-precision-aware-optimizer only supported with adam'
assert (
self.use_distributed_optimizer
), '--use-precision-aware-optimizer only supported with distributed optimizer'
# Only the FusedAdam in TE supports --use-precision-aware-optimizer.
# TODO: Remove this check when apex's FusedAdam is no longer used.
try:
import inspect
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
adam_args = inspect.signature(Adam).parameters
arg_names = [
'master_weight_dtype',
'exp_avg_dtype',
'exp_avg_sq_dtype',
'use_decoupled_grad',
]
for name in arg_names:
assert name in adam_args, (
"Current FusedAdam of TE doesn't support --use-precision-aware-optimizer, "
"please update TE version."
)
except ImportError:
raise RuntimeError(
'--use-precision-aware-optimizer requires FusedAdam from TransformerEngine, '
'but not found.'
)
else:
assert (
self.main_grads_dtype == torch.float32
), "main_grads_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.main_params_dtype == torch.float32
), "main_params_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_dtype == torch.float32
), "exp_avg_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_sq_dtype == torch.float32
), "exp_avg_sq_dtype can only be fp32 when not using precision-aware optimizer"
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -9,6 +9,7 @@ from torch.autograd.variable import Variable
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
......@@ -496,6 +497,9 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1479,6 +1483,9 @@ def forward_backward_pipelining_with_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1874,4 +1881,7 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
File mode changed from 100755 to 100644
......@@ -12,6 +12,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un
import numpy as np
import torch
import megatron.core.parallel_state as mpu
from megatron.core.dist_checkpointing.mapping import ShardedObject
"""DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE.
The rerun state machine implementation in this file is alpha-level code to help
......@@ -34,6 +37,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16
EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17
SerializableStateType = Union[list, dict]
DataIteratorArgType = Optional[Union["RerunDataIterator", list["RerunDataIterator"]]]
class Caller(NamedTuple):
......@@ -203,11 +207,13 @@ class RerunStateMachine:
self.saved_results: dict[Call, Any] = {}
self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats())
if _safe_get_rank() == 0:
logger.warning(f"RerunStateMachine initialized in mode {mode}")
def set_mode(self, mode: RerunMode) -> None:
"""Method to set the operating mode"""
if _safe_get_rank() == 0:
logger.warning(f"Setting RerunStateMachine mode {mode}")
self.mode = mode
......@@ -216,9 +222,7 @@ class RerunStateMachine:
return self.mode
def should_run_forward_backward(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> bool:
def should_run_forward_backward(self, data_iterator: DataIteratorArgType) -> bool:
"""Method instructing whether to (re)run the forward-backward pass.
Args:
......@@ -243,17 +247,7 @@ class RerunStateMachine:
self.validation_counts = defaultdict(int)
data_iterators: list[RerunDataIterator] = []
if self.mode != RerunMode.DISABLED and data_iterator is not None:
if not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
# Are we about to start the initial run?
if self.state == RerunState.NOT_RUNNING_YET:
......@@ -263,10 +257,9 @@ class RerunStateMachine:
if self.data_iterator_checkpoints is not None:
assert (
len(self.data_iterator_checkpoints) == len(data_iterators),
"data_iterator has different length than checkpointed data iterator",
)
), "data iterator has different length than checkpointed data iterator"
for i, d in enumerate(data_iterators):
d.set_checkpoint_state(self.data_iterator_checkpoints[i])
d.load_state_dict(self.data_iterator_checkpoints[i])
self.data_iterator_checkpoints = None
self._save_state()
if data_iterators:
......@@ -632,17 +625,15 @@ class RerunStateMachine:
self.last_loss = loss
return result
def get_checkpoint_state(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> list[dict[str, Any]]:
def state_dict(self, data_iterator: DataIteratorArgType, use_dist_ckpt: bool) -> dict[str, Any]:
"""Method that returns a state dict to be checkpointed.
Args:
data_iterator: the data iterator that needs to be checkpointed (or None
if this checkpoint is not requested by the rerun state machine).
use_dist_ckpt: generate a distributed checkpoint.
Returns:
A list of state dicts, each state dict representing the rerun state machine
for one rank.
A state dict representing the rerun state machine.
Example usage:
......@@ -651,26 +642,15 @@ class RerunStateMachine:
...
rerun_state_machine = get_rerun_state_machine()
checkpoint['rerun_state_machine'] = (
rerun_state_machine.get_checkpoint_state(data_iterator)
rerun_state_machine.state_dict(data_iterator, False)
)
...
return checkpoint
"""
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif isinstance(data_iterator, (list, tuple)):
data_iterators = data_iterator
else:
data_iterators = [data_iterator] if data_iterator is not None else []
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
state: dict[str, Any] = {
state_dict: dict[str, Any] = {
'mode': self.mode,
'state': self.state,
'current_iteration': self.current_iteration,
......@@ -679,7 +659,7 @@ class RerunStateMachine:
'restart_again_requested': self.restart_again_requested,
'continue_requested': self.continue_requested,
# logged_sdc_enabled should not be saved (set at the job startup time).
'error_injector_checkpoint': self.error_injector.get_checkpoint_state(),
'error_injector_checkpoint': self.error_injector.state_dict(),
# validation_counts should not be saved (reset at the beginning of the training loop).
'failed_validation_call': self.failed_validation_call,
'initial_result': self.initial_result,
......@@ -687,29 +667,31 @@ class RerunStateMachine:
'suspicious_device': self.suspicious_device,
# No need to save saved_state (RNG state already captured in checkpoint).
'data_iterator_checkpoints': (
[d.get_checkpoint_state() for d in data_iterators] if data_iterators else None
[d.state_dict() for d in data_iterators] if data_iterators else None
),
'last_loss': self.last_loss,
# No need to save saved_results and stats (resets when job resumes).
}
state_list: list[dict[str, Any]]
if (
torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
and self.mode != RerunMode.DISABLED
):
state_list = [None for i in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(state_list, state)
else:
state_list = [state]
return state_list
if use_dist_ckpt:
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
state_dict = ShardedObject(
'rerun_state_machine_state',
state_dict,
(pp_size, tp_size),
(pp_rank, tp_rank),
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
)
return state_dict
def set_checkpoint_state(self, state_list: list[dict[str, Any]]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Method that restores the state from a checkpoint.
Args:
state_list: the list of state dicts saved in the checkpoint and originally
obtained from get_checkpoint_state().
state_dict: the state dict saved in the checkpoint and originally
obtained from state_dict().
Returns:
None
......@@ -719,31 +701,43 @@ class RerunStateMachine:
...
if 'rerun_state_machine' in checkpoint:
rerun_state_machine = get_rerun_state_machine()
rerun_state_machine.set_checkpoint_state(checkpoint['rerun_state_machine'])
rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
"""
if self.mode == RerunMode.DISABLED:
return
rank: int = _safe_get_rank()
if rank == 0:
logger.warning(
"Getting RerunStaeMachine state from checkpoint, args rerun options ignored"
)
state = state_list[rank]
self.mode = state['mode']
self.state = state['state']
self.current_iteration = state['current_iteration']
self.rerun_requested = state['rerun_requested']
self.checkpoint_requested = state['checkpoint_requested']
self.restart_again_requested = state['restart_again_requested']
self.continue_requested = state['continue_requested']
self.error_injector.set_checkpoint_state(state['error_injector_checkpoint'])
self.failed_validation_call = state['failed_validation_call']
self.initial_result = state['initial_result']
self.suspicious_node = state['suspicious_node']
self.suspicious_device = state['suspicious_device']
self.data_iterator_checkpoints = state['data_iterator_checkpoints']
self.last_loss = state['last_loss']
logger.warning("Getting RerunStaeMachine state from checkpoint, args rerun options ignored")
self.mode = state_dict['mode']
self.state = state_dict['state']
self.current_iteration = state_dict['current_iteration']
self.rerun_requested = state_dict['rerun_requested']
self.checkpoint_requested = state_dict['checkpoint_requested']
self.restart_again_requested = state_dict['restart_again_requested']
self.continue_requested = state_dict['continue_requested']
self.error_injector.load_state_dict(state_dict['error_injector_checkpoint'])
self.failed_validation_call = state_dict['failed_validation_call']
self.initial_result = state_dict['initial_result']
self.suspicious_node = state_dict['suspicious_node']
self.suspicious_device = state_dict['suspicious_device']
self.data_iterator_checkpoints = state_dict['data_iterator_checkpoints']
self.last_loss = state_dict['last_loss']
def _sanitize_data_iterators(
self, data_iterator: DataIteratorArgType
) -> list["RerunDataIterator"]:
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
data_iterators = [d for d in data_iterators if d is not None]
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
), "data iterator is not wrapped with RerunDataIterator"
return data_iterators
def _get_validation_call_info(self) -> Call:
"""Internal method to get the context about the caller to validate_result()."""
......@@ -837,8 +831,8 @@ class RerunDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator)
"""
def __init__(self, iterable: Any, make_iterable: bool = True) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable
def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = []
self.replaying: bool = False
self.replay_pos: int = 0
......@@ -870,7 +864,7 @@ class RerunDataIterator:
self.replaying = False
self.saved_microbatches = []
def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the iterator as a serializable dict."""
return {
......@@ -879,7 +873,7 @@ class RerunDataIterator:
'replay_pos': self.replay_pos,
}
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""
self.saved_microbatches = state_dict['saved_microbatches']
......@@ -1051,7 +1045,7 @@ class RerunErrorInjector:
else:
raise RuntimeError("Should not be here")
def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the error injector as a serializable dict."""
return {
......@@ -1061,7 +1055,7 @@ class RerunErrorInjector:
'injected_error_type': self.injected_error_type,
}
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""
self.error_injection_rate = state_dict['error_injection_rate']
......@@ -1107,7 +1101,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None:
def _safe_get_rank() -> int:
"""Internal function that safely checks and returns the rank of the caller."""
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
# If torch.distributed is not initialized, try to read environment variables.
try:
return int(os.environ.get("RANK", 0))
except (ValueError, TypeError):
return 0
def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment