Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import torch
from packaging import version
from torch import Tensor, nn
from nanotron import logging
logger = logging.get_logger(__name__)
class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)
def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Module):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return torch.clip(gelu(x), self.min, self.max)
class AccurateGELUActivation(nn.Module):
"""
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
https://github.com/hendrycks/GELUs
Implemented along with MEGA (Moving Average Equipped Gated Attention)
"""
def __init__(self):
super().__init__()
self.precomputed_constant = math.sqrt(2 / math.pi)
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def forward(self, input: Tensor) -> Tensor:
return nn.functional.silu(input)
class MishActivation(nn.Module):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9.0"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * torch.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
class LaplaceActivation(nn.Module):
"""
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
https://arxiv.org/abs/2209.10655
Inspired by squared relu, but with bounded range and gradient for better stability
"""
def forward(self, input, mu=0.707107, sigma=0.282095):
input = (input - mu).div(sigma * math.sqrt(2.0))
return 0.5 * (1.0 + torch.erf(input))
class ReLUSquaredActivation(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward(self, input):
relu_applied = nn.functional.relu(input)
squared = torch.square(relu_applied)
return squared
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
import torch
from torch import nn
class TritonLayerNorm(nn.LayerNorm):
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=False,
return_dropout_mask=return_dropout_mask,
)
# This is equivalent to LLaMA RMSNorm
# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
None,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator
from nanotron.optim.zero import ZeroDistributedOptimizer
__all__ = [
"BaseOptimizer",
"InheritFromOtherOptimizer",
"NamedOptimizer",
"OptimizerFromGradientAccumulator",
"ZeroDistributedOptimizer",
]
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
import torch
from typing_extensions import TypeAlias
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
class BaseOptimizer(ABC):
id_to_name: Dict[int, str]
param_groups: List[Dict[str, Any]]
@abstractmethod
def __getstate__(self):
...
@abstractmethod
def __setstate__(self, state):
...
@abstractmethod
def __repr__(self):
...
@abstractmethod
def zero_grad(self):
...
@abstractmethod
def state_dict_additional_keys(self) -> Set[str]:
"""Additional states we store in `state_dict`. It has to be a dictionary using parameter name as key, and a tensor as value."""
...
@abstractmethod
def state_dict(self) -> dict:
...
@abstractmethod
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
...
@abstractmethod
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
...
def inherit_from(self, cls) -> bool:
...
Optimizer = TypeVar("Optimizer", BaseOptimizer, torch.optim.Optimizer)
# Modified from torch.optim.Optimizer._process_value_according_to_param_policy
@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
map_location: Optional[Union[str, torch.device]],
key: Hashable = None,
) -> torch.Tensor:
# If map_location is specified, use it instead of param.device
target_device = map_location if map_location is not None else param.device
fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break
if key == "step":
if capturable or fused:
return value.to(dtype=torch.float32, device=target_device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=target_device)
else:
return value.to(device=target_device)
# Modified from torch.optim.Optimizer.load_state_dict
@torch._disable_dynamo
def custom_load_state_dict(self, state_dict: StateDict, map_location: Union[str, torch.device]) -> None:
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
map_location (str or torch.device, optional): Device where to load the optimizer states.
If None, states will be loaded to the same device as their corresponding parameters.
Default: None
"""
# shallow copy, to be consistent with module API
state_dict = state_dict.copy()
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
# Validate the state_dict
groups = self.param_groups
# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict["param_groups"])
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of " "parameter groups")
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
)
# Update the state
id_map = dict(
zip(chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups))
)
def _cast(param, value, param_id=None, param_groups=None, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
return _process_value_according_to_param_policy(param, value, param_id, param_groups, map_location, key)
elif isinstance(value, dict):
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group["params"] = group["params"]
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})
for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)
from typing import Iterable, Optional, Tuple
import torch
import nanotron.distributed as dist
from nanotron import logging
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.parameters import NanotronParameter
logger = logging.get_logger(__name__)
def clip_grad_norm(
mp_pg: dist.ProcessGroup,
named_parameters: Iterable[Tuple[str, NanotronParameter]],
max_norm: float,
grad_accumulator: Optional[GradientAccumulator],
norm_type: float = 2.0,
) -> torch.Tensor:
"""Clips gradients. Adapted from torch.nn.utils.clip_grad_norm_.
Norms are computed in fp32 precision to retain most accuracy.
Args:
mp_pg (dist.ProcessGroup): Process group for model parallel, ie all the ranks part of the same model replica (TP x PP)
named_parameters (Iterable[(str, Parameter)]): an iterable of named Parameters that will have gradients normalized.
grad_accumulator (GradientAccumulator): grad accumulator. If not None, in case of Zero1, we need to clip all fp32 grads
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
.. note:: In case parameters contains tied weights, we keep only a single copy of the gradient, but modify the
gradient of all tied weights.
"""
named_parameters = list(named_parameters)
world_rank = dist.get_rank()
# assert that all params require grad
for _, p in named_parameters:
assert p.requires_grad, "clip_grad_norm_ only supports Tensors that require grad"
if grad_accumulator is None:
grads = [
p.grad for _, p in named_parameters if not p.is_tied or world_rank == p.get_tied_info().global_ranks[0]
]
else:
# In case of FP32 Grad Accum, We need to clip all fp32 grads
grads = [
grad_accumulator.get_grad_buffer(name)
for name, p in named_parameters
if not p.is_tied or world_rank == p.get_tied_info().global_ranks[0]
]
# Calculate gradient norm
if norm_type == torch.inf:
if len(grads) > 0:
total_norm = torch.max(
torch.stack([torch.linalg.vector_norm(g.detach(), ord=torch.inf, dtype=torch.float) for g in grads])
)
else:
total_norm = torch.zeros([], dtype=torch.float, device=torch.device("cuda"))
dist.all_reduce(total_norm, group=mp_pg, op=dist.ReduceOp.MAX)
else:
if len(grads) > 0:
# TODO @nouamanetazi: Check if we should calculate norm per parameter (remove .pow(norm_type)
total_norm = torch.linalg.vector_norm(
torch.stack([torch.linalg.vector_norm(g.detach(), ord=norm_type, dtype=torch.float) for g in grads]),
ord=norm_type,
dtype=torch.float,
).pow(norm_type)
else:
total_norm = torch.zeros([], dtype=torch.float, device=torch.device("cuda"))
dist.all_reduce(total_norm, group=mp_pg, op=dist.ReduceOp.SUM)
total_norm.pow_(1.0 / norm_type)
# Scale gradients
clip_coef = max_norm / (total_norm + 1.0e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
devices = {
param.grad.device if grad_accumulator is None else grad_accumulator.get_grad_buffer(name).device
for name, param in named_parameters
}
device_to_clip_coef_clamped = {device: clip_coef_clamped.to(device) for device in devices}
for name, param in named_parameters:
if grad_accumulator is None:
param.grad.detach().mul_(device_to_clip_coef_clamped[param.grad.device])
else:
grad_accumulator.get_grad_buffer(name).detach().mul_(
device_to_clip_coef_clamped[grad_accumulator.get_grad_buffer(name).device]
)
return total_norm
import dataclasses
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Callable, Dict, Iterator, Optional, Tuple
import torch
from torch.distributed import GradBucket
import nanotron.distributed as dist
from nanotron import logging
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage
logger = logging.get_logger(__name__)
class GradientAccumulator(ABC):
fp32_grads_allreduce_handle: Optional[torch.futures.Future]
@abstractmethod
def __init__(self, named_parameters: Iterator[Tuple[str, NanotronParameter]]):
...
@abstractmethod
def backward(self, loss: torch.Tensor):
...
@abstractmethod
def step(self):
...
@abstractmethod
def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.ReduceOp, reduce_scatter: bool):
...
@abstractmethod
def zero_grad(self):
...
@abstractmethod
def get_parameter_for_optimizer(self, name: str) -> NanotronParameter:
...
@abstractmethod
def get_grad_buffer(self, name: str) -> torch.Tensor:
...
@abstractmethod
def state_dict(self) -> Dict[str, torch.Tensor]:
...
@abstractmethod
def load_state_dict(self, state_dict: torch.Tensor):
...
class FP32GradientAccumulator(GradientAccumulator):
def __init__(
self,
named_parameters: Iterator[Tuple[str, NanotronParameter]],
grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None,
):
"""Create a gradient accumulator that will accumulate gradients in fp32.
Args:
named_parameters: The parameters that will be updated by the optimizer. In case of Zero 1, this is the parameters that will be updated in this DP rank.
grad_buckets_named_params: The parameters to accumulate gradients for. If None it defaults to `named_parameters`. In case of Zero 1, this should be all the parameters in the model.
Note: We use `grad_buckets_named_params` to keep grad buffers for all parameters even when Zero 1 is used. This is because we need to accumulate gradients for all parameters without having to reduce in every accumulation step.
Note: We make a fp32 copy of parameters during initialization. Therefore parameters need to be initialized or loaded from a checkpoint before constructing this gradient accumulator
"""
if grad_buckets_named_params is None:
named_parameters = list(named_parameters)
grad_buckets_named_params = named_parameters
# Initialize grad bucket
self.fp32_grad_buffers, self._contiguous_fp32_grad_buffer = self.build_grad_buffers(
named_parameters=grad_buckets_named_params
)
# Assign big buffer for weights + grad in fp32
segment_index = {}
length = 0
for name, param in named_parameters:
if not param.requires_grad:
continue
start = length
end_weight = start + param.numel()
assert name not in segment_index
segment_index[name] = (start, end_weight, param)
length = end_weight
big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda")
self.parameters = {
name: {
"fp32": big_flat_buffer[start_weight:end_weight].view_as(param),
"half": param,
}
for name, (start_weight, end_weight, param) in segment_index.items()
}
with torch.inference_mode():
for _, elt in self.parameters.items():
fp32_param = elt["fp32"]
half_param = elt["half"]
# Check that fp32 weights have the same memory representation as half precision weights
assert fp32_param.stride() == half_param.stride()
# Copy weights from half precision to full precision
fp32_param.copy_(half_param)
# Set requires_grad=True
fp32_param.requires_grad = True
self._is_accumulation_sync_step = False
# We need the last allreduce handle to make sure it finishes before the optimizer step
self.fp32_grads_allreduce_handle: Optional[torch.futures.Future] = None
def assign_param_offsets(self, param_name_to_offsets: Dict[str, Dict[int, Tuple[int, int]]], dp_rank: int):
"""To use only when you use with ZeRODistributedOptimizer"""
self.param_name_to_offsets = {
name: elt[dp_rank] for name, elt in param_name_to_offsets.items() if dp_rank in elt
}
def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.ReduceOp, reduce_scatter: bool):
if dp_pg.size() == 1:
# They are already synced
return
if reduce_scatter:
# Usually you need to run `all_reduce` in order for all gradients to be synced.
# However when the optimizer state are sharded, you really just need to scatter to ranks that are going to run the optimizer state.
# Effectively you replace a `all_reduce` with a `reduce_scatter` which should save an `all_gather` when using RING algorithm.
assert hasattr(self, "param_name_to_offsets")
named_offsets = sorted(self.param_name_to_offsets.items(), key=lambda x: x[0])
flat_grad_buffers = [self.fp32_grad_buffers[name]["fp32_grad"].view(-1) for name, _ in named_offsets]
dist.reduce_scatter_coalesced(
output_tensor_list=[
flat_grad_buffer[start_offset:end_offset]
for (_, (start_offset, end_offset)), flat_grad_buffer in zip(named_offsets, flat_grad_buffers)
],
input_tensor_lists=[
torch.split(
flat_grad_buffer,
split_size_or_sections=len(self.fp32_grad_buffers[name]["fp32_grad"].view(-1)) // dp_pg.size(),
)
for (name, _), flat_grad_buffer in zip(named_offsets, flat_grad_buffers)
],
group=dp_pg,
)
else:
dist.all_reduce(self._contiguous_fp32_grad_buffer, op=reduce_op, group=dp_pg)
@staticmethod
def build_grad_buffers(
named_parameters: Iterator[Tuple[str, NanotronParameter]],
) -> Tuple[Dict[str, Dict], torch.Tensor]:
"""Builds grad buffers for all model's parameters, independently of ZeRO sharding
Args:
named_parameters: Parameters to build buckets for. In case of Zero1, this should be all parameters.
Note:
In ZeRO-1, we need to accumulate grads for all parameters, because we need to allreduce all parameters' grads across DP at each sync step.
"""
named_parameters = [(name, param) for name, param in named_parameters if param.requires_grad]
needed_buffer_size = sum(param.numel() for _, param in named_parameters)
# important to have grads zeroed initially (see `self._accumulate_grad`)
contiguous_buffer_f32_gradients = torch.zeros(needed_buffer_size, dtype=torch.float, device="cuda")
untyped_storage = get_untyped_storage(contiguous_buffer_f32_gradients)
element_size = contiguous_buffer_f32_gradients.element_size()
# NOTE: Although `bias` can only exist on TP=0. It shouldn't be a problem here, because we only sync across DP
fp32_grad_buffers = OrderedDict() # keeps order of insertion
offset = 0
for name, param in named_parameters:
if not param.requires_grad:
continue
assert param.dtype != torch.float, f"Expected {name} not to be float"
assert param.is_contiguous(), f"Expected {name} to be contiguous"
next_offset = offset + param.numel() * element_size
fp32_grad_buffer = tensor_from_untyped_storage(
untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float
)
fp32_grad_buffers[name] = {
"half": param,
# We create sliced tensors by also slicing storage.
# We need to specify "cuda" in order to share the same data storage, otherwise it build the tensor in "cpu" and copies over the data
"fp32_grad": fp32_grad_buffer.view_as(param),
}
offset = next_offset
return fp32_grad_buffers, contiguous_buffer_f32_gradients
def backward(self, loss: torch.Tensor):
result = loss.backward()
for name, elt in self.fp32_grad_buffers.items():
self._accumulate_grad(name=name, half_param=elt["half"])
return result
def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
"""Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards"""
assert half_param.grad is not None, f"Expected param {name} to have gradient."
fp32_grad = self.get_grad_buffer(name=name)
if self._is_accumulation_sync_step is False:
# WARNING: We assume fp32_grad_bucket is already zeroed
fp32_grad.add_(half_param.grad)
# In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook
# TODO @thomasw21: Is it better to set to zero instead?
half_param.grad = None
# In the case an optimizer decides to set it to None, we need to re-assign previous buffer
if name in self.parameters:
fp32_param = self.parameters[name]["fp32"]
if hasattr(self, "param_name_to_offsets"):
if name not in self.param_name_to_offsets:
# When `name` isn't in `param_name_to_offsets` it means the slice is empty.
return
start_offset, end_offset = self.param_name_to_offsets[name]
grad = fp32_grad.view(-1)[start_offset:end_offset]
else:
grad = fp32_grad
fp32_param.grad = grad
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronizations across
data-parallel ranks.
Note: if we use `no_sync` once, that means we're in DDP mode, and we switch the default of self._is_accumulation_sync_step to True.
"""
old_is_accumulation_sync_step = self._is_accumulation_sync_step
self._is_accumulation_sync_step = False
try:
yield
finally:
self._is_accumulation_sync_step = old_is_accumulation_sync_step
@torch.inference_mode()
def step(self):
"""Updates fp32 weights from fp32 grads.
In case where OptimizerFromGradientAccumulator and gradient_accumulator_builder are using different parameters (e.g ZeRO).
We need to update only the parameters that were updated by the optimizer.
"""
for name in self.parameters.keys():
fp32_param = self.parameters[name]["fp32"]
half_param = self.parameters[name]["half"]
# TODO @nouamane: should we use a fused kernel to copy?
# Copy weights from full precision to half precision
half_param.copy_(fp32_param)
def zero_grad(self):
# Full precision gradients are reset to zero/none after the underlying `optimiser.step`, so no need to reset.
for elt in self.fp32_grad_buffers.values():
half_param = elt["half"]
if half_param.grad is None:
continue
half_param.grad = None
# in case where self.parameters and self.fp32_grad_buffers are not the same (e.g we want to accumulate all DPs grads, and only sync at sync step)
self._contiguous_fp32_grad_buffer.zero_()
def get_parameter_for_optimizer(self, name: str) -> NanotronParameter:
return self.parameters[name]["fp32"]
def get_grad_buffer(self, name: str) -> torch.Tensor:
"""Returns the gradient of the parameter from the appropriate grad bucket."""
return self.fp32_grad_buffers[name]["fp32_grad"]
def state_dict(self) -> Dict[str, torch.Tensor]:
# We consider `fp32` parameters as a state of the gradient accumulator
return {name: elt["fp32"] for name, elt in self.parameters.items()}
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
assert set(state_dict.keys()) == set(self.parameters.keys())
with torch.inference_mode():
for name, elt in self.parameters.items():
elt["fp32"].copy_(state_dict[name])
@dataclasses.dataclass
class FP32GradBucketManager:
"""Manages the fp32 gradient buckets.
Attributes:
dp_pg: The process group to allreduce gradients across.
accumulator: The gradient accumulator which keeps the gradient buffers.
bucket_id_to_fp32_grad_buckets_and_dependencies: A dictionary mapping bucket ids to:
- fp32 grad bucket (torch.Tensor)
- set of param ids that are in the bucket -> used to know when to delete the buffer
param_id_to_bucket_id: A dictionary mapping param ids to bucket ids."""
dp_pg: dist.ProcessGroup
accumulator: FP32GradientAccumulator
param_id_to_name: Dict[int, str]
def __post_init__(self):
self.accumulator._is_accumulation_sync_step = True
def get_fp32_accum_hook(
reduce_scatter: bool,
reduce_op: dist.ReduceOp = dist.ReduceOp.AVG,
) -> Callable:
"""Returns a DDP communication hook that performs gradient accumulation in fp32.
Args:
reduce_op: The reduction operation to perform.
"""
# s = torch.cuda.Stream()
def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
# nonlocal s
# DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation.
# See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details.
dp_pg = state.dp_pg
accumulator = state.accumulator
param_id_to_name = state.param_id_to_name
# Add new incoming gradient
# with torch.cuda.stream(s):
for param, grad in zip(bucket.parameters(), bucket.gradients()):
name = param_id_to_name[id(param)]
fp32_grad_buffer = accumulator.get_grad_buffer(name)
fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer))
# sync across dp
if dp_pg.size() == 1:
fut = torch.futures.Future()
fut.set_result(bucket.buffer())
return fut
if reduce_scatter:
assert hasattr(accumulator, "param_name_to_offsets")
grad_buffer_tensor_list = [
accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
]
device = grad_buffer_tensor_list[0].device
dtype = grad_buffer_tensor_list[0].dtype
output_tensor_list = [
grad_buffer[slice(*accumulator.param_name_to_offsets[param_id_to_name[id(param)]])]
if param_id_to_name[id(param)] in accumulator.param_name_to_offsets
else torch.empty(0, dtype=dtype, device=device)
for grad_buffer, param in zip(grad_buffer_tensor_list, bucket.parameters())
]
input_tensor_lists = [
torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size())
for grad_buffer in grad_buffer_tensor_list
]
dist.reduce_scatter_coalesced(
output_tensor_list=output_tensor_list,
input_tensor_lists=input_tensor_lists,
op=reduce_op,
group=dp_pg,
async_op=True,
)
else:
grad_buffer_tensor_list = [
accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters()
]
accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced(
grad_buffer_tensor_list, group=dp_pg, async_op=True, op=reduce_op
)
# we shouldn't wait for this future for the rest of the backward
# with torch.cuda.stream(s):
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
half_grad_bucket = bucket.buffer()
fut.set_result(half_grad_bucket)
return fut # We don't care about the new half grad values, so we return the old ones
return fp32_accum_hook
from functools import cache
from typing import Callable, Dict, Optional, Set, Union
import torch
from nanotron.optim.base import BaseOptimizer, Optimizer, custom_load_state_dict
class InheritFromOtherOptimizer(BaseOptimizer):
def __init__(self, optimizer: Optimizer, id_to_name: Dict[int, str]):
self.id_to_name = id_to_name
# if self.optimizer is from torch we replace load_state_dict with the one from torch
if isinstance(optimizer, torch.optim.Optimizer):
# Replace the load_state_dict method with our custom implementation that enables CPU offload
original_load_state_dict = optimizer.load_state_dict
optimizer.load_state_dict = (
lambda state_dict, map_location=None: custom_load_state_dict(
optimizer, state_dict, map_location=map_location
)
if map_location is not None
else original_load_state_dict(state_dict)
)
self.optimizer: Optimizer = optimizer
def __getstate__(self):
return self.optimizer.__getstate__()
def __setstate__(self, state):
return self.optimizer.__setstate__(state)
def __repr__(self):
return f"{self.__class__.__name__}({self.optimizer.__repr__()})"
def zero_grad(self):
return self.optimizer.zero_grad()
@cache
def state_dict_additional_keys(self) -> Set[str]:
if isinstance(self.optimizer, BaseOptimizer):
return self.optimizer.state_dict_additional_keys()
else:
return set()
def state_dict(self) -> dict:
return self.optimizer.state_dict()
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
return self.optimizer.load_state_dict(state_dict, map_location=map_location)
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
return self.optimizer.step(closure=closure)
def get_base_optimizer(self):
if isinstance(self.optimizer, torch.optim.Optimizer):
return self.optimizer
else:
return self.optimizer.get_base_optimizer()
@property
def param_groups(self):
return self.optimizer.param_groups
def inherit_from(self, cls):
if isinstance(self, cls):
return True
if isinstance(self.optimizer, InheritFromOtherOptimizer):
return self.optimizer.inherit_from(cls)
return False
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import torch
from nanotron import logging
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
logger = logging.get_logger(__name__)
class NamedOptimizer(InheritFromOtherOptimizer):
"""Mimics somewhat the torch optimizer API"""
def __init__(
self,
named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]],
optimizer_builder: Callable[[Iterable[Dict[str, Any]]], torch.optim.Optimizer],
):
named_param_groups = list(named_params_or_groups)
if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict):
named_param_groups = [{"named_params": named_param_groups}]
id_to_name = {}
params = []
for named_param_group in named_param_groups:
assert "named_params" in named_param_group
# Don't need to check that param_groups are overlapping since the optimizer will do it for me.
# https://github.com/pytorch/pytorch/blob/88b3810c94b45f5982df616e2bc4c471d173f491/torch/optim/optimizer.py#L473
id_to_name.update(
{id(param): name for name, param in named_param_group["named_params"] if id(param) not in id_to_name}
)
params.append(
{
**{k: v for k, v in named_param_group.items() if k != "named_params"},
"params": [param for _, param in named_param_group["named_params"]],
}
)
name_to_id = {v: k for k, v in id_to_name.items()}
assert len(id_to_name) == len(name_to_id)
# Sanity check
for param_group in params:
_params = param_group["params"]
for param in _params:
# https://github.com/pytorch/pytorch/issues/100701
assert param.numel() > 0
super().__init__(optimizer=optimizer_builder(params), id_to_name=id_to_name)
def state_dict(self) -> dict:
optim_state_dict = super().state_dict()
assert "names" not in optim_state_dict
state_id_to_name = {id(state): self.id_to_name[id(param)] for param, state in self.optimizer.state.items()}
optim_state_dict["names"] = {
index: state_id_to_name[id(state)] for index, state in optim_state_dict["state"].items()
}
return optim_state_dict
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
assert set(self.id_to_name.values()) == set(
state_dict["names"].values()
), f"Elements don't match:\n - Elements in `self.id_to_name` that aren't in the other one: {set(self.id_to_name.values()) - set(state_dict['names'].values())}\n - Elements in `state_dict[\"names\"]` that aren't in the other one: {set(state_dict['names'].values()) - set(self.id_to_name.values())}"
assert len(state_dict["state"]) == len(
state_dict["names"]
), f"Number of params in loaded state dict ({len(state_dict['state'])}) doesn't match number of names ({len(state_dict['names'])})"
assert len(state_dict["state"]) > 0, "Loading empty state dict"
OPTIMIZER_STATE_KEYS = sorted(state_dict["state"][0].keys() - {"step"})
for key in OPTIMIZER_STATE_KEYS:
for k, state in state_dict["state"].items():
assert (
key in state
), f"Key {key} not found in state dict: {state} which corresponds to param_name: {state_dict['names'][k]}"
return super().load_state_dict(state_dict, map_location=map_location)
from functools import cache
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.parallel.parameters import NanotronParameter
class OptimizerFromGradientAccumulator(InheritFromOtherOptimizer):
def __init__(
self,
gradient_accumulator_builder: Callable[[Iterable[Tuple[str, NanotronParameter]]], GradientAccumulator],
named_params_or_groups: Iterable[Union[Tuple[str, torch.Tensor], Dict[str, Any]]],
optimizer_builder: Callable[[Iterable[Dict[str, Any]]], BaseOptimizer],
):
named_param_groups = list(named_params_or_groups)
if len(named_param_groups) == 0 or not isinstance(named_param_groups[0], dict):
named_param_groups = [{"named_params": named_param_groups}]
name_to_param = {}
for named_param_group in named_param_groups:
for name, param in named_param_group["named_params"]:
if name in name_to_param:
raise ValueError(f"Duplicate key. {name} is already in `name_to_param`")
else:
name_to_param[name] = param
# Build gradient accumulator
gradient_accumulator = gradient_accumulator_builder(name_to_param.items())
self.gradient_accumulator = gradient_accumulator
# Obtained new params depending on the gradient accumulator
converted_named_param_group = [
{
**{k: v for k, v in named_param_group.items() if k != "named_params"},
"named_params": [
(name, gradient_accumulator.get_parameter_for_optimizer(name))
for name, _ in named_param_group["named_params"]
],
}
for named_param_group in named_param_groups
]
optimizer = optimizer_builder(converted_named_param_group)
super().__init__(optimizer=optimizer, id_to_name=optimizer.id_to_name)
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = super().step(closure)
self.gradient_accumulator.step()
return loss
def zero_grad(self):
super().zero_grad()
return self.gradient_accumulator.zero_grad()
@cache
def state_dict_additional_keys(self) -> Set[str]:
return super().state_dict_additional_keys() | {"gradient_accumulator"}
def state_dict(self) -> dict:
state_dict = super().state_dict()
assert "gradient_accumulator" not in state_dict
state_dict["gradient_accumulator"] = self.gradient_accumulator.state_dict()
return state_dict
def load_state_dict(self, state_dict: dict, map_location: Optional[Union[str, torch.device]] = None) -> None:
gradient_accumulator_state_dict = state_dict.pop("gradient_accumulator")
super().load_state_dict(state_dict, map_location=map_location)
self.gradient_accumulator.load_state_dict(gradient_accumulator_state_dict)
import itertools
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch.optim
from functorch.dim import tree_map
from torch import nn
from tqdm import tqdm
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import human_format, log_rank, warn_once
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
logger = logging.get_logger(__name__)
class ZeroDistributedOptimizer(InheritFromOtherOptimizer):
"""Optimizer that handles partitioning of optimizer's states across DP ranks. See ZeRO Stage 1 in the paper https://arxiv.org/abs/1910.02054v3 for more details."""
def __init__(
self,
named_params_or_groups: Iterable[Union[Tuple[str, NanotronParameter], Dict[str, Any]]],
optimizer_builder: Callable[[Iterable[Dict[str, Any]]], BaseOptimizer],
dp_pg: ProcessGroup,
):
named_params_or_groups = list(named_params_or_groups)
if len(named_params_or_groups) == 0 or isinstance(named_params_or_groups[0], dict):
# case where named_params_or_groups is Iterable[Dict[str, Any]]
for d in named_params_or_groups:
assert (
"named_params" in d
), f"param_groups must contain a 'named_params' key, got a dict with keys {d.keys()}"
# keep only named_params_or_groups that require grads
named_params_or_groups = [
{
"named_params": [
(name, param) for name, param in named_param_group["named_params"] if param.requires_grad
],
**{k: v for k, v in named_param_group.items() if k != "named_params"},
}
for named_param_group in named_params_or_groups
]
self.zero_named_param_groups = named_params_or_groups
else:
# case where named_params_or_groups is Iterable[Tuple[str, NanotronParameter]]
# keep only named_params_or_groups that require grads
named_params_or_groups = [(name, param) for name, param in named_params_or_groups if param.requires_grad]
self.zero_named_param_groups = [{"named_params": named_params_or_groups}]
self.dp_pg = dp_pg # DP process group
# partition model's params across DP ranks.
# `self.param_name_to_dp_rank_offsets` sets mapping between each param inside self.named_params and its rank
# NOTE: some param_groups may have no params in the current rank. we still keep them in self.optimizer.param_groups
self.param_name_to_dp_rank_offsets = self._partition_parameters()
current_dp_rank = dist.get_rank(self.dp_pg)
param_groups_in_rank = [
{
"named_params": [
(
name,
get_sliced_tensor(
param=param,
start_offset=self.param_name_to_dp_rank_offsets[name][current_dp_rank][0],
end_offset=self.param_name_to_dp_rank_offsets[name][current_dp_rank][1],
),
)
for name, param in param_group["named_params"]
if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
],
**{k: v for k, v in param_group.items() if k != "named_params"},
}
for param_group in self.zero_named_param_groups
]
# initialize rank's optimizer which is responsible for updating the rank's parameters
# NOTE: In case of ZeRO, `self.id_to_name` stores only names of parameters that are going to be updated by this DP rank's optimizer.
# NOTE: In case of ZeRO, `self.optimizer` will only get the parameters that are going to be updated by this DP's optimizer. Which
# means that `self.optimizer.param_groups` is only a subset of `self.param_groups`.
optimizer = optimizer_builder(param_groups_in_rank)
super().__init__(optimizer=optimizer, id_to_name=optimizer.id_to_name)
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step (parameter update)."""
# TODO: @nouamanetazi: handle syncing param groups attrs (e.g. if we update lr)
loss = super().step(closure=closure)
# calculate param_size (model) + param_size (grads) + 2*param_size/DP_if_zero1 (optim_states)
expected_allocated = sum(
param.numel() * param.element_size() * 2 + param.numel() * param.element_size() * 2 / self.dp_pg.size()
for named_param_group in self.zero_named_param_groups
for _, param in named_param_group["named_params"]
)
log_rank(
f"[After optim states allocation] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MB "
f"(Expected 2*param_size + 2*param_size/DP_if_zero1={expected_allocated / 1024**2:.2f}MB). "
f"Peak reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MB",
logger=logger,
level=logging.DEBUG,
group=self.dp_pg,
rank=0,
)
# All gather updated params
self._all_gather_params()
return loss
def zero_grad(self):
"""Copied from `torch.optim.optimizer.zero_grad` with the only change of using `self.param_groups` instead of `self.optimizer.param_groups`
because we want to zero out the gradients of all model params (not just the params in the current rank)"""
super().zero_grad()
# TODO @thomasw21: This is a call to torch internal API, we need to fix this
foreach = False # self.optimizer.defaults.get("foreach", False)
# TODO @thomasw21: This is a call to torch internal API, we need to fix this
# if not hasattr(self.optimizer, "_zero_grad_profile_name"):
# self.optimizer._hook_for_profile()
if foreach:
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
# TODO @thomasw21: This is a call to torch internal API, we need to fix this
# with torch.autograd.profiler.record_function(self.optimizer._zero_grad_profile_name):
# zero out the gradients of all model params (not just the params in the current rank)
for named_param_group in self.zero_named_param_groups:
for _, p in named_param_group["named_params"]:
if p.grad is not None:
p.grad = None
if foreach:
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._foreach_zero_(grads)
def _partition_parameters(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
named_params = [
(name, param)
for named_param_group in self.zero_named_param_groups
for name, param in named_param_group["named_params"]
if param.requires_grad
]
# maps each model's param to the optimizer's dp rank that is responsible for updating it
# We assume that parameters can be sharded across DP, ie we can "split" a parameter in different DP. This does break some optimizers, like Adafactor and such.
# `param_name_to_dp_rank_offsets[name]` is a `Dict[int, Tuple[int, int]]` keys are dp_rank, and `Tuple[int, int]` are the offsets of the param belonging to this DP
param_name_to_dp_rank_offsets = {}
# NOTE: save the original shapes before flattening the params
# so that later on, we can reshape the params to their original shapes
# for topology-agnostic optimizer states loading
self._orig_param_shapes = {}
for name, param in named_params:
self._orig_param_shapes[name] = param.shape
for name, param in named_params:
# We assume parameter to be contiguous in order to have an easy way of sharding it.
assert param.is_contiguous(), f"Parameter {name} is not contiguous"
numel = param.numel()
padded_numel_per_dp = (numel - 1) // self.dp_pg.size() + 1
sizes = np.full(shape=(self.dp_pg.size()), fill_value=padded_numel_per_dp)
remainder = padded_numel_per_dp * self.dp_pg.size() - numel
# Last `remainder` indices has one less element
if remainder > 0:
# It's weird that `size[-0:]` returns the entire list instead of nothing
sizes[-remainder:] -= 1
end_offsets = np.cumsum(sizes)
assert len(end_offsets) == self.dp_pg.size()
assert end_offsets[-1] == numel, f"Somehow {end_offsets[-1]} != {numel}"
# We want start indices,
start_offsets = np.concatenate([[0], end_offsets[:-1]])
param_name_to_dp_rank_offsets[name] = {
dp_rank: (start_offsets[dp_rank], end_offsets[dp_rank])
for dp_rank in range(self.dp_pg.size())
if start_offsets[dp_rank] < end_offsets[dp_rank] # Only if the slice is not empty.
}
log_rank("[ZeRO sharding] Size of optimizer params per rank:", logger=logger, level=logging.INFO, rank=0)
all_numel = sum(
param_name_to_dp_rank_offsets[name][dp_rank][1] - param_name_to_dp_rank_offsets[name][dp_rank][0]
for name, param in named_params
for dp_rank in range(self.dp_pg.size())
if dp_rank in param_name_to_dp_rank_offsets[name]
)
for dp_rank in range(self.dp_pg.size()):
acc_numel = sum(
value[dp_rank][1] - value[dp_rank][0]
for value in param_name_to_dp_rank_offsets.values()
if dp_rank in value
)
log_rank(
f"[ZeRO sharding] DP Rank {dp_rank} has {human_format(acc_numel)} out of {human_format(all_numel)} ({0 if all_numel == 0 else acc_numel / all_numel * 100:.2f}%) params' optimizer states",
logger=logger,
level=logging.INFO,
rank=0,
)
return param_name_to_dp_rank_offsets
def _all_gather_params(self):
"""All gather updated params"""
all_named_tensors_to_gather = [
(name, param.view(-1))
for named_param_groups in self.zero_named_param_groups
for name, param in named_param_groups["named_params"]
]
if len(all_named_tensors_to_gather) == 0:
# No need to broadcast if there's nothing
return
if self.dp_pg.size() == 1:
# They should already be updated
return
current_dp_rank = dist.get_rank(self.dp_pg)
dist.all_gather_coalesced(
output_tensor_lists=[
[
tensor[slice(*self.param_name_to_dp_rank_offsets[name][dp_rank])]
if dp_rank in self.param_name_to_dp_rank_offsets[name]
else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for dp_rank in range(self.dp_pg.size())
]
for name, tensor in all_named_tensors_to_gather
],
input_tensor_list=[
tensor[slice(*self.param_name_to_dp_rank_offsets[name][current_dp_rank])]
if current_dp_rank in self.param_name_to_dp_rank_offsets[name]
else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for name, tensor in all_named_tensors_to_gather
],
group=self.dp_pg,
)
# Helpers
class SlicedFlatTensor(torch.Tensor):
"""Subclass of `torch.Tensor` that unable to define `grad` getter on a slice of a flattened tensor."""
# Based on torch/testing/_internal/logging_tensor.py
# https://github.com/pytorch/pytorch/issues/102337#issuecomment-1579673041
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def get_sliced_flat_tensor(data, start_offset, end_offset):
with torch.no_grad():
return data.view(-1)[start_offset:end_offset]
@staticmethod
def __new__(cls, data, start_offset, end_offset):
sliced_tensor = cls.get_sliced_flat_tensor(data=data, start_offset=start_offset, end_offset=end_offset)
result = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
sliced_tensor.size(),
strides=sliced_tensor.stride(),
storage_offset=sliced_tensor.storage_offset(),
# TODO: clone storage aliasing
dtype=sliced_tensor.dtype,
layout=sliced_tensor.layout,
device=sliced_tensor.device,
requires_grad=sliced_tensor.requires_grad,
)
return result
def __init__(self, data, start_offset, end_offset):
super().__init__()
# TODO @thomasw21: Make is so that you can never update this value
self.sliced_flat_tensor = self.get_sliced_flat_tensor(
data=data, start_offset=start_offset, end_offset=end_offset
)
self.orig_data = data
self.start_offset = start_offset
self.end_offset = end_offset
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.sliced_flat_tensor if isinstance(e, cls) else e
def never_wrap(e):
# Never re-wrap
return e
return tree_map(never_wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
def _get_grad(self):
if self.orig_data.grad is None:
return None
with torch.no_grad():
return self.orig_data.grad.view(-1)[self.start_offset : self.end_offset]
def _set_grad(self, grad):
if grad is not None:
orig_grad = self._get_grad()
if orig_grad is None:
raise NotImplementedError(
"Trying to set gradient on a sliced tensor when the original tensor hasn't allocated the buffer for the gradient"
)
orig_grad.copy_(grad)
return
# TODO @thomasw21: This is unfortunately necessary since we might pass `SliceTensor` to the optimizer.
warn_once(
logger=logger,
msg="You're setting a `SlicedTensor` gradient to None. We're going to assume you meant to set the original tensor gradient to None.",
rank=0,
)
self.orig_data.grad = None
def _del_grad(self):
raise NotImplementedError
# TODO @thomasw21: Figure out why this function doesn't get inherited. https://github.com/pytorch/pytorch/issues/102337#issuecomment-1634363356
def data_ptr(self):
return self.sliced_flat_tensor.data_ptr()
grad = property(_get_grad, _set_grad, _del_grad)
def get_sliced_tensor(param: NanotronParameter, start_offset: int, end_offset: int):
# This allows us to create a leaf tensor, despite sharing the underlying storage
result = SlicedFlatTensor(data=param, start_offset=start_offset, end_offset=end_offset)
return result
def find_optim_index_from_param_name(
param_name: str,
# NOTE: (pp_rank, dp_rank, tp_rank) or (pp_rank, tp_rank)
ckp_sharded_optim_states: Union[Tuple[Tuple[int, int, int], torch.Tensor], Tuple[Tuple[int, int], torch.Tensor]],
is_zero1: bool,
pp_rank=0,
) -> int:
param_name = param_name.replace("module.", "")
# NOTE: since all shards have the same optim state names
# so we take the first shard (except optionally the pp dimension)
if is_zero1 is True:
# NOTE: (pp_rank, dp_rank, tp_rank)
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0, 0)]["names"]
else:
# NOTE: (pp_rank, tp_rank)
OPTIM_STATE_INDEX_TO_PARAM_NAME = ckp_sharded_optim_states[(pp_rank, 0)]["names"]
return next((k for k, v in OPTIM_STATE_INDEX_TO_PARAM_NAME.items() if v == param_name), None)
def extract_parallel_ranks_from_shard_path(
shard_path: Path, is_zero1: bool
) -> Union[Tuple[int, int, int], Tuple[int, int]]:
"""Extract parallel ranks from shard path
For example, if the shard path is:
+ For ZeRO-1: /path/to/optimizer_pp-0-of-1_dp-0-of-2_tp-0-of-1.pt
then the function will return (0, 0, 0) (pp_rank, dp_rank, tp_rank)
For ZeRO-0: /path/to/optimizer_pp-0-of-1_tp-0-of-1.pt
then the function will return (0, 0) (pp_rank, tp_rank)
"""
if is_zero1 is True:
# TODO(xrsrke): use the same pattern as weight checkpoints
# in weight checkpoints, we do pp-rank-.... but here we only do pp-...
# TODO(xrsrke): don't hardcode this
pattern = r"optimizer_pp-(\d+)-of-\d+_dp-(\d+)-of-\d+_tp-(\d+)-of-\d+\.pt"
match = re.search(pattern, str(shard_path))
pp_rank, dp_rank, tp_rank = match.groups()
return int(pp_rank), int(dp_rank), int(tp_rank)
else:
# NOTE: this is zero0 checkpoint
pattern = r"pp-(\d+)-of-\d+_tp-(\d+)-of-\d+"
match = re.search(pattern, str(shard_path))
pp_rank, tp_rank = match.groups()
return int(pp_rank), int(tp_rank)
def merge_dp_shard_in_zero1_optimizer(
model: nn.Module,
optimizer_config,
shard_paths: List[Path],
parallel_context: ParallelContext,
map_location: Optional[str] = None,
) -> Dict[Tuple[int, int], Dict[str, torch.Tensor]]: # (pp_rank, tp_rank): param_name -> optim_state
assert (
optimizer_config["configs"]["param_name_to_dp_rank_offsets"] is not None
), "param_name_to_dp_rank_offsets is required"
checkpoint_pp_size = optimizer_config["parallelism"]["pp_size"]
checkpoint_tp_size = optimizer_config["parallelism"]["tp_size"]
ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, dp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=True)
ckp_sharded_optim_states[(pp_rank, dp_rank, tp_rank)] = torch.load(shard_path, map_location=map_location)
param_name_to_dp_rank_offsets = optimizer_config["configs"]["param_name_to_dp_rank_offsets"]
optimizer_state_names = ckp_sharded_optim_states[(0, 0, 0)]["state"][0].keys()
def get_numel_of_unsharded_dp_param(param_name):
dp_offsets = param_name_to_dp_rank_offsets[param_name]
return max(int(value) for values in dp_offsets.values() for value in values)
def assign_shard_to_buffer(buffer, offset, value):
offset_start, offset_end = map(int, offset)
buffer[offset_start:offset_end] = value
param_names = sorted(model.state_dict().keys(), key=lambda x: x)
ckp_merged_dp_shards_optim_states = {}
for pp_rank, tp_rank in tqdm(
list(itertools.product(range(int(checkpoint_pp_size)), range(int(checkpoint_tp_size)))),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Merging ZeRO-1's shards across data parallel dimension",
):
# NOTE: filter only the shards that correspond to the current pp_rank and tp_rank
filtered_ckp_sharded_optim_states = {}
for (pp, dp, tp), ckp_optim_state in ckp_sharded_optim_states.items():
if pp == pp_rank and tp == tp_rank:
filtered_ckp_sharded_optim_states[dp] = ckp_optim_state
# NOTE: now merge the shards across data parallel dimension
# for each parameter, we need to merge all shards across data parallel dimension
merged_dp_shards_optim_states = {}
merged_dp_shards_optim_states["state"] = {}
for param_name in param_names:
unshard_dp_size = get_numel_of_unsharded_dp_param(param_name)
optim_state_index = find_optim_index_from_param_name(
param_name=param_name,
ckp_sharded_optim_states=ckp_sharded_optim_states,
is_zero1=True,
)
merged_dp_shards_optim_states["state"][optim_state_index] = {}
for state_name in optimizer_state_names:
unsharded_dp_buffer = torch.zeros(unshard_dp_size, device="cuda")
# NOTE: now merge all the params across data parallel dimension
for dp_rank, ckp_optim_state in filtered_ckp_sharded_optim_states.items():
# NOTE: extract the optimizer state of the current parameter
ckp_optim_state = ckp_optim_state["state"][optim_state_index]
ckp_offset = param_name_to_dp_rank_offsets[param_name][str(dp_rank)]
assign_shard_to_buffer(unsharded_dp_buffer, ckp_offset, ckp_optim_state[state_name])
# NOTE: in optimizer states, the "state" use an index to represent the parameter
# not the parameter name
merged_dp_shards_optim_states["state"][optim_state_index][state_name] = unsharded_dp_buffer
# NOTE: each dp shard has the same step
merged_dp_shards_optim_states["state"][optim_state_index]["step"] = ckp_optim_state["step"]
ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)] = merged_dp_shards_optim_states
# NOTE: each dp shard has the same names, and param_groups since it's the same tp shard
# the 0 in (pp_rank, 0, tp_rank) is the dp_rank
ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)]["names"] = ckp_sharded_optim_states[
(pp_rank, 0, tp_rank)
]["names"]
ckp_merged_dp_shards_optim_states[(pp_rank, tp_rank)]["param_groups"] = ckp_sharded_optim_states[
(pp_rank, 0, tp_rank)
]["param_groups"]
assert len(ckp_merged_dp_shards_optim_states) == int(checkpoint_pp_size) * int(
checkpoint_tp_size
), f"Expect {int(checkpoint_pp_size) * int(checkpoint_tp_size)} merged dp shards, got {len(ckp_merged_dp_shards_optim_states)}"
# NOTE: sanity check, make sure each merged checkpoint
# has the same dict key as the original checkpoint
for (pp_rank, tp_rank), ckp_optim_state in ckp_merged_dp_shards_optim_states.items():
# NOTE: we remove the gradient_accumulator key from sanity check
# because we don't merge gradient_accumulator states
missing_keys = set(ckp_optim_state.keys()) - set(ckp_sharded_optim_states[(pp_rank, 0, tp_rank)].keys())
assert (
len(missing_keys - {"gradient_accumulator"}) == 0
), "Expected the merged dp shards to have the same keys as the original dp shards, but merged dp shard misses: {}".format(
missing_keys
)
return ckp_merged_dp_shards_optim_states
# flake8: noqa
from nanotron.parallel.context import ParallelContext
import os
from typing import Literal, Tuple, Annotated
import numpy as np
import torch
import nanotron.distributed as dist
DistributedBackend = Literal["gloo", "mpi", "nccl"]
class ParallelContext:
def __init__(
self,
tensor_parallel_size: int,
pipeline_parallel_size: int,
data_parallel_size: int,
expert_parallel_size: int = 1,
backend: DistributedBackend = "nccl",
):
"""Initialize parallel context."""
num_gpus_per_model = tensor_parallel_size * pipeline_parallel_size * expert_parallel_size
world_size = int(os.environ["WORLD_SIZE"])
assert (
world_size % data_parallel_size == 0
), "The total number of processes must be divisible by the data parallel size."
assert world_size % num_gpus_per_model == 0, (
"The total number of processes must be divisible by"
"the number of GPUs per model (tensor_parallel_size * pipeline_parallel_size)."
)
if num_gpus_per_model * data_parallel_size != world_size:
raise ValueError(
f"The number of process requires to run all replicas ({num_gpus_per_model * data_parallel_size})",
f"must be equal to the world size ({world_size}).",
)
if not dist.is_available():
raise ValueError("torch.distributed is not available as a package, please install it.")
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
self.data_parallel_size = data_parallel_size
self.expert_parallel_size = expert_parallel_size
self._groups = {}
self.set_device()
assert backend == "nccl", "Only nccl backend is supported for now."
if not dist.is_initialized():
dist.initialize_torch_distributed()
world_size = int(os.getenv("WORLD_SIZE", "1"))
ranks = list(range(world_size))
process_group = dist.new_group(
ranks=ranks,
backend=dist.get_backend(),
)
self.world_pg = process_group
self._init_parallel_groups()
def _init_parallel_groups(self):
"""Initialize 3D parallelism's all process groups."""
dist.barrier()
world_size = int(os.environ["WORLD_SIZE"])
ranks = np.arange(0, world_size).reshape(
(
self.expert_parallel_size,
self.pipeline_parallel_size,
self.data_parallel_size,
self.tensor_parallel_size,
)
)
self.world_ranks_to_pg = {}
# Relevant process groups containing the current rank
self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3)).reshape((-1, self.tensor_parallel_size)))
self.dp_pg = self.create_new_group(ranks.transpose((3, 0, 1, 2)).reshape((-1, self.data_parallel_size)))
self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 0, 1)).reshape((-1, self.pipeline_parallel_size)))
self.expert_pg = self.create_new_group(ranks.transpose((1, 2, 3, 0)).reshape((-1, self.expert_parallel_size)))
# model parallel group = combination of tp and pp and exp for a given dp rank
self.mp_pg = self.create_new_group(
[ranks[:, :, dp_rank, :].reshape(-1) for dp_rank in range(self.data_parallel_size)]
)
self.tp_and_expert_pg = self.create_new_group(
[
ranks[:, pp_rank, dp_rank, :].reshape(-1)
for pp_rank in range(self.pipeline_parallel_size)
for dp_rank in range(self.data_parallel_size)
]
)
self.world_rank_matrix: np.ndarray = ranks
def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup:
dist.barrier()
rank = int(os.environ["RANK"])
new_group_containing_rank = None
for group_ranks in all_groups_ranks:
sorted_ranks = tuple(sorted(group_ranks))
# add new group to `world_ranks_to_pg`
if sorted_ranks not in self.world_ranks_to_pg:
new_group = dist.new_group(ranks=group_ranks)
self.world_ranks_to_pg[sorted_ranks] = new_group
else:
new_group = self.world_ranks_to_pg[sorted_ranks]
if rank in sorted_ranks:
new_group_containing_rank = new_group
dist.barrier()
return new_group_containing_rank
def set_device(self):
local_rank = int(os.getenv("LOCAL_RANK", "0"))
# NOTE: Set the device id.
# `torch.cuda.device_count` should return the number of device on a single node.
# We assume the nodes to be homogeneous (same number of gpus per node)
device_id = local_rank
torch.cuda.set_device(torch.cuda.device(device_id))
def get_local_ranks(self, world_rank: int) -> Tuple[int, int, int]:
return tuple(i.item() for i in np.where(self.world_rank_matrix == world_rank))
def destroy(self):
if not dist.is_initialized():
return
dist.barrier()
dist.destroy_process_group()
def get_global_rank(
self,
ep_rank: int,
pp_rank: int,
dp_rank: int,
tp_rank: int,
) -> np.int64:
"""
Get the global rank based on the specified ranks in different parallel groups.
:param ep_rank: int, Rank in the expert parallel group.
:param pp_rank: int, Rank in the pipeline parallel group.
:param dp_rank: int, Rank in the data parallel group.
:param tp_rank: int, Rank in the tensor parallel group.
:return: numpy.int64, The global rank.
"""
return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank]
\ No newline at end of file
from contextlib import contextmanager
from typing import Optional
import torch
from nanotron import distributed as dist
from nanotron.optim.gradient_accumulator import GradientAccumulator
from torch import nn
@contextmanager
def ddp_trigger_sync_in_bwd(model_ddp):
"""Trigger the sync of the gradients in the next backward pass of DDP model."""
assert isinstance(model_ddp, torch.nn.parallel.DistributedDataParallel)
old_require_backward_grad_sync = model_ddp.require_backward_grad_sync
old_require_forward_param_sync = model_ddp.require_forward_param_sync
model_ddp.require_backward_grad_sync = True
model_ddp.require_forward_param_sync = True
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L1325-L1356
model_ddp.reducer.prepare_for_backward([])
try:
yield
finally:
model_ddp.require_backward_grad_sync = old_require_backward_grad_sync
model_ddp.require_forward_param_sync = old_require_forward_param_sync
def sync_gradients_across_dp(
module: nn.Module,
dp_pg: dist.ProcessGroup,
reduce_op: dist.ReduceOp,
grad_accumulator: Optional[GradientAccumulator],
**sync_options,
):
"""Sync gradients across data parallelism.
Args:
module: The module to sync gradients for.
dp_pg: The data parallelism process group.
reduce_op: The reduce operation to use.
grad_accumulator: The gradient accumulator to use.
sync_options: Additional options given when using `grad_accumulator`. Please look at `GradientAccumulator.sync_gradients_across_dp` for documentation
"""
if grad_accumulator is not None:
# This is an optimized path that
grad_accumulator.sync_gradients_across_dp(dp_pg=dp_pg, reduce_op=reduce_op, **sync_options)
return
# Sync gradients
for name, param in module.named_parameters():
dist.all_reduce(param.grad, op=reduce_op, group=dp_pg)
import dataclasses
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
if TYPE_CHECKING:
from nanotron.models import NanotronModel
logger = logging.get_logger(__name__)
@dataclasses.dataclass
class SlicesPair:
local_slices: Tuple[slice, ...]
global_slices: Tuple[slice, ...]
@staticmethod
def slice_to_str(s: slice):
# e.g. slice(0, 10, 2) -> "0,10,2"
# e.g. slice(None, None, None) -> "None,None,None"
return ",".join(str(x) if x is not None else "None" for x in (s.start, s.stop, s.step))
@staticmethod
def str_to_slice(s: str):
return slice(*(int(x) if x != "None" else None for x in s.split(",")))
def __str__(self):
# e.g. local_slices (slice(0, 10, 2), slice(None, None, None)) -> "0,10,2|None,None,None"
local_slices_str = "|".join(map(self.slice_to_str, self.local_slices))
# e.g. global_slices (slice(0, 20, 4), slice(None, None, None)) -> "0,20,4|None,None,None"
global_slices_str = "|".join(map(self.slice_to_str, self.global_slices))
# e.g. "0,10,2|None,None,None#0,20,4|None,None,None"
return f"{local_slices_str}#{global_slices_str}"
@classmethod
def from_str(cls, string: str):
local_slices_str, global_slices_str = string.split("#")
local_slices = tuple(map(cls.str_to_slice, local_slices_str.split("|")))
global_slices = tuple(map(cls.str_to_slice, global_slices_str.split("|")))
return cls(local_slices, global_slices)
@classmethod
def tuple_to_str(cls, pairs):
# e.g. 2 SlicesPair, 1st SlicesPair local_slices "0,10,2|None,None,None" and global_slices "0,10,2|None,None,None"
# 2nd SlicesPair local_slices "0,20,4|None,None,None" and global_slices "0,40,8|None,None,None"
# -> "0,10,2|None,None,None#0,10,2|None,None,None;0,20,4|None,None,None#0,40,8|None,None,None"
return ";".join(map(str, pairs))
@classmethod
def tuple_from_str(cls, string: str):
return tuple(map(cls.from_str, string.split(";")))
@dataclasses.dataclass
class TiedInfo:
name: str
# name must be defined starting from `root_module` (e.g. root_module.dense0.dense1.weight)
root_module: nn.Module
global_ranks: Tuple[int, ...]
# None signifies that we do not reduce
reduce_op: Optional[dist.ReduceOp]
def get_full_name_from_model(self, model: nn.Module) -> str:
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
return self.get_full_name_from_module_id_to_prefix(module_id_to_prefix)
def get_full_name_from_module_id_to_prefix(self, module_id_to_prefix: Dict[int, str]) -> str:
return f"{module_id_to_prefix[id(self.root_module)]}{self.name}" # this assumes root_module is part of module_id_to_prefix
@dataclasses.dataclass
class ShardedInfo:
global_ranks: Tuple[int, ...]
# Info of to what slice of the unsharded tensor (global_slices) the current sharded tensor corresponds (local_slices)
local_global_slices_pairs: Tuple[SlicesPair, ...]
# The shape of the unsharded tensor
unsharded_shape: Tuple[int, ...]
def is_tp_sharded(self, parallel_context) -> bool:
return set(dist.get_global_ranks(parallel_context.tp_pg)).issubset(set(self.global_ranks))
def is_expert_sharded(self, parallel_context) -> bool:
return set(dist.get_global_ranks(parallel_context.expert_pg)).issubset(set(self.global_ranks))
def is_dp_sharded(self, parallel_context):
return set(dist.get_global_ranks(parallel_context.dp_pg)).issubset(set(self.global_ranks))
class NanotronParameter(nn.Parameter):
"""Base class for all parameters in Nanotronmodels
A NanotronParameter can have specific properties:
- sharded: the parameter is considered to be `sharded` across multiple devices
- tied: the parameter is considered to be `tied` with other parameters. We sum gradients over those.
.. note::
Notes about tied weights:
- Tied weights means weights that need to be synced only within the same DP rank, regardless if they are part of TP strategy or just shared weights between two layers.
- Syncing tied weights usually require to sum gradients.
- Some weights are synced without needing to reduce grads over ranks. They can be in the same device (ex: enc/dec embeds in the same PP stage) or they can be duplicated across TP and duplicate the workload across TP ranks (ex: LN using traditional TP)
- Even if some weights don't need their grads to be reduced, it's still useful for them to be marked as tied. For example, current serialization format requires to mark them correctly.
"""
NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME = "__nanotron_metadata__"
NANOTRON_PARAMETER_METADATA_TIED_KEY = "tied"
NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded"
def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True):
param = nn.Parameter.__new__(cls, data=tensor.data.detach(), requires_grad=requires_grad)
if isinstance(tensor, NanotronParameter):
# Check that we don't inherit a weird class
# We copy in order not to make in-place operation
assert type(tensor) == NanotronParameter
setattr(
param,
cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME,
getattr(tensor, cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME).copy(),
)
else:
setattr(param, cls.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, {})
return param
def _set_metadata(self, key: str, value: Any):
metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)
if key in metadata:
raise ValueError(
f"We shouldn't override previous metadata. Key to be overridden: {key}, current metadata: {metadata}"
)
else:
metadata[key] = value
def mark_as_tied(
self,
name: str,
global_ranks: Tuple[int, ...],
reduce_op: Optional[dist.ReduceOp],
root_module: "NanotronModel",
):
self._set_metadata(
self.NANOTRON_PARAMETER_METADATA_TIED_KEY,
TiedInfo(name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module),
)
def get_tied_info(self) -> TiedInfo:
return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[
self.NANOTRON_PARAMETER_METADATA_TIED_KEY
]
@property
def is_tied(self) -> bool:
return self.NANOTRON_PARAMETER_METADATA_TIED_KEY in getattr(
self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
)
def mark_as_sharded(
self,
global_ranks: Tuple[int, ...],
local_global_slices_pairs: Tuple[SlicesPair, ...],
unsharded_shape: Tuple[int, ...],
):
self._set_metadata(
self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY,
ShardedInfo(
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
),
)
def get_sharded_info(self) -> ShardedInfo:
return getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)[
self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY
]
@property
def is_sharded(self) -> bool:
return self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY in getattr(
self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME
)
def sanity_check(root_module: nn.Module):
"""Makes sure that the module is in Nanotronformat
Format:
- all parameters are `NanotronParameter`, this allows us to add metadata to a parameter.
"""
for name, param in root_module.named_parameters():
if not isinstance(param, NanotronParameter):
raise ValueError(
f"Nanotronrequires model to be in Nanotronformat, ie all parameters are required to be a NanotronParameter. {name} isn't."
)
## Pipeline parallelism
We choose to mimic the "torch" eager semantics:
- Declare module/blocks at init time
- Declare edges during forward
# Scheduling
# All forward, all backward (easy, but memory expensive)
# 1f1b (much nicer)
We're going to assume that all Pipeline blocks are assigned to a rank in a contiguous manner.
Warmup:
```
Rank 1: [forward(), forward(), forward(), forward(), backward()]
Rank 2: [forward(), forward(), forward(), forward(), backward(), backward()]
Rank 3: [forward(), forward(), forward(), forward(), backward(), backward(), backward()]
Rank 4: [forward(), backward(), forward(), backward(), forward(), backward()]
```
// TODO @thomasw21: How do we extrapolate this notion to a tree. Not sure exactly, but topological ordering should be fine
# TODOs:
- [ ] passing activation that don't require backward screws me as 1f1b works because you have the same number of forward and the same number of backward (in the stage sense)
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel.pipeline_parallel.functional import (
recv_from_pipeline_state_buffer,
send_to_pipeline_state_buffer,
)
from nanotron.parallel.pipeline_parallel.p2p import P2P, BatchTensorSendRecvState
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState, PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
class PipelineBlock(nn.Module):
"""Most granular pipeline block, ie within this module, everything will be part of a single rank, ie the entire computation within this block will happen on a specific device.
Current limitations:
- PipelineBlocks have to wrap a method/function/module that outputs a Dict[str, torch.Tensor]
Some considerations:
- In the literature, authors often refer to pipeline stages as a granularity block. Our notion is more granular. A pipeline stage is list of contiguous (in the forward sense) of pipeline blocks.
All PipelineBlock definition exist in each rank, they are just instantiated/built on a single rank per pipeline parallel process group.
"""
def __init__(
self,
p2p: P2P,
module_builder: Callable[..., Callable[..., Union[torch.Tensor, Dict[str, torch.Tensor]]]],
module_kwargs: Dict[str, Any],
module_input_keys: Set[str],
module_output_keys: Set[str],
):
super().__init__()
# Module follows a restrictive API: module.forward return a `Dict[str, torch.Tensor]`
self.p2p = p2p
# None signifies that we don't use specific pipeline engine and just run typical torch forward/backward pass
self.pipeline_state: Optional[PipelineBatchState] = None
self.module_builder = module_builder
self.module_kwargs = module_kwargs
self.module_input_keys = set(module_input_keys)
self.module_output_keys = set(module_output_keys)
def build_and_set_rank(self, pp_rank: int):
"""This method is used to define on which rank computation is going to happen"""
assert pp_rank < self.p2p.pg.size()
self.rank = pp_rank
if pp_rank == dist.get_rank(self.p2p.pg):
# Instantiate the module
self.pp_block = self.module_builder(**self.module_kwargs)
def extra_repr(self) -> str:
return f"pp_rank={self.rank}" if hasattr(self, "rank") else ""
def set_pipeline_state(self, pipeline_state: Optional[PipelineBatchState]):
self.pipeline_state = pipeline_state
def forward(self, **kwargs):
"""Forward pass
We use a mechanism using TensorPointers to pass Tensors around
All non Tensor object or TensorPointers are considered pass-through, they are never meant to be communicated cross process
:param kwargs: Dict[str, Union[TensorPointer, torch.Tensor, Any]]
:return: Dict[str, Union[TensorPointer, torch.Tensor, Any]
"""
assert self.module_input_keys == set(
kwargs.keys()
), f"Expected {self.module_input_keys}, got {set(kwargs.keys())}"
sorted_kwargs = sorted(kwargs.items(), key=get_sort_key(dist.get_rank(self.p2p.pg)))
# Is the current rank is not the one running the compute
if dist.get_rank(self.p2p.pg) != self.rank:
# TODO(kunhao): A better design is to pop this up for both if else branches.
batch_send_recv = BatchTensorSendRecvState(self.p2p)
# Send activations from other devices to local rank
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is neither the rank holding the data nor the rank responsible for computing block
continue
else:
assert isinstance(tensor, torch.Tensor)
# We need to send the tensor to the rank that actually runs the compute
if self.pipeline_state is not None:
send_to_pipeline_state_buffer(
tensor,
to_rank=self.rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
if tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
batch_send_recv.add_send(tensor=tensor, to_rank=self.rank)
batch_send_recv.flush()
# Return that the outputs are all in the rank responsible for computing block
# TODO @thomasw21: Figure out a way to build dummy_input in a generic sense, and remove the necessity to have Dict[str, torch.Tensor] as output
return {k: TensorPointer(group_rank=self.rank) for k in self.module_output_keys}
# Recv activations from other devices to local rank
new_kwargs: Dict[str, torch.Tensor] = {}
name_to_recv_id = {}
batch_send_recv = BatchTensorSendRecvState(self.p2p)
for name, tensor in sorted_kwargs:
if isinstance(tensor, TensorPointer):
# Current rank is the one running the compute, we need to query the tensor
# new_kwargs[name] = recv_tensor(from_rank=tensor.group_rank, p2p=self.p2p)
# This assumes that prior communication was already done
# In case of interleaved 1f1b, if this is the second model chunk, then we need to send the previous activations before receiving the current activations
if isinstance(self.pipeline_state, PipelineTrainBatchState):
for _ in range(len(self.pipeline_state.microbatches_activations_to_send)):
send_activation = self.pipeline_state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
if self.pipeline_state is not None:
new_kwargs[name] = recv_from_pipeline_state_buffer(
from_rank=tensor.group_rank,
p2p=self.p2p,
pipeline_state=self.pipeline_state,
)
continue
# We don't store result in a buffer
recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank)
name_to_recv_id[name] = recv_id
else:
new_kwargs[name] = tensor
# Run receiving communications
recv_tensors = batch_send_recv.flush()
assert len(recv_tensors) == len(name_to_recv_id)
for name, recv_id in name_to_recv_id.items():
assert name not in new_kwargs
new_tensor = recv_tensors[recv_id]
if new_tensor.requires_grad is True:
raise ValueError(
f"Pipeline engine is None and tensor requires grad. Tried receiving a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine."
)
new_kwargs[name] = new_tensor
output = self.pp_block(**new_kwargs)
# Helper for functions that return tensors
if isinstance(output, torch.Tensor):
assert len(self.module_output_keys) == 1
output = {next(iter(self.module_output_keys)): output}
assert isinstance(output, dict), "Modules within a Pipeline Block have to return a Dict[str, torch.Tensor]"
assert self.module_output_keys == set(
output.keys()
), f"Expected {self.module_output_keys}, got {set(output.keys())}"
return output
def get_min_max_rank(module: torch.nn.Module) -> Tuple[int, int]:
"""Finds min and max PP ranks of the underlying PipelineBlocks"""
ranks = [module.rank for module in module.modules() if isinstance(module, PipelineBlock)]
return min(ranks), max(ranks)
def get_sort_key(current_rank: int):
"""The idea is to free earlier ranks earlier."""
def sort_key(elt: Tuple[str, Union[torch.Tensor, TensorPointer]]):
name, tensor = elt
rank: int
if isinstance(tensor, TensorPointer):
rank = tensor.group_rank
else:
rank = current_rank
return rank, name
return sort_key
from contextlib import contextmanager
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.state import PipelineBatchState
from torch import nn as torch_nn
@contextmanager
def attach_pipeline_state_to_model(model: torch_nn.Module, pipeline_state: PipelineBatchState):
"""Attach the pipeline state to all the PipelineBlocks within `model`"""
old_pipeline_states = []
# Set new
for name, module in model.named_modules():
if not isinstance(module, PipelineBlock):
continue
old_pipeline_state = module.pipeline_state
assert old_pipeline_state is None, "We never replace an old pipeline engine, we just set one when there's none"
old_pipeline_states.append((old_pipeline_state, module))
module.set_pipeline_state(pipeline_state)
try:
yield
finally:
for old_pipeline_state, module in old_pipeline_states:
module.set_pipeline_state(old_pipeline_state)
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel
logger = logging.get_logger(__name__)
class PipelineEngine(ABC):
def __init__(self):
self.nb_microbatches: Optional[int] = None
pass
def forward(
self,
context: ContextManagers,
state: PipelineTrainBatchState,
micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]],
model: torch_nn.Module,
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Increment the number of backwards
state.nb_forwards += 1
log_rank(
f"Forward micro batch id: {state.nb_forwards}",
logger=logger,
level=logging.DEBUG,
)
# IMPORTANT as it's basically the context manager storing all the intermediary activations
state.new_micro_batch_forward()
with context:
output = model(**micro_batch)
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# We normalize our loss
if not isinstance(output["loss"], TensorPointer):
output["loss"] = output["loss"] / self.nb_microbatches
# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
return output
@staticmethod
def _get_fwd_context(model: torch_nn.Module):
is_ddp = isinstance(model, DistributedDataParallel)
# We never to trigger a DDP sync in the next backward pass
context = ContextManagers([model.no_sync()] if is_ddp else [])
return context
def backward(
self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator]
):
# Increment the number of backwards
state.nb_backwards += 1
log_rank(
f"Backward micro batch id: {state.nb_forwards}",
logger=logger,
level=logging.DEBUG,
)
# Go backward entirely
activations = state.pop_last_activations_requiring_backward()
if len(activations) == 0:
return
with context:
if grad_accumulator is None:
sum(activations).backward()
else:
grad_accumulator.backward(sum(activations))
# TODO @nouamane: this fixes interleaved afab but makes 1f1b hang
# with context:
# if grad_accumulator is None:
# for activation in reversed(activations): #TODO @nouamane: need to bwd only 2nd chunk
# activation.backward()
# else:
# for activation in reversed(activations):
# grad_accumulator.backward(activation)
def _get_bwd_context(
self,
model: torch_nn.Module,
nb_backwards: int,
grad_accumulator: Optional[GradientAccumulator],
):
assert (
self.nb_microbatches is not None
), "You must call `train_batch_iter` first and set `self.nb_microbatches`"
is_ddp = isinstance(model, DistributedDataParallel)
context_list = []
if is_ddp:
if grad_accumulator is not None and nb_backwards < self.nb_microbatches - 1:
context_list.append(grad_accumulator.no_sync()) # Prevents grad accumulator from syncing
if nb_backwards == self.nb_microbatches - 1:
# Triggers DDP to sync gradients in the next backward pass
context_list.append(ddp_trigger_sync_in_bwd(model_ddp=model))
context = ContextManagers(context_list)
return context
@abstractmethod
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
"""If model returns tensor, we use it as a loss to backpropagate. If model returns a dict, we assume that the key "loss" is the loss to backpropagate."""
...
@torch.inference_mode()
def validate_batch_iter(
self,
model: torch_nn.Module,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState() # TODO: do i need state?
self.nb_microbatches = nb_microbatches
outputs = []
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
return outputs
class AllForwardAllBackwardPipelineEngine(PipelineEngine):
def __init__(self):
super().__init__()
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState()
self.nb_microbatches = nb_microbatches
outputs = []
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
# All backward
for _ in range(len(state.microbatches_activations_requiring_backward)):
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
# Make sure that micro batches are all fully consumed
state.check_buffers_empty()
return outputs
class OneForwardOneBackwardPipelineEngine(PipelineEngine):
def __init__(self):
super().__init__()
def train_batch_iter(
self,
model: torch_nn.Module,
pg: ProcessGroup,
batch: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
nb_microbatches: int,
grad_accumulator: Optional[GradientAccumulator],
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
"""Check https://arxiv.org/abs/2104.04473 for diagrams for the pipeline engine"""
self.nb_microbatches = nb_microbatches
assert (
self.nb_microbatches >= pg.size() - 1
), f"Number of microbatches ({self.nb_microbatches}) must be at least PP_SIZE-1={pg.size() - 1} when using the OneForwardOneBackwardPipelineEngine"
state = PipelineTrainBatchState()
outputs = []
batch = iter(batch)
current_pp_rank = dist.get_rank(pg)
with attach_pipeline_state_to_model(model=model, pipeline_state=state):
# Init
for _ in range(pg.size() - current_pp_rank - 1):
micro_batch = next(batch)
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Send tensors
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
# Execute
send_activation()
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
# We make `output` a dict
if not isinstance(output, dict):
output = {"loss": output}
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)
# One backward
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
# Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left
assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1
# No more activation to send/recv
assert (
len(state.microbatches_activations_to_send) == 0
), f"There are activations left for me to send still: {len(state.microbatches_activations_to_send)}"
assert (
len(state.microbatches_activations_to_recv) == 0
), f"There are activations left for me to recv still: {len(state.microbatches_activations_to_recv)}"
# Close: compute backward for the rest
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
for _ in range(len(state.microbatches_activations_requiring_backward)):
context = self._get_bwd_context(
model=model,
nb_backwards=state.nb_backwards,
grad_accumulator=grad_accumulator,
)
self.backward(context=context, state=state, grad_accumulator=grad_accumulator)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_grads_to_send)):
send_grads = state.microbatches_grads_to_send.popleft()
# Execute
send_grads()
# Make sure that micro batches are all fully consumed
state.check_buffers_empty()
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment