Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from .hybrid_optimizer import HybridDeviceOptimizer
# Copyright (c) 2025, NVIDIA CORPORATION and Alibaba PAI. All rights reserved.
from collections import defaultdict
from typing import Dict
import torch
def _param_generator(cpu_optimizer):
for group in cpu_optimizer.param_groups:
for param in group["params"]:
yield param
class HybridDeviceOptimizer(torch.optim.Optimizer):
"""
HybridDeviceOptimizer is a custom optimizer designed to facilitate
hybrid parameter updates across GPU and CPU. This optimizer allows
users to adjust the fraction of parameters updated on the CPU and
GPU through the `offload_fraction` parameter.
It supports bf16 mixed-precision training. Additionally, the optimizer
implements overlapping operations for improved performance, including
gradient transfer from device to host (D2H) and parameter transfer
from host to device (H2D).
Example:
from transformer_engine.pytorch.optimizers import FusedAdam as GPUAdam
from torch.optim import AdamW as CPUAdam
optimizer = HybridDeviceOptimizer(
param_groups,
cpu_optimizer_cls=CPUAdam,
gpu_optimizer_cls=GPUAdam,
offload_fraction=0.5,
param_update_in_fp32=True,
overlap_cpu_optimizer_d2h_h2d=True,
)
optimizer.step()
Note:
This optimizer is particularly useful in scenarios where memory
constraints are present or when leveraging both CPU and GPU resources
can lead to performance improvements.
"""
def __init__(
self,
params,
offload_fraction=0.5,
cpu_optimizer_cls=None,
gpu_optimizer_cls=None,
param_update_in_fp32: bool = False,
pin_cpu_grads: bool = True,
pin_cpu_params: bool = True,
overlap_cpu_optimizer_d2h_h2d: bool = True,
**kwargs
):
super(HybridDeviceOptimizer, self).__init__(
params,
defaults={
"offload_fraction": offload_fraction,
"cpu_optimizer_cls": cpu_optimizer_cls,
"gpu_optimizer_cls": gpu_optimizer_cls,
"param_update_in_fp32": param_update_in_fp32,
"pin_cpu_grads": pin_cpu_grads,
"pin_cpu_params": pin_cpu_params,
"overlap_cpu_optimizer_d2h_h2d": overlap_cpu_optimizer_d2h_h2d,
**kwargs,
},
)
self.offload_fraction = offload_fraction
self.cpu_optimizer_cls = cpu_optimizer_cls
self.gpu_optimizer_cls = gpu_optimizer_cls
self.pin_cpu_grads = pin_cpu_grads
self.pin_cpu_params = pin_cpu_params
self.overlap_cpu_optimizer_d2h_h2d = overlap_cpu_optimizer_d2h_h2d
self.param_update_in_fp32 = param_update_in_fp32
self.sub_optimizer_kwargs = kwargs
self._init_sub_optimizers()
self._register_load_state_dict_hooks()
def _set_sub_optimizer_grads(self):
if self.param_update_in_fp32:
for param in self.param_to_fp32_param:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to CPU, it should be handled
# in the following part.
continue
fp32_param = self.param_to_fp32_param[param]
grad = getattr(param, "decoupled_grad", param.grad)
if grad is not None:
fp32_param.grad = grad.to(fp32_param.dtype)
fp32_param.requires_grad = True
else:
fp32_param.requires_grad = False
# Sync the grads from GPU to CPU.
for optimizer in self.cpu_optimizers:
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
grad = getattr(gpu_param, "decoupled_grad", gpu_param.grad)
if grad is None:
param.requires_grad = False
continue
param.requires_grad = False
if param not in self.cpu_copy_map_grad:
self.cpu_copy_map_grad[param] = torch.empty(
param.shape, dtype=param.dtype, pin_memory=self.pin_cpu_grads, device="cpu"
)
param.grad = self.cpu_copy_map_grad[param]
self.cpu_copy_map_grad[param].data.copy_(grad, non_blocking=True)
self._cpu_optimizer_map_data_event[optimizer] = self._d2h_stream.record_event()
def _register_param_copy_back_gpu_hook(self):
def param_copy_back_gpu_hook_closure():
def param_copy_back_gpu_hook(optimizer, args, kwargs):
self._h2d_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._h2d_stream):
for param in _param_generator(optimizer):
gpu_param = self.cpu_copys_map_gpu_param[param]
gpu_param.data.copy_(param.data, non_blocking=True)
self._d2h_stream.record_event().wait(torch.cuda.current_stream())
return param_copy_back_gpu_hook
def fp32_param_copy_back_gpu_hook_closure():
def fp32_param_copy_back_gpu_hook(optimizer, args, kwargs):
for group in self.param_groups:
for param in group["params"]:
if param in self.gpu_params_map_cpu_copy:
# Skip if the param is offloaded to GPU, it has been
# copied back in the previous hook.
continue
if param in self.param_to_fp32_param:
fp32_param = self.param_to_fp32_param[param]
param.data.copy_(fp32_param.data)
return fp32_param_copy_back_gpu_hook
for optimizer in self.sub_optimizers:
if optimizer is not self.gpu_optimizer:
optimizer.register_step_post_hook(param_copy_back_gpu_hook_closure())
elif self.param_update_in_fp32:
optimizer.register_step_post_hook(fp32_param_copy_back_gpu_hook_closure())
def step(self, closure=None):
"""
Override the step method to perform the following operations:
1. Sync the HDO param_groups to sub-optimizers.
2. Sync the grads from GPU to CPU.
3. Step the sub-optimizers.
4. Sync the sub-optimizers state to HDO.
"""
# Sync param_groups to sub-optimizers before each step to make sure
# the lr, wd, etc. are up-to-date.
self._sync_hdo_param_groups_to_sub_optimizers()
self._d2h_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._d2h_stream):
self._set_sub_optimizer_grads()
# Step the sub-optimizers.
if self.gpu_optimizer:
self.gpu_optimizer.step(closure)
for cpu_optimizer in self.cpu_optimizers:
d2h_event = self._cpu_optimizer_map_data_event.pop(cpu_optimizer, None)
if d2h_event is not None:
d2h_event.synchronize()
cpu_optimizer.step(closure)
# Sync state and param_groups to HDO after each step.
# NOTE: It is possible for the optimizer to change the properties
# in param_groups.
self._sync_sub_optimizers_state_to_hdo()
def _init_sub_optimizers(self):
(
self.cpu_param_groups,
self.gpu_param_groups,
self.gpu_params_map_cpu_copy,
self.cpu_copys_map_gpu_param,
self.param_to_fp32_param,
) = self._get_sub_optimizer_param_groups(self.offload_fraction)
self.param_to_inner_param = {}
self.inner_param_to_orig_param = {}
for group in self.param_groups:
for param in group["params"]:
if param in self.param_to_fp32_param:
inner_param = self.param_to_fp32_param[param]
elif param in self.gpu_params_map_cpu_copy:
inner_param = self.gpu_params_map_cpu_copy[param]
else:
inner_param = param
self.param_to_inner_param[param] = inner_param
self.inner_param_to_orig_param[inner_param] = param
self.fp32_param_to_orig_param = {v: k for k, v in self.param_to_fp32_param.items()}
self.cpu_optimizers = []
if self.overlap_cpu_optimizer_d2h_h2d:
self.cpu_optimizers = self.build_cpu_optimizer_list(
self.cpu_optimizer_cls, self.cpu_param_groups
)
elif len(self.cpu_param_groups) > 0:
self.cpu_optimizers = [self.cpu_optimizer_cls(self.cpu_param_groups)]
if len(self.gpu_param_groups) > 0:
self.gpu_optimizer = self.gpu_optimizer_cls(self.gpu_param_groups)
else:
self.gpu_optimizer = None
self.cpu_copy_map_grad: Dict[torch.Tensor, torch.Tensor] = defaultdict(torch.Tensor)
self._d2h_stream = torch.cuda.current_stream()
self._h2d_stream = torch.cuda.current_stream()
if self.overlap_cpu_optimizer_d2h_h2d:
self._d2h_stream = torch.cuda.Stream()
self._h2d_stream = torch.cuda.Stream()
self._cpu_optimizer_map_data_event = dict()
self._register_param_copy_back_gpu_hook()
@staticmethod
def build_cpu_optimizer_list(cpu_optimizer_cls, cpu_param_groups):
"""Build several cpu optimizers to enable overlap. Currently we naively
assign each parameter to an individual optimizer.
Args:
cpu_optimizer_cls (Type[torch.optim.Optimizer]): A torch optimizer class
cpu_param_groups (List[Dict[str, Any]]): The CPU parameter groups
"""
cpu_optimizers = []
if len(cpu_param_groups) == 0:
return cpu_optimizers
for group in cpu_param_groups:
group_defaults = group.copy()
params = group_defaults.pop("params")
if isinstance(params, torch.Tensor):
params = [params]
for param in params:
_cpu_param_group = group_defaults.copy()
_cpu_param_group["params"] = [param]
cpu_optimizers.append(cpu_optimizer_cls([_cpu_param_group]))
return cpu_optimizers
def _get_sub_optimizer_param_groups(self, offload_fraction: float):
params = []
for group in self.param_groups:
params.extend(group["params"])
params_total_numel = sum([param.numel() for param in params])
gpu_params_total_numel = sum([param.numel() for param in params if param.is_cuda])
cpu_params_total_numel = params_total_numel - gpu_params_total_numel
offload_threshold = gpu_params_total_numel * offload_fraction
offload_params_numel = 0
cpu_param_groups = []
gpu_param_groups = []
gpu_params_map_cpu_copy = {}
cpu_copys_map_gpu_param = {}
param_to_fp32_param = {}
for group in self.param_groups:
gpu_group = group.copy()
cpu_group = group.copy()
gpu_group["params"] = []
cpu_group["params"] = []
for param in group["params"]:
orig_param = param
cpu_copy = False
if offload_params_numel < offload_threshold and param.is_cuda:
param = param.detach().clone().cpu().pin_memory()
offload_params_numel += param.numel()
cpu_copy = True
if self.param_update_in_fp32 and param.dtype != torch.float32:
param = param.detach().clone().float()
param_to_fp32_param[orig_param] = param
if cpu_copy:
gpu_params_map_cpu_copy[orig_param] = param
cpu_copys_map_gpu_param[param] = orig_param
if param.is_cuda:
gpu_group["params"].append(param)
else:
cpu_group["params"].append(param)
if len(gpu_group["params"]) != 0:
gpu_param_groups.append(gpu_group)
if len(cpu_group["params"]) != 0:
cpu_param_groups.append(cpu_group)
return (
cpu_param_groups,
gpu_param_groups,
gpu_params_map_cpu_copy,
cpu_copys_map_gpu_param,
param_to_fp32_param,
)
def _sync_sub_optimizers_state_to_hdo(self):
"""
Update HDO state attribute to sub-optimizers.
"""
# optimizer.state:
# {
# torch.nn.Parameter: {
# str: Any,
# },
# ...
# }
new_state = defaultdict(dict)
for optimizer in self.sub_optimizers:
for param in optimizer.state:
orig_param = self.inner_param_to_orig_param[param]
new_state[orig_param] = optimizer.state[param]
if self.param_update_in_fp32:
new_state[orig_param]["master_param"] = param
self.state = new_state
def _sync_hdo_state_to_sub_optimizers(self):
for optimizer in self.sub_optimizers:
new_state = defaultdict(dict)
for group in optimizer.param_groups:
for param in group["params"]:
orig_param = self.inner_param_to_orig_param[param]
new_state[param] = self.state[orig_param]
optimizer.state = new_state
self._update_fp32_params_by_new_state()
self._move_new_state_to_right_device()
def _sync_hdo_param_groups_to_sub_optimizers(self):
"""Sync HDO new param_groups attribute (e.g. lr, wd, etc.) to sub-optimizers."""
param_in_param_group_index = {}
for i, group in enumerate(self.param_groups):
for p_id, param in enumerate(group["params"]):
inner_param = self.param_to_inner_param[param]
param_in_param_group_index[inner_param] = (i, p_id)
for optimizer in self.sub_optimizers:
new_param_groups = []
for group in optimizer.param_groups:
new_group = group.copy()
# After sync-up the sub-optimizer last update, we need to sync-up the
# HDO new param_groups attributes to the sub-optimizer.
assert len(group["params"]) > 0, "param_groups should not be empty"
group_id, _ = param_in_param_group_index[group["params"][0]]
update_group_attrs = self.param_groups[group_id].copy()
del update_group_attrs["params"]
new_group.update(update_group_attrs)
new_param_groups.append(new_group)
optimizer.param_groups = new_param_groups
def _move_new_state_to_right_device(self):
for optimizer in self.sub_optimizers:
for param, state in optimizer.state.items():
for k, v in state.items():
if not isinstance(v, torch.Tensor):
continue
orig_param = self.inner_param_to_orig_param.get(param, param)
if isinstance(optimizer, self.defaults["cpu_optimizer_cls"]):
self.state[orig_param][k] = state[k] = v.to("cpu")
else:
self.state[orig_param][k] = state[k] = v.to("cuda")
def _update_fp32_params_by_new_state(self):
if not self.param_update_in_fp32:
return
for param, v in self.state.items():
fp32_param = self.param_to_fp32_param[param]
fp32_param.data.copy_(v["master_param"])
def _register_load_state_dict_hooks(self):
def pre_load_state_dict_hook(self, state_dict):
"""
Pre-load state dictionary hook to prevent loss of precision in
mixed-precision training.
When loading a state dictionary with `torch.load_state_dict`,
optimizer states are reset and cast from `float32` to `bfloat16`/`float16`,
potentially losing precision. This hook replaces parameters with
their `float32` copies to mitigate this issue.
Args:
state_dict (dict): The state dictionary to be loaded.
Returns:
dict: The modified state dictionary with `float32` parameters.
"""
if not self.param_update_in_fp32:
return state_dict
new_state = {}
for param, v in self.state.items():
param = self.param_to_fp32_param.get(param, param)
new_state[param] = v
self.state = new_state
for group in self.param_groups:
for i, param in enumerate(group["params"]):
group["params"][i] = self.param_to_fp32_param.get(param, param)
return state_dict
self.register_load_state_dict_pre_hook(pre_load_state_dict_hook)
def post_load_state_dict_hook(self):
# 1. Replace the temporarily replaced fp32 parameters back. Please
# refer to the documentation in `pre_load_state_dict_hook`.
if self.param_update_in_fp32:
new_state = {}
for param, v in self.state.items():
orig_param = self.fp32_param_to_orig_param.get(param, param)
new_state[orig_param] = v
self.state = new_state
for group in self.param_groups:
for i, param in enumerate(group["params"]):
group["params"][i] = self.fp32_param_to_orig_param.get(param, param)
# 2. After loading state_dict, the parameters may change, and we need to
# reinitialize the sub-optimizers to regenerate the new parameters and
# cpu copy pairs.
self._init_sub_optimizers()
self._sync_hdo_param_groups_to_sub_optimizers()
self._sync_hdo_state_to_sub_optimizers()
self.register_load_state_dict_post_hook(post_load_state_dict_hook)
def zero_grad(self, set_to_none: bool = True):
"""
Zero or zero to none the gradients of all the parameters in the model.
"""
super(HybridDeviceOptimizer, self).zero_grad(set_to_none)
for group in self.param_groups:
for param in group["params"]:
if hasattr(param, "decoupled_grad"):
if set_to_none:
param.decoupled_grad = None
else:
param.decoupled_grad.zero_()
def dummy_step(self):
"""
The dummy step can be used to initialize the potential optimizer.state,
which can solve the problem of checkpoint loading for an inplace operation
such as loading a torch distributed checkpoint, for example.
"""
for group in self.param_groups:
for param in group["params"]:
param.grad = torch.randn_like(param)
self.step()
self.zero_grad()
@property
def sub_optimizers(self):
"""
Return the list of sub-optimizers.
"""
if self.gpu_optimizer is not None:
return self.cpu_optimizers + [self.gpu_optimizer]
return self.cpu_optimizers
......@@ -21,6 +21,8 @@ except ImportError:
HAVE_APEX_OR_TE = False
from megatron.core.optimizer.cpu_offloading import HybridDeviceOptimizer
from .. import tensor_parallel
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..dist_checkpointing import ShardedTensor
......@@ -33,8 +35,13 @@ from ..dist_checkpointing.mapping import (
)
from ..dist_checkpointing.utils import extract_sharded_tensors_and_factories
from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
from ..fp8_utils import (
get_fp8_scale_and_amax,
is_float8tensor,
is_mxfp8tensor,
quantize_param_fragment,
)
from ..transformer.module import MegatronModule
from ..utils import is_float8tensor
from .grad_scaler import MegatronGradScaler
from .optimizer import (
MixedPrecisionOptimizer,
......@@ -43,14 +50,6 @@ from .optimizer import (
)
from .optimizer_config import OptimizerConfig
try:
# This will be used when "--fp8-param-gather" is enabled.
# When BF16/FP16 parameters don't exist, we need to cast the FP32 main parameters to
# FP8 directly in the optimizer.
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8
except:
pass
logger = getLogger(__name__)
......@@ -293,6 +292,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_ranges: List[Dict],
param_gbuf_map: Dict[torch.nn.Parameter, Tuple],
opt_group_ranges: List,
config: OptimizerConfig,
):
"""
Create main parameter groups needed for the optimizer step.
......@@ -343,38 +343,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
# Clone model -> main.
# Generate sharded model param.
shard_model_param = model_param.detach().view(-1)[
param_range.start : param_range.end
]
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param
)
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Generate main param.
if not config.use_precision_aware_optimizer:
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a
# checkpoint).
if is_float8tensor(model_param) and hasattr(
model_param, 'get_high_precision_init_val'
):
shard_main_param = (
model_param.get_high_precision_init_val()
.view(-1)[param_range.start : param_range.end]
.clone()
.to(shard_model_param.device)
.float()
)
model_param.clear_high_precision_init_val()
elif is_mxfp8tensor(model_param):
raise ValueError(
"Distributed optimizer currently does not support MXFP8 parameters."
)
else:
shard_main_param = shard_model_param.clone().float()
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param
)
if hasattr(model_param, 'shared'):
shard_main_param.shared = model_param.shared
else:
# When using precision-aware optimizer, main params are held by FusedAdam.
shard_main_param = None
# Store handle to main_param.
model_param.main_param = shard_main_param
model_param.main_param_sharded = True
# Add to group.
model_float16_params_this_group.append(model_param)
......@@ -402,10 +417,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
)
# Update optimizer's params.
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
if not config.use_precision_aware_optimizer:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
else:
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_float16_params_this_group,
]
return (
model_float16_groups,
......@@ -424,7 +445,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_chunks: List[MegatronModule],
per_model_buffers: Dict[int, List[_ParamAndGradBuffer]],
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_group_gloo: torch.distributed.ProcessGroup,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup],
data_parallel_group_idx: int,
distributed_optimizer_instance_id: int,
):
......@@ -468,10 +489,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.ddp_config = self.model_chunks[0].ddp_config
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config
self.distributed_optimizer_instance_id = distributed_optimizer_instance_id
assert isinstance(optimizer, (Adam, HybridDeviceOptimizer)) or optimizer is None, (
"Only Adam and HybridDeviceOptimizer currently supported, "
"due to checkpointing requirements."
)
assert isinstance(
optimizer, Adam
), "Only Adam currently supported, due to checkpointing requirements."
# when freezing sub-models we have no real optimizer
# but still need a stub DistributedOptimizer class
if optimizer is None:
self.is_stub_optimizer = True
return
self.is_stub_optimizer = False
if self.ddp_config.use_custom_fsdp:
return
# Model grad buffer ranges.
assert per_model_buffers is not None, "per_model_buffers must be provided"
......@@ -480,7 +513,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.data_parallel_group = data_parallel_group
self.data_parallel_group_gloo = data_parallel_group_gloo
self.data_parallel_group_idx = data_parallel_group_idx
self.distributed_optimizer_instance_id = distributed_optimizer_instance_id
self.gbuf_idx_to_model_idx_map = {}
gbuf_idx = 0
......@@ -515,6 +547,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.gbuf_ranges.append(self._build_gbuf_range_map(buffer))
self.model_param_gbuf_map = self._build_model_param_gbuf_map(self.gbuf_ranges)
# Add main_param field to each parameter. We will use this fp32 copy to compute
# the param norm.
# For parameters with optimizer state on this rank, None will be overwritten by
# the corresponding sharded main_param tensor.
for param_group in self.optimizer.param_groups:
# For all the parameters in this group.
for param in param_group['params']:
if param.requires_grad:
# fp32 copy only needed for 16-bit parameters.
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
param.main_param = None
param.main_param_sharded = True
# Optimizer ranges.
(self.model_param_group_index_map, self.opt_group_ranges) = (
self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
......@@ -528,14 +573,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer = HybridDeviceOptimizer(
params=[g["orig_group"] for g in self.opt_group_ranges], **self.optimizer.defaults
)
else:
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())
def _get_model_param_range_map(self, param: torch.nn.Parameter):
"""
......@@ -571,6 +618,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
steps = list(set([s["step"].item() for s in inner_state_dict["state"].values()]))
assert len(steps) == 1
step = steps[0]
elif isinstance(self.optimizer, HybridDeviceOptimizer):
step = None
for optimizer in self.optimizer.sub_optimizers:
if isinstance(optimizer, torch.optim.AdamW):
if len(optimizer.state) == 0:
continue
steps = list(set([s["step"].item() for s in optimizer.state.values()]))
assert len(steps) == 1, f"steps: {optimizer.state}"
step = steps[0]
break
# Optimizer state (do not store parameter state here).
state_dict['optimizer'] = {k: v for k, v in inner_state_dict.items() if k != "state"}
......@@ -579,6 +636,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
if not HAVE_APEX_OR_TE:
# Native PyTorch param group requires step (i.e., iteration).
param_group["step"] = step
elif isinstance(self.optimizer, HybridDeviceOptimizer) and step is not None:
param_group["step"] = int(step)
# Grad scaler state.
if self.grad_scaler:
......@@ -615,6 +674,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
- state_order : The index of a parameter within the shared parameter
list.
"""
if len(self.optimizer.state) == 0:
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer.dummy_step()
elif self.ddp_config.use_custom_fsdp:
# Initializes optimizer states with dummy values.
# This step is necessary to ensure that the optimizer's states are
# initialized correctly. These dummy states will be replaced in-place
# during the loading of distributed checkpoints.
for group in self.optimizer.param_groups:
for param in group["params"]:
if param.numel() == 0:
# Avoid FusedAdam errors on empty tensor input.
continue
param.grad = torch.randn_like(param)
self.optimizer.step()
self.optimizer.zero_grad()
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
......@@ -655,9 +731,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
)
state_dict_state.append(
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
)
tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
if self.config.use_precision_aware_optimizer:
tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
# Sort by state order (see method docstring for details).
state_dict_state.sort(key=lambda s: s[0])
......@@ -676,6 +753,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for s in state_dict_state.values():
# Native PyTorch state dict requires step (i.e., iteration).
s["step"] = step
elif isinstance(self.optimizer, HybridDeviceOptimizer):
# Handle Torch AdamW special case, which, unlike FusedAdam, Torch AdamW
# has an extra optimizer state “step”.
steps = list(
set([g["step"] for g in state_dict["optimizer"]["param_groups"] if "step" in g])
)
if len(steps) != 0:
assert len(steps) == 1, f"steps: {steps}"
step = torch.tensor(steps[0], dtype=torch.float32, device="cpu")
for v in self.optimizer.state.values():
v["step"] = step.detach().clone()
# Optimizer.
self.optimizer.load_state_dict(
......@@ -702,6 +790,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert 'param_state_sharding_type' in state_dict, state_dict.keys()
param_state = state_dict['param_state']
sharding_type = state_dict['param_state_sharding_type']
if self.ddp_config.use_custom_fsdp:
assert (
sharding_type == "fully_sharded_model_space"
), "Only fully sharded model space is supported"
logger.info(f'Loading distributed optimizer sharded state of type {sharding_type}')
if sharding_type == 'dp_zero_gather_scatter':
self.load_parameter_state_from_dp_zero(param_state)
......@@ -712,6 +804,92 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
else:
raise NotImplementedError(f'Unknown sharding_type: {sharding_type}')
def _get_main_param_and_optimizer_states(self, model_param):
"""Return a dict containing the main param and optimizer states corresponding to the input
model_param.
The structure of the returned dict:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
if model_param not in pg_buffer.param_to_name:
continue
param_name = pg_buffer.param_to_name[model_param]
main_param = dict(pg_buffer.optimizer_named_parameters)[param_name]
assert param_name is not None, f"Not found main_param"
return {"param": main_param, **self.optimizer.state[main_param]}
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
tensors = {}
for k in self.optimizer.state[sharded_model_param]:
if isinstance(self.optimizer, HybridDeviceOptimizer):
tensors[k] = self.optimizer.state[sharded_model_param][k]
continue
tensors[k] = self.optimizer.get_unscaled_state(sharded_model_param, k)
tensors["param"] = tensors.pop("master_param")
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
return tensors
def _set_main_param_and_optimizer_states(self, model_param, tensors):
"""Set the main param and optimizer states corresponding to the input model_param.
The structure of the input `tensors`:
tensors = {
"param": torch.Tensor
"exp_avg": torch.Tensor
"exp_avg_sq": torch.Tensor
}
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
if model_param not in pg_buffer.param_to_name:
continue
param_name = pg_buffer.param_to_name[model_param]
main_param = dict(pg_buffer.optimizer_named_parameters)[param_name]
assert param_name is not None, f"Not found parameter"
for key in tensors:
if key == "param":
main_param.copy_(tensors[key])
else:
self.optimizer.state[main_param][key] = tensors[key]
return
group_index, group_order = self.model_param_group_index_map[model_param]
if self.config.use_precision_aware_optimizer:
sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
for k, v in tensors.items():
if isinstance(self.optimizer, HybridDeviceOptimizer):
if k == "param":
k = "master_param"
self.optimizer.state[sharded_model_param][k] = v
continue
if k == "param":
self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
else:
self.optimizer.set_scaled_state(sharded_model_param, k, v)
else:
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(tensors[key])
def get_parameter_state_fs_bucket_space(self):
"""Get internal representation of parameter state without any copies and modifications.
......@@ -734,18 +912,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
bucket_state = []
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {
"param": main_param,
**optim_state,
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors.update(
{
"gbuf_local_start": param_range_map["gbuf_local"].start,
"gbuf_local_end": param_range_map["gbuf_local"].end,
}
)
bucket_state.append(tensors)
buckets_state.append(bucket_state)
dtype_state[dtype] = buckets_state
......@@ -762,8 +935,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
if self.ddp_config.use_custom_fsdp:
state = {"buckets_coalesced": True}
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for group_id, group in enumerate(pg_buffer.parameter_groups):
this_group_state = {}
mbuf = group.master_weight_buffer
for item_id, _ in enumerate(group.params):
main_param = mbuf.get_item(item_id)
optim_state = self.optimizer.state[main_param]
object_list = [None] * mbuf.dp_world_size
torch.distributed.all_gather_object(
object_list, optim_state, group=mbuf.data_parallel_group
)
for rank, obj in enumerate(object_list):
for name, value in obj.items():
assert torch.is_tensor(value), f"Expected tensor, got {type(value)}"
this_group_state.setdefault(name, []).append(value)
for name, values in this_group_state.items():
this_group_state[name] = torch.cat(values).cpu()
state[f"group_{group_id}"] = this_group_state
return state
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -810,13 +1010,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Build contiguous DP rank shards (for param + optim states).
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}
tensors = self._get_main_param_and_optimizer_states(model_param)
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
......@@ -895,6 +1089,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
' Please switch to `full_sharded_model_space`.'
)
if self.ddp_config.use_custom_fsdp:
assert sharding_type == 'fully_sharded_model_space', (
f'For FSDP, only `fully_sharded_model_space` is supported. ' f'Got: {sharding_type}'
)
state_dict = self.state_dict()
if sharding_type != 'fully_sharded_model_space':
# State dict differs between different model parallel groups
......@@ -914,7 +1113,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# which conditionally skips re-allocating the optimizer's state if
# already initialized, which in turn reduces memory fragmentation.
self.load_state_dict(self.state_dict())
if sharding_type == 'fully_sharded_bucket_space':
param_state = self.sharded_param_state_fs_bucket_space(
model_sharded_state_dict, is_loading
......@@ -956,7 +1154,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Fixed TPxPP. Save on DP rank 0 only
param_state = ShardedObject(
f'optimizer.distributed.dp_group_idx_{self.data_parallel_group_idx}.param_state',
param_state_data,
param_state_data, # pylint: disable=E0606
(1,),
(0,),
)
......@@ -1101,6 +1299,72 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
prefix = 'optimizer.state'
state = {}
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
def _get_param_state_sharded_tensors(model_param, item_slice):
# Main param & optimizer states.
tensors = self._get_main_param_and_optimizer_states(model_param)
tensors["fp32_param"] = tensors.pop("param")
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
if self.ddp_config.use_custom_fsdp:
model_param = getattr(model_param, "fully_shard_param_local_shard", model_param)
try:
sharded_metadata = param_to_sharded_metadata[model_param]
except KeyError as e:
raise ValueError(
f'Model param {model_param} not in model_sharded_state_dict'
) from e
# Set DP corresponding replica_id coordinate to 0.
assert (
len(sharded_metadata.replica_id) == 3
), f'Expected replica_id format (PP, TP, DP), got: {sharded_metadata}'
replica_id = (*sharded_metadata.replica_id[:2], self.distributed_optimizer_instance_id)
# Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer
# params.
for state_key, state_ten in tensors.items():
if state_key == 'step':
# Note that step is a 0-dim tensor, unlike other
# states have the same size as the parameter.
# The optimizer state of STEP is handled
# specifically and is read from param_groups.
continue
replace_kwargs = dict(
key=f'{prefix}.{state_key}.{sharded_metadata.key}',
data=state_ten,
dtype=state_ten.dtype,
flattened_range=item_slice,
replica_id=replica_id,
)
if isinstance(sharded_metadata, ShardedTensorFactory):
replace_kwargs.pop('dtype')
tensors[state_key] = replace(sharded_metadata, **replace_kwargs)
tensors[state_key].validate_metadata_integrity()
return tensors
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for pg in pg_buffer.parameter_groups:
gbuf = pg.main_grad_buffer
if gbuf is None:
continue
for model_param in gbuf.params:
item_id = gbuf.param_idx[model_param]
param_name = pg_buffer.param_to_name[model_param]
item_slice = gbuf._get_item_slice_in_shard(item_id)
if item_slice[0] == item_slice[1]:
# This param is not in this shard.
continue
state[param_name] = _get_param_state_sharded_tensors(
model_param, slice(*item_slice)
)
return state
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
param_idx = 0
......@@ -1108,45 +1372,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
param_range = param_range_map['param']
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"fp32_param": main_param, **optim_state}
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
try:
sharded_metadata = param_to_sharded_metadata[model_param]
except KeyError as e:
raise ValueError(
f'Model param {model_param} not in model_sharded_state_dict'
) from e
# Set DP corresponding replica_id coordinate to 0.
assert (
len(sharded_metadata.replica_id) == 3
), f'Expected replica_id format (PP, TP, DP), got: {sharded_metadata}'
replica_id = (
*sharded_metadata.replica_id[:2],
self.distributed_optimizer_instance_id,
tensors = _get_param_state_sharded_tensors(
model_param, slice(param_range.start, param_range.end)
)
# Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer
# params.
for state_key, state_ten in tensors.items():
replace_kwargs = dict(
key=f'{prefix}.{state_key}.{sharded_metadata.key}',
data=state_ten,
dtype=state_ten.dtype,
flattened_range=slice(param_range.start, param_range.end),
replica_id=replica_id,
)
if isinstance(sharded_metadata, ShardedTensorFactory):
replace_kwargs.pop('dtype')
tensors[state_key] = replace(sharded_metadata, **replace_kwargs)
tensors[state_key].validate_metadata_integrity()
state[param_idx] = tensors
param_idx += 1
return state
......@@ -1188,13 +1417,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
bucket_state, gbuf_range_map["param_map"].items()
):
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
self._set_main_param_and_optimizer_states(model_param, src_tensors)
@torch.no_grad()
def load_parameter_state_from_fs_model_space(self, state_dict):
......@@ -1202,21 +1425,41 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Inverse of the `sharded_param_state_fs_model_space` method.
"""
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
pg_buffer = model_chunk.param_and_grad_buffer
for model_param in pg_buffer.params:
param_name = pg_buffer.param_to_name[model_param]
if param_name not in state_dict:
continue
src_tensors = {}
for k, v in state_dict[param_name].items():
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
return
param_idx = 0 # matching order with `sharded_param_state_fs_model_space`
for gbuf_range_maps in self.gbuf_ranges:
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
for gbuf_range_map in gbuf_range_map_for_all_buckets:
for model_param, param_range_map in gbuf_range_map["param_map"].items():
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
src_tensors = state_dict[param_idx]
dst_tensors = {"fp32_param": main_param, **optim_state}
for key in dst_tensors:
dst_tensors[key].copy_(src_tensors[key])
src_tensors = {}
for k, v in state_dict[param_idx].items():
if k == "step":
# Handle torch Adam "step" state separately.
continue
if k == "fp32_param":
src_tensors["param"] = v
else:
src_tensors[k] = v
self._set_main_param_and_optimizer_states(model_param, src_tensors)
param_idx += 1
if isinstance(self.optimizer, HybridDeviceOptimizer):
self.optimizer._sync_hdo_state_to_sub_optimizers()
@classmethod
def _update_legacy_world_tensors(cls, old_tensors, new_numels):
......@@ -1254,6 +1497,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
"""
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -1369,6 +1613,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return self.load_parameter_state_from_dp_zero_legacy(state_dict)
# Data parallelism variables.
assert self.data_parallel_group_gloo is not None
data_parallel_world_size = self.data_parallel_group_gloo.size()
data_parallel_rank = torch.distributed.get_rank(self.data_parallel_group_gloo)
data_parallel_group_gloo = self.data_parallel_group_gloo
......@@ -1390,6 +1635,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
f"Number of unpadded elements must be same in current run "
f"({buffer_numel_unpadded}) and checkpoint ({checkpoint_numel_unpadded})"
)
recv_tensors = {}
for key in ("param", "exp_avg", "exp_avg_sq"):
offset_in_world_tensors = 0
for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
......@@ -1440,26 +1686,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group_gloo,
)
# Copy local contiguous shards to param/optim shards.
for model_param, param_range_map in gbuf_range_map["param_map"].items():
# Main param & optimizer states.
group_index, group_order = self.model_param_group_index_map[model_param]
main_param = self.optimizer.param_groups[group_index]["params"][
group_order
]
if key == "param":
tensor_to_copy_into = main_param
else:
optim_state = self.optimizer.state[main_param]
tensor_to_copy_into = optim_state[key]
# Copy states into contiguous shard.
gbuf_local_start = param_range_map["gbuf_local"].start
gbuf_local_end = param_range_map["gbuf_local"].end
tensor_to_copy_into.data.copy_(
recv_tensor[gbuf_local_start:gbuf_local_end]
)
if model_param not in recv_tensors:
recv_tensors[model_param] = {}
recv_tensors[model_param][key] = recv_tensor[
gbuf_local_start:gbuf_local_end
]
for model_param, tensors in recv_tensors.items():
self._set_main_param_and_optimizer_states(model_param, tensors)
def split_state_dict_if_needed(self, state_dict):
"""
......@@ -1600,6 +1838,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
filename (str): path to load parameter state from.
"""
if self.is_stub_optimizer:
return
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
......@@ -1618,24 +1858,44 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
Args:
set_to_none (bool): if true, set grads to None.
"""
for groups in (
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.zero_grad_buffer()
return
if self.is_stub_optimizer:
return
total_groups = [
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups,
):
]
if not self.config.use_precision_aware_optimizer:
total_groups.append(self.shard_fp32_from_float16_groups)
for groups in total_groups:
for group in groups:
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(
group, set_to_none, self.config.use_precision_aware_optimizer
)
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
return [
param.grad.data for group in self.optimizer.param_groups for param in group["params"]
]
if self.config.use_precision_aware_optimizer:
return [
param.decoupled_grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
else:
return [
param.grad.data
for group in self.optimizer.param_groups
for param in group["params"]
]
def _get_model_and_main_params_data_float16(self):
"""
......@@ -1648,7 +1908,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
if self.config.use_precision_aware_optimizer:
main_data.append(None)
else:
main_data.append(main_param.data)
return model_data, main_data
def _copy_model_grads_to_main_grads(self):
......@@ -1659,6 +1922,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
return
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
......@@ -1671,11 +1939,23 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end]
shard_main_param.grad = shard_model_grad.float()
if self.config.use_precision_aware_optimizer:
# Pytorch requires a param and its' grad to be the same dtype, but we want
# their types to be different in precision-aware optimizer. So we use
# ".decoupled_grad" to replace ".grad".
# Note that this requires corresponding modifications in the optimizer (Let
# the optimizer read gradients from ".decoupled_grad" instead of ".grad").
shard_main_param.decoupled_grad = shard_model_grad
else:
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
if self.config.use_precision_aware_optimizer:
copy_group_grads(self.model_float16_groups, self.shard_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
else:
copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups)
def _copy_main_params_to_model_params(self):
"""
......@@ -1685,6 +1965,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.param_and_grad_buffer.copy_main_weights_to_model_weights()
return
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
......@@ -1714,16 +2001,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# be deleted if it is not necessary.
shard_main_param = shard_main_param.to(model_param.dtype)
cast_to_fp8(
shard_main_param.view(1, -1),
model_param._fp8_meta['scaling_fwd'],
model_param._fp8_meta_index,
model_param._fp8_dtype,
out=shard_model_param.view(1, -1),
quantize_param_fragment(
shard_main_param, out=shard_model_param, param=model_param
)
else:
shard_model_param.data.copy_(shard_main_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, self.model_fp32_groups)
......@@ -1736,6 +2024,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
if isinstance(self.optimizer, HybridDeviceOptimizer):
return
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.param_and_grad_buffer.copy_model_weights_to_main_weights()
return
# Utility method for copying group params.
def copy_group_params(model_groups, shard_main_groups):
......@@ -1749,6 +2044,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
shard_main_param.data.copy_(shard_model_param)
# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
if self.config.use_precision_aware_optimizer:
return
# Copy model groups to shard groups.
copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
......@@ -1758,42 +2058,58 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
amaxes = []
scales = []
scale_invs = []
if self.is_stub_optimizer:
return
if self.ddp_config.use_custom_fsdp:
buffers = []
for m in self.model_chunks:
for group in m.param_and_grad_buffer.parameter_groups:
mbuf = group.model_weight_buffer
buffers.append(mbuf)
else:
buffers = self.buffers
# Iterate over all parameters inside this optimizer to find FP8 parameters.
for buffer in self.buffers:
for bucket in buffer.buckets:
for param in bucket.params_list:
if is_float8tensor(param):
fp8_meta = param._fp8_meta['scaling_fwd']
fp8_meta_index = param._fp8_meta_index
amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1))
scales.append(fp8_meta.scale[fp8_meta_index].view(1))
scale_invs.append(param._scale_inv.view(1))
# Reset transpose cache
param._reset_caches()
# If there is no FP8 parameters, skip all operations.
if len(scales) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Update scaling factors.
packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device)
packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
_multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
_multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.data_parallel_group
)
_multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf)
for buffer in buffers:
amaxes = []
scales = []
scale_invs = []
for param in buffer.params:
if is_float8tensor(param):
scale, amax = get_fp8_scale_and_amax(param)
amaxes.append(amax.view(1))
scales.append(scale.view(1))
scale_invs.append(param._scale_inv.view(1))
# Reset transpose cache
param._reset_caches()
# If there is no FP8 parameters, skip all operations.
if len(scales) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Update scaling factors.
packed_scales = torch.empty(
len(scales), dtype=torch.float32, device=scales[0].device
)
packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
_multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes = torch.empty(
len(amaxes), dtype=torch.float32, device=amaxes[0].device
)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
_multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf)
torch.distributed.all_reduce(
packed_amaxes,
op=torch.distributed.ReduceOp.MAX,
group=buffer.data_parallel_group,
)
_multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf)
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
......@@ -1809,13 +2125,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers = self.config.timers
if timers is not None:
timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# the first all-gather is launched asynchronously in the next optimizer.zero_grad()
# call and subsequent all-gathers are launched in the forward pre-hook.
if not self.ddp_config.overlap_param_gather:
if self.ddp_config.use_custom_fsdp:
for model_chunk in self.model_chunks:
model_chunk.start_param_sync()
else:
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# the first all-gather is launched asynchronously in the next optimizer.zero_grad()
# call and subsequent all-gathers are launched in the forward pre-hook.
if not self.ddp_config.overlap_param_gather:
for model_chunk in self.model_chunks:
model_chunk.start_param_sync()
if timers is not None:
timers('params-all-gather').stop()
......
File mode changed from 100755 to 100644
......@@ -4,6 +4,7 @@
import copy
import math
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from logging import getLogger
......@@ -52,21 +53,25 @@ from .optimizer_config import OptimizerConfig
logger = getLogger(__name__)
def _zero_grad_group_helper(group: List[torch.nn.Parameter], set_to_none: bool):
def _zero_grad_group_helper(
group: List[torch.nn.Parameter], set_to_none: bool, use_decoupled_grad: bool = False
):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for param in group:
if param.grad is not None:
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
if hasattr(param, grad_attr) and getattr(param, grad_attr) is not None:
if set_to_none:
param.grad = None
setattr(param, grad_attr, None)
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
grad_obj = getattr(param, grad_attr)
if grad_obj.grad_fn is not None:
grad_obj.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
grad_obj.requires_grad_(False)
grad_obj.zero_()
def _multi_tensor_copy_this_to_that(
......@@ -105,7 +110,11 @@ class MegatronOptimizer(ABC):
):
"""Input optimizer is the base optimizer (e.g., Adam)."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
if self.optimizer is None:
warnings.warn(
f"WARNING: there is no optimizer on RANK {torch.distributed.get_rank()}. "
"This may be expected if you have frozen sub-models."
)
self.config = config
self.init_state_fn = init_state_fn
......@@ -114,9 +123,10 @@ class MegatronOptimizer(ABC):
Get list of parameters wrapped in optimizer.
"""
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
......@@ -131,7 +141,10 @@ class MegatronOptimizer(ABC):
params = self.get_parameters()
grads_for_norm = []
for param in params:
grad = param.grad
if self.config.use_precision_aware_optimizer:
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
else:
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
......@@ -182,18 +195,27 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad: float) -> float:
"""Compute and return grad norm, also clip grads."""
params = self.get_parameters()
grads_for_norm = self.get_main_grads_for_grad_norm()
if params:
grads_for_norm = self.get_main_grads_for_grad_norm()
else:
grads_for_norm = []
grad_norm = get_grad_norm_fp32(
grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
)
clip_grad_by_total_norm_fp32(params, clip_grad, grad_norm)
if params:
clip_grad_by_total_norm_fp32(
params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer
)
return grad_norm
def count_zeros(self) -> float:
"""Count number of zeros in model's gradients."""
params = self.get_parameters()
return count_zeros_fp32(
params, grad_stats_parallel_group=self.get_grad_stats_parallel_group()
params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer,
)
@abstractmethod
......@@ -213,13 +235,6 @@ class MegatronOptimizer(ABC):
"""Simple scaling."""
return self.get_loss_scale() * loss
def start_param_sync(self, model_index: int, *unused):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@abstractmethod
def reload_model_params(self):
"""Refreshes any internal state from the current model parameters.
......@@ -253,7 +268,10 @@ class MegatronOptimizer(ABC):
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
if self.is_stub_optimizer:
return []
else:
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
......@@ -280,7 +298,7 @@ class MegatronOptimizer(ABC):
"""
@staticmethod
def _extract_common_per_param_step(state_dict) -> Union[int, torch.Tensor]:
def _extract_common_per_param_step(state_dict) -> Union[int, torch.Tensor, None]:
common_step = None
for param_idx, param_state in state_dict['state'].items():
param_step = param_state.get('step', None)
......@@ -356,20 +374,23 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return self.grad_scaler.scale
def reload_model_params(self):
self._copy_model_params_to_main_params()
if self.param_groups:
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
if not self.is_stub_optimizer:
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
if not self.is_stub_optimizer:
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale
)
# Update across all model parallel instances.
torch.distributed.all_reduce(
......@@ -393,7 +414,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_model_grads_to_main_grads()
if not self.is_stub_optimizer:
self._copy_model_grads_to_main_grads()
if timers is not None:
timers('optimizer-copy-to-main-grad').stop()
......@@ -427,7 +449,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-inner-step', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self.optimizer.step()
if not self.is_stub_optimizer:
self.optimizer.step()
if timers is not None:
timers('optimizer-inner-step').stop()
......@@ -436,7 +459,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
self._copy_main_params_to_model_params()
if not self.is_stub_optimizer:
self._copy_main_params_to_model_params()
if timers is not None:
timers('optimizer-copy-main-to-model-params').stop()
......@@ -455,7 +479,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
grad_norm = None
grad_norm = 0.0
if self.config.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.config.clip_grad)
if timers is not None:
......@@ -466,7 +490,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-count-zeros', log_level=1).start(
barrier=self.config.barrier_with_L1_time
)
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else 0
if timers is not None:
timers('optimizer-count-zeros').stop()
......@@ -502,56 +526,63 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
if optimizer:
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_from_fp32_groups: original fp32 parameters
self.float16_groups = []
self.fp32_from_float16_groups = []
self.fp32_from_fp32_groups = []
# For all the groups in the original optimizer:
for param_group in self.optimizer.param_groups:
float16_params_this_group = []
fp32_params_this_group = []
fp32_from_float16_params_this_group = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
float16_params_this_group.append(param)
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
# Store handle to main_param.
param.main_param = main_param
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] = self.optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())
)
self.float16_groups.append(float16_params_this_group)
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......@@ -559,6 +590,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
if self.is_stub_optimizer:
return
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
......@@ -567,6 +600,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none)
def _collect_main_grad_data_for_unscaling(self):
if self.is_stub_optimizer:
return
main_grads = []
......@@ -640,7 +675,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
......@@ -677,7 +712,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict['optimizer']['state']['common_step'] = step
if step:
state_dict['optimizer']['state']['common_step'] = step
return state_dict
def load_state_dict(self, state_dict):
......@@ -735,9 +771,12 @@ class FP32Optimizer(MegatronOptimizer):
super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn)
self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda')
self.is_stub_optimizer = True if optimizer is None else False
def zero_grad(self, set_to_none=True):
"""Copied from torch.optim.optimizer"""
if self.is_stub_optimizer:
return
for group in self.optimizer.param_groups:
_zero_grad_group_helper(group['params'], set_to_none)
......@@ -748,6 +787,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def prepare_grads(self) -> bool:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
if self.is_stub_optimizer:
return False
timers = self.config.timers
# Copy main_grads to grads.
......@@ -767,6 +808,8 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step_with_ready_grads(self) -> bool:
"""Step the optimizer with ready gradients, return successful."""
if self.is_stub_optimizer:
return True
timers = self.config.timers
# Update parameters.
......@@ -832,7 +875,7 @@ class FP32Optimizer(MegatronOptimizer):
self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False
):
if is_loading:
self.init_state_fn(self.optimizer)
self.init_state_fn(self.optimizer, self.config)
state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
......@@ -846,7 +889,8 @@ class FP32Optimizer(MegatronOptimizer):
optim_state_to_sharding_state(state_dict, id_to_sharded_param_map, exclude_keys="step")
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict['state']['common_step'] = step
if step:
state_dict['state']['common_step'] = step
return state_dict
......@@ -900,13 +944,19 @@ class ChainedOptimizer(MegatronOptimizer):
def __init__(self, chained_optimizers: List[MegatronOptimizer]):
self.model_chunks = []
self.config = getattr(chained_optimizers[0], 'config', None)
for optimizer in chained_optimizers:
if hasattr(optimizer, 'model_chunks'):
for model_chunk in optimizer.model_chunks:
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
# chained_optimizers would be empty in the case that a rank
# has no trainable parameters
if chained_optimizers:
self.config = getattr(chained_optimizers[0], 'config', None)
for optimizer in chained_optimizers:
if hasattr(optimizer, 'model_chunks'):
for model_chunk in optimizer.model_chunks:
if model_chunk not in self.model_chunks:
self.model_chunks.append(model_chunk)
assert self.config == getattr(optimizer, 'config', None)
self.is_stub_optimizer = False
else:
self.is_stub_optimizer = True
self.chained_optimizers = chained_optimizers
@property
......@@ -930,7 +980,10 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.zero_grad(set_to_none)
def get_loss_scale(self):
return self.chained_optimizers[0].get_loss_scale()
if self.chained_optimizers:
return self.chained_optimizers[0].get_loss_scale()
else:
return torch.tensor([1.0], dtype=torch.float32, device=torch.cuda.current_device())
def reload_model_params(self):
for optimizer in self.chained_optimizers:
......@@ -987,6 +1040,8 @@ class ChainedOptimizer(MegatronOptimizer):
@torch.no_grad()
def step(self):
"""ChainedOptimizer will step all optimizers one by one."""
if self.is_stub_optimizer:
return True, 0.0, 0
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
......@@ -1005,6 +1060,7 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=optimizer.config.use_precision_aware_optimizer,
)
# Count the zeros in the grads.
......@@ -1062,8 +1118,3 @@ class ChainedOptimizer(MegatronOptimizer):
optimizer.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
)
def start_param_sync(self, model_index: int, *unused):
"""Start parameter synchronization for all optimizers."""
for optimizer in self.chained_optimizers:
optimizer.start_param_sync(model_index, *unused)
......@@ -47,6 +47,23 @@ class OptimizerConfig:
params_dtype: torch.dtype = torch.float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
use_precision_aware_optimizer: bool = False
"""If true, allows optimizer-related tensors (master_param, gradients and optimizer states)
to be set to lower precision. Defaults to False.
"""
main_grads_dtype: torch.dtype = torch.float32
"""dtype of main grads when enabling precision-aware-optimizer"""
main_params_dtype: torch.dtype = torch.float32
"""dtype of main params when enabling precision-aware-optimizer"""
exp_avg_dtype: torch.dtype = torch.float32
"""dtype of exp_avg when enabling precision-aware-optimizer"""
exp_avg_sq_dtype: torch.dtype = torch.float32
"""dtype of exp_avg_sq when enabling precision-aware-optimizer"""
###############
# Loss scaling
###############
......@@ -97,6 +114,34 @@ class OptimizerConfig:
overlap_param_gather_with_optimizer_step: bool = False
"""If true, overlap param all-gather of first bucket with optimizer step."""
#######################
# Optimizer Offload
#######################
optimizer_cpu_offload: bool = False
"""If True, offload optimizer states tensor and compute to CPU."""
optimizer_offload_fraction: float = 0.0
"""Specifies the fraction of optimizer states to offload from GPU memory to CPU."""
use_torch_optimizer_for_cpu_offload: bool = False
"""If True, use torch.optim.Optimizer for CPU offload."""
overlap_cpu_optimizer_d2h_h2d: bool = False
"""
When set to `True`, this flag enables overlapping of the CPU optimizer
update process with the data transfer operations. This can help improve
overall training efficiency by reducing idle time during data movement,
allowing the optimizer to perform updates while gradients and parameters
are being transferred between devices.
"""
pin_cpu_grads: bool = True
"""If True, pin the optimizer gradients to CPU memory."""
pin_cpu_params: bool = True
"""If True, pin the optimizer parameters to CPU memory."""
################
# Miscellaneous
################
......@@ -114,3 +159,54 @@ class OptimizerConfig:
config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
def __post_init__(self):
"""Check the validity of the config."""
if self.use_precision_aware_optimizer:
assert (
self.optimizer == 'adam'
), '--use-precision-aware-optimizer only supported with adam'
assert (
self.use_distributed_optimizer
), '--use-precision-aware-optimizer only supported with distributed optimizer'
# Only the FusedAdam in TE and HybridDeviceOptimizer supports
# --use-precision-aware-optimizer.
# TODO: Remove this check when apex's FusedAdam is no longer used.
if self.optimizer_cpu_offload:
return
try:
import inspect
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
adam_args = inspect.signature(Adam).parameters
arg_names = [
'master_weight_dtype',
'exp_avg_dtype',
'exp_avg_sq_dtype',
'use_decoupled_grad',
]
for name in arg_names:
assert name in adam_args, (
"Current FusedAdam of TE doesn't support --use-precision-aware-optimizer, "
"please update TE version."
)
except ImportError:
raise RuntimeError(
'--use-precision-aware-optimizer requires FusedAdam from TransformerEngine, '
'but not found.'
)
else:
assert (
self.main_grads_dtype == torch.float32
), "main_grads_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.main_params_dtype == torch.float32
), "main_params_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_dtype == torch.float32
), "exp_avg_dtype can only be fp32 when not using precision-aware optimizer"
assert (
self.exp_avg_sq_dtype == torch.float32
), "exp_avg_sq_dtype can only be fp32 when not using precision-aware optimizer"
File mode changed from 100755 to 100644
......@@ -2,7 +2,7 @@
MAJOR = 0
MINOR = 10
MINOR = 12
PATCH = 0
PRE_RELEASE = 'rc0'
......
File mode changed from 100755 to 100644
......@@ -11,7 +11,7 @@ from typing import Callable, List, Optional
import torch
from .utils import GlobalMemoryBuffer
from .utils import GlobalMemoryBuffer, is_torch_min_version
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
......@@ -137,11 +137,49 @@ def get_nccl_options(pg_name, nccl_comm_cfgs):
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
if 'net_name' in nccl_comm_cfgs[pg_name]:
nccl_options.config.net_name = nccl_comm_cfgs[pg_name].get('net_name')
# verify net_name value
if nccl_options.config.net_name.lower() not in ['ib', 'socket']:
raise RuntimeError(
f"net_name ({nccl_options.config.net_name}) is not supported."
f"Accepted values: 'IB' or 'socket'."
)
return nccl_options
else:
return None
def create_group(
ranks=None,
timeout=None,
backend=None,
pg_options=None,
use_local_synchronization=False,
group_desc=None,
):
"""Creates a ProcessGroup."""
kwargs = {
'ranks': ranks,
'timeout': timeout,
'backend': backend,
'pg_options': pg_options,
'use_local_synchronization': use_local_synchronization,
'group_desc': group_desc,
}
if not is_torch_min_version('2.4.0'):
kwargs.pop('group_desc')
if timeout is None:
# Old version (e.g. v2.1.2) sets default_pg_timeout as default value to timeout
# in function signature, then check tiemout value type.
# New version sets None as default value to timeout in function signature. If value
# is None, torch will give value according to the backend, then check type.
# So need to unset timeout here if caller doesn't set value. Otherwise there is
# type error.
kwargs.pop('timeout')
return torch.distributed.new_group(**kwargs)
def generate_masked_orthogonal_rank_groups(
world_size: int, parallel_size: List[int], mask: List[bool]
) -> List[List[int]]:
......@@ -270,7 +308,7 @@ def create_hierarchical_parallel_groups(
hierarchical_groups = []
accumulated_group_sizes = 1
processed_group_sizes = 1
for hierarchical_group_size in hierarchical_group_sizes:
for level, hierarchical_group_size in enumerate(hierarchical_group_sizes):
accumulated_group_sizes *= hierarchical_group_size
for k in range(group_size // accumulated_group_sizes):
for j in range(processed_group_sizes):
......@@ -278,7 +316,11 @@ def create_hierarchical_parallel_groups(
ranks[j + i * processed_group_sizes + k * accumulated_group_sizes]
for i in range(hierarchical_group_size)
]
sub_group = torch.distributed.new_group(global_sub_ranks, pg_options=pg_options)
sub_group = create_group(
global_sub_ranks,
pg_options=pg_options,
group_desc=f'HIERARCHICAL_CONTEXT_PARALLEL_GROUP_L{level}',
)
if rank in global_sub_ranks:
hierarchical_groups.append(sub_group)
processed_group_sizes *= hierarchical_group_size
......@@ -392,6 +434,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
pipeline_model_parallel_comm_backend: Optional[str] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
......@@ -405,6 +448,7 @@ def initialize_model_parallel(
encoder_pipeline_model_parallel_size: Optional[int] = 0,
get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
create_gloo_process_groups: bool = True,
) -> None:
# pylint: disable=line-too-long
"""Initialize model data parallel groups.
......@@ -445,6 +489,10 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
pipeline_model_parallel_comm_backend (str, optional):
The backend to use for pipeline parallel communication.
If None, the default backend will be used.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
......@@ -519,6 +567,10 @@ def initialize_model_parallel(
A function that takes in a list of ranks for a pipeline group, and returns
those ranks that should have position embeddings.
create_gloo_process_groups (bool, default = True):
Create Gloo process groups if set to True. If set to False, Gloo process groups are
not created and calls to get Gloo process groups will result in assertion errors.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
......@@ -662,6 +714,12 @@ def initialize_model_parallel(
rank_offset=encoder_world_size,
)
assert (
order.endswith("pp")
or pipeline_model_parallel_size == 1
or expert_data_parallel_size == data_parallel_size
), "When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same"
assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks(
"pp"
), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
......@@ -715,28 +773,48 @@ def initialize_model_parallel(
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
for ranks in generator_wrapper('dp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('dp', nccl_comm_cfgs),
group_desc='DATA_PARALLEL_GROUP',
)
group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
if create_gloo_process_groups:
group_gloo = create_group(
ranks, timeout=timeout, backend="gloo", group_desc='DATA_PARALLEL_GROUP_GLOO'
)
else:
group_gloo = None
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_GLOO = group_gloo
_DATA_PARALLEL_GLOBAL_RANKS = ranks
assert (
data_parallel_size % num_distributed_optimizer_instances == 0
), 'Data parallel size should be divisible by partial DistOpt shard factor'
intra_partial_data_parallel_size = data_parallel_size // num_distributed_optimizer_instances
data_parallel_size * context_parallel_size
) % num_distributed_optimizer_instances == 0, (
'Data parallel size should be divisible by partial DistOpt shard factor'
)
intra_partial_data_parallel_size = (
data_parallel_size * context_parallel_size
) // num_distributed_optimizer_instances
for ranks_with_cp in generator_wrapper('dp-cp'):
group_with_cp = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
)
group_with_cp_gloo = torch.distributed.new_group(
ranks_with_cp, timeout=timeout, backend="gloo"
group_with_cp = create_group(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
group_desc='DATA_PARALLEL_GROUP_WITH_CP',
)
if create_gloo_process_groups:
group_with_cp_gloo = create_group(
ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc='DATA_PARALLEL_GROUP_WITH_CP_GLOO',
)
else:
group_with_cp_gloo = None
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
......@@ -752,14 +830,21 @@ def initialize_model_parallel(
)
]
intra_partial_data_parallel_group_with_cp = torch.distributed.new_group(
intra_partial_data_parallel_group_with_cp = create_group(
intra_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
)
intra_partial_data_parallel_group_with_cp_gloo = torch.distributed.new_group(
intra_partial_data_parallel_ranks_with_cp, timeout=timeout, backend="gloo"
pg_options=get_nccl_options('intra_dp_cp', nccl_comm_cfgs),
group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
)
if create_gloo_process_groups:
intra_partial_data_parallel_group_with_cp_gloo = create_group(
intra_partial_data_parallel_ranks_with_cp,
timeout=timeout,
backend="gloo",
group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO',
)
else:
intra_partial_data_parallel_group_with_cp_gloo = None
if rank in intra_partial_data_parallel_ranks_with_cp:
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
......@@ -774,10 +859,11 @@ def initialize_model_parallel(
i::intra_partial_data_parallel_size
]
inter_partial_data_parallel_group_with_cp = torch.distributed.new_group(
inter_partial_data_parallel_group_with_cp = create_group(
inter_partial_data_parallel_ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
pg_options=get_nccl_options('inter_dp_cp', nccl_comm_cfgs),
group_desc='INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
)
if rank in inter_partial_data_parallel_ranks_with_cp:
......@@ -813,8 +899,11 @@ def initialize_model_parallel(
global _CONTEXT_PARALLEL_GLOBAL_RANKS
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
for ranks in generator_wrapper('cp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('cp', nccl_comm_cfgs),
group_desc='CONTEXT_PARALLEL_GROUP',
)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
......@@ -826,7 +915,7 @@ def initialize_model_parallel(
ranks,
context_parallel_size,
hierarchical_context_parallel_sizes,
get_nccl_options('cp', nccl_comm_cfgs),
get_nccl_options('hcp', nccl_comm_cfgs),
)
# Build the model-parallel groups.
......@@ -834,8 +923,11 @@ def initialize_model_parallel(
global _MODEL_PARALLEL_GLOBAL_RANKS
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for ranks in generator_wrapper('tp-pp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('mp', nccl_comm_cfgs),
group_desc='MODEL_PARALLEL_GROUP',
)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
......@@ -848,8 +940,11 @@ def initialize_model_parallel(
_TENSOR_MODEL_PARALLEL_GROUP is None
), 'tensor model parallel group is already initialized'
for ranks in generator_wrapper('tp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp', nccl_comm_cfgs),
group_desc='TENSOR_MODEL_PARALLEL_GROUP',
)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
......@@ -868,10 +963,76 @@ def initialize_model_parallel(
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
if pipeline_model_parallel_comm_backend == 'ucc':
# The UCC backend provides two key benefits:
# 1) Achieves better bandwidth utilization than NCCL when using InfiniBand links.
# 2) Does not use GPU SM resources (Zero-SM), mitigating performance interference
# with overlapping compute kernels.
# The UCC backend is recommended in the following cases:
# 1) When the exposed pipeline-parallel (PP) communications are significant.
# - E.g., Pipeline parallelism with very less gradient accumulation steps.
# - It may provide better performance due to improved bandwidth utilization.
# 2) When the critical-path pipeline stage has substantial PP-communication overlap.
# - E.g., Uneven pipeline parallelism.
# - It may provide better performance due to zero SM resource usage.
if 'CUDA_DEVICE_MAX_CONNECTIONS' in os.environ:
# UCC backend requires CUDA_DEVICE_MAX_CONNECTIONS variable to be larger than 1,
# to gurantee the overlapped UCC communications. If this environment variable is set to 1,
# all the UCC communication will be serialized.
assert (
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] != '1'
), "UCC-backend requires CUDA_DEVICE_MAX_CONNECTIONS > 1"
# Setting up required environment variables for ucc backend
#
# "TORCH_UCC_BLOCKING_WAIT=none" allows non-blocking waits of the communiction handle
# "UCC_EC_CUDA_STREAM_TASK_MODE" controls how CUDA execution engines (EC)
# schedule tasks on CUDA streams.
# "UCX_TLS" controls transport layer selection
# "NSYS_UCP_COMM_PARAMS=1" enables capturing ucx tracing in nsys profiling
# "UCX_RNDV_THRESH" controls threshold threshold for switching between
# eager and rendezvous (RNDV) communication protocols.
# "UCX_NET_DEVICES" select which network interfaces UCX should use.
# "UCC_CL_BASIC_TLS" controls which Transport Layers are used by
# the Basic Collective libraray
os.environ['TORCH_UCC_BLOCKING_WAIT'] = (
os.environ['TORCH_UCC_BLOCKING_WAIT']
if "TORCH_UCC_BLOCKING_WAIT" in os.environ
else 'none'
)
os.environ['UCC_EC_CUDA_STREAM_TASK_MODE'] = (
os.environ['UCC_EC_CUDA_STREAM_TASK_MODE']
if "UCC_EC_CUDA_STREAM_TASK_MODE" in os.environ
else 'driver'
)
os.environ['UCX_TLS'] = (
os.environ['UCX_TLS'] if "UCX_TLS" in os.environ else 'ib,cuda_copy'
) # cuda_ipc (i.e., NVLink-enablement) will be later supported
os.environ['NSYS_UCP_COMM_PARAMS'] = '1'
os.environ['UCX_RNDV_THRESH'] = '0'
os.environ['UCX_NET_DEVICES'] = 'all'
os.environ['UCC_CL_BASIC_TLS'] = '^sharp,nccl'
for ranks in generator_wrapper('pp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('pp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
backend=pipeline_model_parallel_comm_backend,
pg_options=(
None
if pipeline_model_parallel_comm_backend == 'ucc'
else get_nccl_options('pp', nccl_comm_cfgs)
),
group_desc='PIPELINE_MODEL_PARALLEL_GROUP',
)
assert (
pipeline_model_parallel_comm_backend == None
or pipeline_model_parallel_comm_backend == 'nccl'
or pipeline_model_parallel_comm_backend == 'ucc'
), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported'
if rank in ranks:
if _PIPELINE_MODEL_PARALLEL_GROUP is None:
_PIPELINE_MODEL_PARALLEL_GROUP = group
......@@ -884,18 +1045,22 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS = [_PIPELINE_GLOBAL_RANKS, ranks]
embedding_ranks = get_embedding_ranks(ranks)
group = torch.distributed.new_group(
embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
group = create_group(
embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('embd', nccl_comm_cfgs),
group_desc='EMBEDDING_GROUP',
)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
position_embedding_ranks = get_position_embedding_ranks(ranks)
group = torch.distributed.new_group(
group = create_group(
position_embedding_ranks,
timeout=timeout,
pg_options=get_nccl_options('embd', nccl_comm_cfgs),
pg_options=get_nccl_options('pos_embd', nccl_comm_cfgs),
group_desc='POSITION_EMBEDDING_GROUP',
)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
......@@ -908,14 +1073,20 @@ def initialize_model_parallel(
_TENSOR_AND_DATA_PARALLEL_GROUP is None
), 'Tensor + data parallel group is already initialized'
for ranks in generator_wrapper('tp-dp-cp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs),
group_desc='TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP',
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
for ranks in generator_wrapper('tp-dp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs),
group_desc='TENSOR_AND_DATA_PARALLEL_GROUP',
)
if rank in ranks:
_TENSOR_AND_DATA_PARALLEL_GROUP = group
......@@ -925,8 +1096,11 @@ def initialize_model_parallel(
_TENSOR_AND_CONTEXT_PARALLEL_GROUP is None
), 'Tensor + context parallel group is already initialized'
for ranks in generator_wrapper('tp-cp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs),
group_desc='TENSOR_AND_CONTEXT_PARALLEL_GROUP',
)
if rank in ranks:
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = group
......@@ -936,8 +1110,10 @@ def initialize_model_parallel(
global _EXPERT_MODEL_PARALLEL_GROUP
assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
for ranks in generator_wrapper('ep', is_expert=True):
group = torch.distributed.new_group(
ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
group = create_group(
ranks,
pg_options=get_nccl_options('ep', nccl_comm_cfgs),
group_desc='EXPERT_MODEL_PARALLEL_GROUP',
)
if rank in ranks:
_EXPERT_MODEL_PARALLEL_GROUP = group
......@@ -948,8 +1124,11 @@ def initialize_model_parallel(
_EXPERT_TENSOR_PARALLEL_GROUP is None
), 'Expert tensor model parallel group is already initialized'
for ranks in generator_wrapper('tp', is_expert=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('ep_tp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_PARALLEL_GROUP',
)
if rank in ranks:
_EXPERT_TENSOR_PARALLEL_GROUP = group
......@@ -960,8 +1139,11 @@ def initialize_model_parallel(
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is None
), 'Expert tensor + model parallel group is already initialized'
for ranks in generator_wrapper('tp-ep', is_expert=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_ep_mp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP',
)
if rank in ranks:
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = group
......@@ -972,8 +1154,11 @@ def initialize_model_parallel(
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is None
), 'The expert_tensor_model_pipeline parallel group is already initialized'
for ranks in generator_wrapper('tp-ep-pp', is_expert=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('tp_ep_pp', nccl_comm_cfgs),
group_desc='EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP',
)
if rank in ranks:
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = group
......@@ -985,10 +1170,18 @@ def initialize_model_parallel(
assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, 'Expert data group-gloo is already initialized'
for ranks in generator_wrapper('dp', is_expert=True):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
group = create_group(
ranks,
timeout=timeout,
pg_options=get_nccl_options('ep_dp', nccl_comm_cfgs),
group_desc='EXPERT_DATA_PARALLEL_GROUP',
)
group_gloo = torch.distributed.new_group(ranks, backend="gloo")
if create_gloo_process_groups:
group_gloo = create_group(
ranks, backend="gloo", group_desc='EXPERT_DATA_PARALLEL_GROUP_GLOO'
)
else:
group_gloo = None
if rank in ranks:
_EXPERT_DATA_PARALLEL_GROUP = group
_EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
......@@ -1469,7 +1662,10 @@ def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last stage in the current rank's pipeline."""
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
return [group[last_rank_local] for group in _PIPELINE_GLOBAL_RANKS]
else:
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
......@@ -1733,6 +1929,14 @@ def get_expert_data_parallel_rank():
return 0
def get_expert_data_parallel_world_size():
"""Return world size for the expert data parallel group."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(group=get_expert_data_parallel_group())
else:
return 0
### End of expert-related functions region
......
File mode changed from 100755 to 100644
......@@ -166,11 +166,17 @@ def _p2p_ops(
):
reqs = {}
even_send_odd_recv_group = group
if get_pipeline_model_parallel_world_size() == 2:
if (
get_pipeline_model_parallel_world_size() == 2
and torch.distributed.get_backend(group) != 'ucc'
):
# Use the global process group for one of the two p2p communications
# to allow the overlap of the independent communications.
# Using the global process group is compatible because the pipeline-parallel
# communications set the source and destination by global rank.
# The only exception occurs when using the ‘ucc’ backend.
# Because the global communicator always uses the ‘nccl’ backend,
# we must ensure the else path is followed for the ‘ucc’ backend.
even_recv_odd_send_group = torch.distributed.group.WORLD
else:
even_recv_odd_send_group = group
......
......@@ -9,6 +9,7 @@ from torch.autograd.variable import Variable
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import (
drain_embedding_wgrad_compute,
......@@ -496,6 +497,9 @@ def forward_backward_no_pipelining(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1479,6 +1483,9 @@ def forward_backward_pipelining_with_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
......@@ -1874,4 +1881,7 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward').stop()
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
return forward_data_store
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques
including quantization and sparsity to compress model for efficient inference on
NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference.
More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.post_training.modelopt.layers import (
BlockwiseFP8WeightTransformerLayer,
FP8WeightTransformerLayer,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
def get_gpt_layer_modelopt_spec(
num_experts: Optional[int] = None,
local_core_attention: bool = False,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
warnings.warn(
"`get_gpt_layer_modelopt_spec` will be deprecated in a future release."
"Use `get_gpt_modelopt_spec` instead."
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
sharded_state_dict_keys_map = {}
if remap_te_layernorm:
if num_experts:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'
}
else:
sharded_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
k_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=sharded_state_dict_keys_map,
),
)
def get_gpt_modelopt_spec(
config: TransformerConfig,
local_core_attention: bool = False,
remap_te_layernorm: bool = False,
real_quant_cfg: str = "None",
):
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
moe_sharded_state_dict_keys_map = {}
dense_sharded_state_dict_keys_map = {}
if remap_te_layernorm:
input_layernorm_map = {'input_layernorm.': 'self_attention.linear_qkv.layer_norm_'}
mla_qk_layernorm_map = {
"self_attention.q_layernorm.": 'self_attention.linear_q_up_proj.layer_norm_',
"self_attention.kv_layernorm.": 'self_attention.linear_kv_up_proj.layer_norm_',
}
dense_sharded_state_dict_keys_map = {'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_'}
if not config.multi_latent_attention:
moe_sharded_state_dict_keys_map.update(input_layernorm_map)
dense_sharded_state_dict_keys_map.update(input_layernorm_map)
else:
if config.qk_layernorm:
moe_sharded_state_dict_keys_map.update(mla_qk_layernorm_map)
dense_sharded_state_dict_keys_map.update(mla_qk_layernorm_map)
if real_quant_cfg == "None":
transformer_layer = TransformerLayer
elif real_quant_cfg == "fp8_real_quant":
transformer_layer = FP8WeightTransformerLayer
elif real_quant_cfg == "fp8_blockwise_real_quant":
transformer_layer = BlockwiseFP8WeightTransformerLayer
else:
raise ValueError("RealQuantTransformerLayer does not support {}".format(real_quant_cfg))
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
if config.multi_latent_attention:
attn_module = MLASelfAttention
attn_submodules = MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
q_layernorm=TENorm,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
kv_layernorm=TENorm,
linear_kv_up_proj=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
)
else:
attn_module = SelfAttention
attn_submodules = SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if config.qk_layernorm else IdentityOp,
k_layernorm=TENorm if config.qk_layernorm else IdentityOp,
)
dense_mlp_spec = get_mlp_module_spec(use_te=False)
dense_layer_spec = ModuleSpec(
module=transformer_layer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=attn_module,
params={"attn_mask_type": AttnMaskType.causal},
submodules=attn_submodules,
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=dense_mlp_spec,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=dense_sharded_state_dict_keys_map,
),
)
if config.num_moe_experts is None:
return dense_layer_spec
moe_mlp_spec = get_mlp_module_spec(
use_te=False,
num_experts=config.num_moe_experts,
moe_grouped_gemm=False,
# use_te=True, num_experts=config.num_moe_experts, moe_grouped_gemm=True,
)
moe_layer_spec = ModuleSpec(
module=transformer_layer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=attn_module,
params={"attn_mask_type": AttnMaskType.causal},
submodules=attn_submodules,
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=moe_mlp_spec,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map=moe_sharded_state_dict_keys_map,
),
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = get_transformer_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=TENorm)
return block_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from logging import getLogger
import torch
logger = getLogger(__name__)
def mcore_gpt_load_legacy_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Register a pre-hook to fix the state_dict key difference.
This prehook is used when trying to load the legacy Megatron-LM GPTModel into its
megatron/core variant that uses native ParallelLinear and Transformer-Engine Norm.
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if "modelopt_state" in state_dict:
state_dict.pop("modelopt_state")
if "language_model" in state_dict:
language_model_state_dict = state_dict.pop("language_model")
if "embedding" in language_model_state_dict:
if "word_embeddings" in language_model_state_dict["embedding"]:
for key, param in language_model_state_dict["embedding"]["word_embeddings"].items():
state_dict.update({"embedding.word_embeddings." + key: param})
if "position_embeddings" in language_model_state_dict["embedding"]:
for key, param in language_model_state_dict["embedding"][
"position_embeddings"
].items():
state_dict.update({"embedding.position_embeddings." + key: param})
if "transformer" in language_model_state_dict:
for key, param in language_model_state_dict["transformer"].items():
state_dict.update({"decoder." + key: param})
else:
for key, param in language_model_state_dict["encoder"].items():
state_dict.update({"decoder." + key: param})
if "output_layer" in language_model_state_dict:
for key, param in language_model_state_dict["output_layer"].items():
state_dict.update({"output_layer." + key: param})
if torch.distributed.get_rank() == 0:
logger.info("ModelOptGPTModel {}".format(state_dict.keys()))
module_name_rewrite_list = [
("input_norm", "input_layernorm"),
(".attention.query_key_value", ".self_attention.linear_qkv"),
(".attention.dense", ".self_attention.linear_proj"),
("self_attention.query_key_value", "self_attention.linear_qkv"),
("self_attention.dense", "self_attention.linear_proj"),
("post_attention_layernorm", "pre_mlp_layernorm"),
("post_attention_norm", "pre_mlp_layernorm"),
("dense_h_to_4h", "linear_fc1"),
("dense_4h_to_h", "linear_fc2"),
("final_norm", "final_layernorm"),
]
key_rewrite_list = []
for key, _ in state_dict.items():
for old_name, new_name in module_name_rewrite_list:
if old_name in key:
key_rewrite_list += [(key, key.replace(old_name, new_name))]
for old_key, new_key in key_rewrite_list:
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)
def mcore_gpt_load_te_state_dict_pre_hook(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
"""Register a pre-hook to fix the state_dict key difference of.
This prehook is used when trying to load the megatron/core GPTModel that uses a
fused Transformer-Engine ParallelLinear into the variant that uses native ParallelLinear
and Transformer-Engine Norm (effectively to restore the fusion).
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if "modelopt_state" in state_dict:
state_dict.pop("modelopt_state")
key_with_te_extra_state_to_pop = []
for key, _ in state_dict.items():
if "_extra_state" in key:
key_with_te_extra_state_to_pop += [key]
for key in key_with_te_extra_state_to_pop:
state_dict.pop(key)
module_name_rewrite_list = [
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("mlp.linear_fc1.layer_norm_weight", "pre_mlp_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "pre_mlp_layernorm.bias"),
]
key_rewrite_list = []
for key, _ in state_dict.items():
for old_name, new_name in module_name_rewrite_list:
if old_name in key:
key_rewrite_list += [(key, key.replace(old_name, new_name))]
for old_key, new_key in key_rewrite_list:
if torch.distributed.get_rank() == 0:
logger.info("replace {} with {}".format(old_key, new_key))
state_dict[new_key] = state_dict[old_key]
state_dict.pop(old_key)
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from typing import Callable
import torch
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.nn import QuantModuleRegistry
from modelopt.torch.quantization.nn.modules.quant_linear import _QuantLinear
has_nvidia_modelopt = True
except Exception:
has_nvidia_modelopt = False
class Linear(torch.nn.Linear):
"""Local Linear impl as a replacement of TELinear."""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
self._return_bias = skip_bias_add and bias
if skip_weight_param_allocation:
raise ValueError('torch.nn.Linear layers do not support skip_weight_param_allocation')
super().__init__(
in_features=input_size, out_features=output_size, bias=bias, dtype=config.params_dtype
)
for param in self.parameters():
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', self.config.expert_model_parallel_size == 1)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
setattr(param, 'sequence_parallel', self.config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
for k, v in state_dict.items():
if "_amax" in k or "_scale" in k:
if v.ndim == 0:
state_dict[k] = v.view(1)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
state_dict, prefix, sharded_offsets=sharded_offsets
)
return sharded_state_dict
def forward(self, x):
"""Forward."""
out = super().forward(x)
if self._return_bias:
return out
return out, None
if has_nvidia_modelopt:
QuantModuleRegistry.register({Linear: Linear.__class__.__name__})(_QuantLinear)
class RealQuantTransformerLayer(TransformerLayer):
"""Real quantization transformer layer base class.
This base class iniitialize the default TransformerLayer and immediately
perform weight-only real quantization via TensorRT Model Optimizer.
All linear weights (Linear, ColumnParallelLinear, RowParallelLinear) picked
up will be replaced with low-bit data type (default torch.uint8). If sub-byte
real_quant_cfg is used, the weight shape will further be half.
This module cannot be trained (all parameters frozen).
"""
verbose: bool = False
real_quant_cfg: str = "None"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if has_nvidia_modelopt and self.real_quant_cfg != "None":
REAL_QUANT_CFG_CHOICES = {
"fp8_real_quant": mtq.FP8_PER_TENSOR_REAL_QUANT_CFG,
"fp8_blockwise_real_quant": mtq.FP8_BLOCKWISE_REAL_QUANT_CFG,
}
mtq_cfg = REAL_QUANT_CFG_CHOICES.get(self.real_quant_cfg, None)
if mtq_cfg is None:
raise ValueError(
"RealQuantTransformerLayer does not support {}".format(self.real_quant_cfg)
)
self._collect_original_tensor_info()
mtq.quantize(self, mtq_cfg)
delattr(self, "_modelopt_state")
# Freeze all parameters since the real-quant linears cannot be trained.
for param in self.parameters():
param.requires_grad = False
if self.verbose:
self._report_quantize_tensor_info()
def _collect_original_tensor_info(self):
self._original_tensor_info = {}
for k, v in self.state_dict().items():
if isinstance(v, torch.Tensor):
self._original_tensor_info[k] = (str(v.dtype), str(v.shape))
def _report_quantize_tensor_info(self):
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
for k, v in self.state_dict().items():
if not isinstance(v, torch.Tensor):
continue
original_dtype, original_shape = self._original_tensor_info.get(k, ("-", "-"))
print(
"{:<64} {:<16} {:<32} {:<16} {:<32}".format(
k, original_dtype, original_shape, str(v.dtype), str(v.shape)
)
)
torch.distributed.barrier()
class FP8WeightTransformerLayer(RealQuantTransformerLayer):
"""FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_real_quant"
class BlockwiseFP8WeightTransformerLayer(RealQuantTransformerLayer):
"""Blockwise FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_blockwise_real_quant"
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment