Commit 71e79847 authored by chenzk's avatar chenzk
Browse files

v1.0.3

parents
Pipeline #2034 canceled with stages
import dataclasses
@dataclasses.dataclass
class TensorPointer:
"""Dataclass specifying from which rank we need to query a tensor from in order to access data"""
# Needed to understand from which rank to get the tensor
# TODO @thomasw21: Maybe add which group it belongs to as well? Typically this is highly correlated to `p2p.pg`
group_rank: int
# TODO @thomasw21: Maybe add a tag (torch.distributed.send/recv allow for tagging)
from nanotron.models import NanotronModel
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from torch import nn
from torch.nn.parallel import DistributedDataParallel
def get_input_output_pp_ranks(model: NanotronModel | DistributedDataParallel):
if isinstance(model, DistributedDataParallel):
input_pp_rank = model.module.input_pp_rank
output_pp_rank = model.module.output_pp_rank
else:
input_pp_rank = model.input_pp_rank
output_pp_rank = model.output_pp_rank
return input_pp_rank, output_pp_rank
def get_pp_rank_of(target: str, module: nn.Module):
"""Assuming a model with pipeline blocks, we want to know in which pp rank the module/parameter whose name is `target`"""
if isinstance(module, PipelineBlock):
return module.rank
atoms = target.split(".")
current_module = module
for atom in atoms:
if not hasattr(current_module, atom):
raise AttributeError(f'{current_module._get_name()} has no attribute `"{atom}"`')
current_module = getattr(current_module, atom)
if isinstance(current_module, PipelineBlock):
return current_module.rank
if not isinstance(current_module, nn.Module):
raise AttributeError(f'`"{atom}"` is not an nn.Module')
raise ValueError(f'`"{target}" is not inside a PipelineBlock and thus does not have a pp_rank')
import dataclasses
from typing import List, Optional, Tuple
import numpy as np
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel.parameters import NanotronParameter, SlicesPair
@dataclasses.dataclass
class SplitConfig:
split_dim: int
# contiguous_chunks is a tuple of chunk sizes along the split_dim
# sharding happens inside each chunk
# if None, by default contiguous_chunks = (len(unsharded_param.shape[split_dim]),)
contiguous_chunks: Optional[Tuple[int, ...]] = None
def create_sharded_parameter(
parameter: nn.Parameter,
global_ranks: Tuple[int, ...],
local_global_slices_pairs: Tuple[SlicesPair, ...],
unsharded_shape: Tuple[int, ...],
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_sharded(
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
return parameter
def create_sharded_parameter_from_config(
parameter: nn.Parameter,
pg: dist.ProcessGroup,
split_config: SplitConfig,
) -> NanotronParameter:
current_rank = dist.get_rank(pg)
param_num_dims = len(parameter.shape)
global_ranks = dist.get_global_ranks(pg)
split_dim = split_config.split_dim
assert split_dim < param_num_dims
contiguous_chunks = split_config.contiguous_chunks
if contiguous_chunks is None:
# we are assuming that the parameter is contiguous along the split_dim, i.e. 1 whole chunk
# all parameters are equally shardable across the process group along the split_dim
shard_length = parameter.shape[split_dim]
global_slice = slice(current_rank * shard_length, (current_rank + 1) * shard_length)
# construct a mapping from local slices to global slices, multi-dimensional version
local_slices = tuple(slice(None) for _ in range(param_num_dims))
global_slices = tuple(global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims))
local_global_slices_pairs = (SlicesPair(local_slices=local_slices, global_slices=global_slices),)
unsharded_shape = tuple(
pg.size() * param_dim_size if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
else:
# support custom contiguous chunk size for sharding each along the split_dim
local_global_slices_pairs: List[SlicesPair] = []
chunks_global_offset = np.cumsum((0,) + contiguous_chunks)
chunks_local_offset = chunks_global_offset // pg.size()
for chunk, chunk_global_start, chunk_local_start, chunk_local_end in zip(
contiguous_chunks,
chunks_global_offset[:-1],
chunks_local_offset[:-1],
chunks_local_offset[1:],
strict=True,
):
# we assume that we are doing equal split at the chunk level
assert chunk % pg.size() == 0, f"chunk size {chunk} must be divisible by process group size {pg.size()}"
shard_length = chunk // pg.size()
# we have: chunk_local_end = chunk_local_start + shard_length
local_slice = slice(chunk_local_start, chunk_local_end)
global_slice = slice(
current_rank * shard_length + chunk_global_start,
(current_rank + 1) * shard_length + chunk_global_start,
)
local_slices = tuple(
local_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
global_slices = tuple(
global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims)
)
local_global_slices_pairs.append(SlicesPair(local_slices=local_slices, global_slices=global_slices))
local_global_slices_pairs: Tuple[SlicesPair, ...] = tuple(local_global_slices_pairs)
unsharded_shape = tuple(
chunks_global_offset[-1] if dim_id == split_dim else param_dim_size
for dim_id, param_dim_size in enumerate(parameter.shape)
)
return create_sharded_parameter(
parameter=parameter,
global_ranks=global_ranks,
local_global_slices_pairs=local_global_slices_pairs,
unsharded_shape=unsharded_shape,
)
def mark_all_parameters_in_module_as_sharded(module: nn.Module, pg: dist.ProcessGroup, split_config: SplitConfig):
"""
Mark parameters as sharded within a module. We assume that parameters are equally shardable across the process group.
:param module: nn.Module
:param pg: dist.ProcessGroup
:param split_config: SplitConfig
:return:
"""
for module_name, submodule in module.named_modules():
for param_name, param in list(submodule.named_parameters(recurse=False)):
new_param = create_sharded_parameter_from_config(parameter=param, pg=pg, split_config=split_config)
setattr(submodule, param_name, new_param)
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
class DifferentiableIdentity(torch.autograd.Function):
"""All-reduce gradients in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
return tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group), None
class DifferentiableAllReduceSum(torch.autograd.Function):
"""All-reduce in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
if group.size() == 1:
return tensor
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class DifferentiableAllGather(torch.autograd.Function):
"""All gather in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group)
return unsharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None
class DifferentiableReduceScatterSum(torch.autograd.Function):
"""Reduce scatter in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: shard along another dimension
unsharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
assert unsharded_batch_size % group.size() == 0
# TODO @thomasw21: Collectives seem to require tensors to be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
tensor = tensor.contiguous()
sharded_tensor = torch.empty(
unsharded_batch_size // group.size(),
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllGather.apply(grad_output, group), None
# -----------------
# Helper functions.
# -----------------
def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableIdentity.apply(tensor, group)
def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllReduceSum.apply(tensor, group)
def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllGather.apply(tensor, group)
def differentiable_reduce_scatter_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableReduceScatterSum.apply(tensor, group)
from enum import Enum, auto
# TODO @thomasw21: python 3.11 introduces `StrEnum` which would've been great to use.
class TensorParallelLinearMode(Enum):
ALL_REDUCE = auto()
REDUCE_SCATTER = auto()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from torch.nn import functional as F
import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits, # (batch_size, length, sharded_hidden_size)
target, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)
# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size
# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss.view_as(target)
@staticmethod
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtype: torch.dtype = None):
"""Helper function for the cross entropy."""
if dtype is not None:
# Cast input to specific dtype.
sharded_logits = sharded_logits.to(dtype=dtype)
return _ShardedCrossEntropy.apply(sharded_logits, target, group)
class _ColumnLinearAsyncCommunication(torch.autograd.Function):
"""Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215"""
@staticmethod
@assert_cuda_max_connections_set_to_1
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
group_size = group.size()
current_rank = dist.get_rank(group)
if group_size == 1:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
else:
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
# ctx.save_for_backward(tensor, weight)
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)
# Compute a shard of column_linear in the same time of AllGather
# We could compute the matmul of current holding shard and the current rank's weight
# We assume that rank 0 holds w0, rank 1 holds w1, etc.
# weights: w0 w1 w2 w3
# rank 0: X - - -
# rank 1: - X - -
# rank 2: - - X -
# rank 3: - - - X
# We call the corresponding shard of output "same_device_shard"
output_size = weight.shape[0]
gathered_output = torch.empty(
gathered_batch_size,
*intermediate_size,
output_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
before_shard, same_device_shard, after_shard = torch.split(
gathered_output,
split_size_or_sections=[
sharded_batch_size * current_rank,
sharded_batch_size,
sharded_batch_size * (group_size - current_rank - 1),
],
dim=0,
)
first_dims = math.prod([sharded_batch_size, *intermediate_size])
if bias is None:
torch.mm(
input=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)
# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
# rank 0: - X X X
# rank 1: X - X X
# rank 2: X X - X
# rank 3: X X X -
# As they could be not contiguous (r1 and r2) vertically as they are separated by "same_device_shard"
# We need to compute them separately, i.e. "before_shard" and "after_shard"
# For r0, "before_shard" is empty. For r3, "after_shard" is empty.
if before_shard.numel() > 0:
first_dims = math.prod(before_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
if after_shard.numel() > 0:
first_dims = math.prod(after_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
return gathered_output
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
if group.size() == 1:
total_tensor = tensor
else:
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
else:
total_tensor = tensor
grad_tensor = grad_output.matmul(weight)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()
grad_bias = grad_output.sum(dim=0) if use_bias else None
if handle1 is not None:
handle1.wait()
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
if handle2 is not None:
handle2.wait()
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""
@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):
# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)
# Get linear output.
out = F.linear(total_input, weight, bias)
return out
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)
# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None
return sub_grad_input, grad_weight, grad_bias, None, None
def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _RowLinearAsyncCommunication(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, weight, bias, group, tp_mode):
assert (
tp_mode is TensorParallelLinearMode.REDUCE_SCATTER
), f"async communication in RowLinear only supports REDUCE_SCATTER, got {tp_mode}"
if group is None:
group = dist.distributed_c10d._get_default_group()
ctx.use_bias = bias is not None
ctx.group = group
out = F.linear(tensor, weight, bias)
if group.size() > 1:
out = differentiable_reduce_scatter_sum(out, group=group)
ctx.save_for_backward(tensor, weight)
return out
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
handle: Optional[dist.Work] = None
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = grad_output.shape
if group.size() == 1:
total_grad_output = grad_output
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
# total_grad_output: [b, s, h_out]
# weight: [h_out, h_in/n]
# total_grad_tensor: [b, s, h_in/n]
# grad_output: [b/n, s, h_out]
sharded_batch_size, *rest_size_grad_output = grad_output.shape
rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]
if group.size() == 1:
total_grad_tensor = grad_output.matmul(weight)
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_tensor = torch.empty(
unsharded_batch_size,
*rest_size_grad_tensor,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
)
before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
total_grad_tensor,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# compute local shard
torch.mm(
input=grad_output.view(-1, grad_output.shape[-1]),
mat2=weight,
out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
)
if handle is not None:
handle.wait()
before_shard_grad_output, _, after_shard_grad_output = torch.split(
total_grad_output,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# before shard compute
if before_shard_grad_tensor.numel() > 0:
torch.mm(
input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
mat2=weight,
out=before_shard_grad_tensor.view(-1, weight.shape[1]),
)
# after shard compute
if after_shard_grad_tensor.numel() > 0:
torch.mm(
input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
mat2=weight,
out=after_shard_grad_tensor.view(-1, weight.shape[1]),
)
# Convert the tensor shapes to 2D for execution compatibility
tensor = tensor.contiguous()
tensor_first_dims, tensor_last_dim = tensor.shape[:-1], tensor.shape[-1]
tensor = tensor.view(math.prod(tensor_first_dims), tensor_last_dim)
# Convert the tensor shapes to 2D for execution compatibility
total_grad_output_first_dims, total_grad_output_last_dim = (
total_grad_output.shape[:-1],
total_grad_output.shape[-1],
)
total_grad_output = total_grad_output.view(math.prod(total_grad_output_first_dims), total_grad_output_last_dim)
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
def row_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
):
if async_communication:
return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
out = F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=group)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
return out
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.sharded_parameters import (
SplitConfig,
create_sharded_parameter_from_config,
mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
column_linear,
row_linear,
)
from nanotron.parallel.tied_parameters import create_tied_parameter
class TensorParallelColumnLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
):
self.pg = pg
self.world_size = pg.size()
assert out_features % self.world_size == 0
self.in_features = in_features
self.out_features = out_features // self.world_size
self.tp_recompute_allgather = tp_recompute_allgather
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == out_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(
self,
pg=self.pg,
split_config=split_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return column_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"
class TensorParallelRowLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.world_size = pg.size()
assert in_features % self.world_size == 0
self.in_features = in_features // self.world_size
self.out_features = out_features
# No need to shard the bias term, only rank 0 would have it
bias = dist.get_rank(self.pg) == 0 and bias
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
raise ValueError("async_communication is not supported for ALL_REDUCE mode")
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == in_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to in_features ({in_features})"
split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks)
self._mark_all_parameters_in_module_as_sharded(split_config)
def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
for name, param in list(self.named_parameters()):
if name == "bias":
# `bias` only exists in rank 0 because it's not sharded
new_param = NanotronParameter(tensor=param)
else:
new_param = create_sharded_parameter_from_config(
parameter=param,
pg=self.pg,
split_config=split_config,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return row_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}"
class TiedLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
):
self.pg = pg
self.world_size = pg.size()
self.mode = mode
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self._mark_all_parameters_in_module_as_tied()
def _mark_all_parameters_in_module_as_tied(self):
for name, param in list(self.named_parameters()):
new_param = create_tied_parameter(
parameter=param,
name=name,
global_ranks=tuple(sorted((get_global_rank(self.pg, i) for i in range(self.pg.size())))),
reduce_op=None if self.mode is TensorParallelLinearMode.ALL_REDUCE else dist.ReduceOp.SUM,
root_module=self,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = super().forward(x)
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
y = differentiable_identity(y, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
y = differentiable_all_gather(y, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return y
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.rank = dist.get_rank(self.pg)
self.world_size = pg.size()
self.original_num_embeddings = num_embeddings
# TODO @thomasw21: Fix and remove that constraint. Typically there's no reason to have such a constraint.
assert num_embeddings % self.world_size == 0
block_size = num_embeddings // self.world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.rank * block_size
self.max_id = (self.rank + 1) * block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
self.mode = mode
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == num_embeddings
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to num_embeddings ({num_embeddings})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=split_config)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.pg.size() > 1:
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input_ids, input_ids >= self.max_id)
# translate for [0, self.max_id - self.min_id[
masked_input = input_ids.clone() - self.min_id
# default all out of bounds values to `0`
masked_input[input_mask] = 0
else:
masked_input = input_ids
out = super().forward(masked_input)
if self.pg.size() > 1:
out = out * (~input_mask[..., None])
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return out
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_num_embeddings={self.original_num_embeddings}"
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_parameter_and_parent_module
logger = logging.get_logger(__name__)
def create_tied_parameter(
parameter: nn.Parameter,
name: str,
global_ranks: Tuple[int, ...],
reduce_op: Optional[dist.ReduceOp],
root_module: nn.Module,
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_tied(name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module)
return parameter
def tie_parameters(
root_module: nn.Module,
ties: List[Tuple[str, Tuple[int, ...]]],
parallel_context: ParallelContext,
reduce_op: Optional[dist.ReduceOp],
):
"""
Tie parameters.
Within a single device, tied parameters are replaced with a single Parameter
Across devices, we add metadata to Parameters that require extra synchronization.
:param root_module: nn.Module
:param ties: List[Tuple[str, Tuple[int, ...]]]: a tie is (param_target, global_ranks)
:param parallel_context: ParallelContext
:return:
"""
if len(ties) < 1:
raise ValueError("Can't tie nothing")
# TODO @thomasw21: When we support Zero3 this isn't true anymore
dp_ranks = tuple(
sorted(
{
parallel_context.get_local_ranks(world_rank=global_rank)[2]
for _, global_ranks in ties
for global_rank in global_ranks
}
)
)
assert (
len(dp_ranks) == 1
), f"Tying weights has to happen with a replica of a model. Got the ranks from the following replicas: {dp_ranks}"
name = ties[0][0]
global_ranks = tuple(sorted(set().union(*(tie[1] for tie in ties))))
new_param = None
world_rank = dist.get_rank(parallel_context.world_pg)
for tie_target, tie_model_ranks in ties:
if world_rank not in tie_model_ranks:
continue
param, parent_module, param_name = get_parameter_and_parent_module(target=tie_target, root_module=root_module)
# If they are physically in the same device, then we tie them
if new_param is None:
new_param = create_tied_parameter(
parameter=param, name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module
)
# Re-assign it to the original name. We assign the raw tensor instead of the parameter since we moved it already.
setattr(parent_module, param_name, new_param)
def create_pg_for_tied_weights(root_module: nn.Module, parallel_context: ParallelContext):
"""Tied weights are tied across specific set of global ranks, we use this method to create process groups for each difference set of global ranks"""
group_ranks = {
param.get_tied_info().global_ranks
for name, param in root_module.named_parameters()
if isinstance(param, NanotronParameter) and param.is_tied
}
world_group_ranks = [None] * parallel_context.world_pg.size()
dist.all_gather_object(world_group_ranks, group_ranks, group=parallel_context.world_pg)
all_group_ranks = sorted(
set().union(*world_group_ranks),
)
for global_ranks in all_group_ranks:
if global_ranks not in parallel_context.world_ranks_to_pg:
parallel_context.world_ranks_to_pg[global_ranks] = dist.new_group(global_ranks)
def get_tied_id_to_param(
parameters: List[NanotronParameter], root_module: nn.Module
) -> Dict[Tuple[str, Tuple[int, ...]], NanotronParameter]:
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in root_module.named_modules()}
# Fix the root_model
module_id_to_prefix[id(root_module)] = ""
return {
(
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix),
param.get_tied_info().global_ranks, # TODO @nouamane: merge groups which tie the same parameter
): param
for param in parameters
if param.is_tied
}
def sync_tied_weights_gradients(
module: nn.Module, # TODO: NanotronModel
parallel_context: ParallelContext,
grad_accumulator: Optional[GradientAccumulator],
):
tied_id_to_param = get_tied_id_to_param(
parameters=[param for param in module.parameters() if param.requires_grad], root_module=module
)
# Only first and last rank should print the warning
for rank in [0, parallel_context.world_pg.size() - 1]:
log_rank(
f"[Debug Tied Weights] Syncing the following tied weights: {tied_id_to_param.keys()}",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_pg,
rank=rank,
)
# Group tensors to reduce by process groups
# Important to use ordered dict in order to be synchronized across all ranks
group_ranks_and_reduce_op_to_tensors_to_reduce = OrderedDict()
for (name, group_ranks), tied_param in sorted(tied_id_to_param.items(), key=lambda x: x[0]):
tied_info = tied_param.get_tied_info()
# Some weights don't require any syncing, because they are by design synchronised
if tied_info.reduce_op is None:
continue
if grad_accumulator is not None:
tied_grad = grad_accumulator.get_grad_buffer(name=name)
else:
tied_grad = tied_param.grad
log_rank(
f"Syncing tied weights {name} across ranks {group_ranks} ...",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_ranks_to_pg[group_ranks],
rank=0,
)
key = (group_ranks, tied_info.reduce_op)
if key in group_ranks_and_reduce_op_to_tensors_to_reduce:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)].append(tied_grad)
else:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)] = [tied_grad]
for (group_ranks, reduce_op), tensors in group_ranks_and_reduce_op_to_tensors_to_reduce.items():
dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks])
import functools
import operator
import os
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton
class MemoryBuffer(metaclass=Singleton):
"""
Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
"""
def __init__(self):
self.buffer = {}
def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)
def assert_cuda_max_connections_set_to_1(func):
flag_is_set_to_1 = None
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal flag_is_set_to_1
if flag_is_set_to_1 is None:
assert os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") == "1"
flag_is_set_to_1 = True
return func(*args, **kwargs)
return wrapper
def initial_sync(model: nn.Module, parallel_context: ParallelContext):
# Synchronize across dp: basic assumption
sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0])
for name, param in sorted_name_params:
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg)
# Synchronize across tied weights: basic assumption
for (_, group_ranks), param in sorted(
get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0]
):
group = parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
import contextlib
import random
from dataclasses import dataclass
from typing import MutableMapping, Optional, Tuple
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
@dataclass
class RandomState:
random: Tuple[int, Tuple[int, ...], None]
numpy: Tuple[str, np.ndarray, int, int, float]
torch_cpu: torch.Tensor
torch_cuda: Optional[torch.Tensor]
def __eq__(self, other):
return (
isinstance(other, RandomState)
and all(v1 == v2 for v1, v2 in zip(self.random, other.random))
and all(
np.array_equal(v1, v2) if isinstance(v1, np.ndarray) else v1 == v2
for v1, v2 in zip(self.numpy, other.numpy)
)
and torch.equal(self.torch_cpu, other.torch_cpu)
and (
other.torch_cuda is None if self.torch_cuda is None else torch.equal(self.torch_cuda, other.torch_cuda)
)
)
class RandomStates(MutableMapping[str, RandomState]):
def __init__(self, dict: dict):
for key, value in dict.items():
self.check_type(key, value)
# TODO @thomasw21: We make a copy for safety measure.
self._dict = dict.copy()
@staticmethod
def check_type(key, value):
if not isinstance(key, str):
raise ValueError(f"Expected key to be of type str. Got {type(key)}")
if not isinstance(value, RandomState):
raise ValueError(f"Expected value to be of type `nanotron.dataclass.RandomState`. Got {type(value)}")
def __getitem__(self, item):
return self._dict[item]
def __iter__(self):
return self._dict.__iter__()
def __len__(self):
return len(self._dict)
def __delitem__(self, key):
raise ValueError("Can't delete a random states key")
def __setitem__(self, key, value):
if key not in self._dict:
raise ValueError("Can't add a new random states after initialisation")
self.check_type(key, value)
return self._dict.__setitem__(key, value)
def __eq__(self, other):
if not isinstance(other, RandomStates):
return False
return self._dict == other._dict
def set_random_seed(seed: int):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def set_random_state(random_state: RandomState):
random.setstate(random_state.random)
np.random.set_state(random_state.numpy)
torch.set_rng_state(random_state.torch_cpu)
if torch.cuda.is_available():
torch.cuda.set_rng_state(random_state.torch_cuda, "cuda")
else:
assert random_state.torch_cuda is None
def get_current_random_state():
"""Returns a snapshot of current random state"""
return RandomState(
random=random.getstate(),
numpy=np.random.get_state(),
torch_cpu=torch.random.get_rng_state(),
torch_cuda=torch.cuda.get_rng_state("cuda") if torch.cuda.is_available() else None,
)
@contextlib.contextmanager
def branch_random_state(random_states: RandomStates, key: str, enabled: bool):
"""
Context manager handling random state:
- upon entering: Stores current random state and set new random state defined by key.
- upon exiting: updates key in `random_states` to the new current random state, and set back the old one.
"""
if not enabled:
yield
return
old_random_state = get_current_random_state()
# Get the new state associated to the key
new_random_state = random_states[key]
set_random_state(new_random_state)
try:
yield
finally:
# Update state from parallel_context with the newest state
new_random_state = get_current_random_state()
random_states[key] = new_random_state
# Set the old state back
set_random_state(old_random_state)
def get_synced_random_state(
random_state: RandomState,
pg: ProcessGroup,
):
# We use rank 0 as a reference and broadcast random states from that rank to all the other ranks within a group in order to sync them
reference_rank = 0
if dist.get_rank(pg) == reference_rank:
random_states = [random_state]
else:
random_states = [None]
# TODO @thomasw21: broadcast tensor using `broadcast` in order not to use pickle
dist.broadcast_object_list(
random_states, src=dist.get_global_rank(pg, reference_rank), group=pg, device=torch.device("cuda")
)
new_random_state = random_states[0]
assert new_random_state is not None
return new_random_state
from .fsspec import check_path_is_local, fs_copy, fs_open
from .s3_mover import S3Mover
__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
import contextlib
from pathlib import Path
from typing import Tuple, Union
import fsspec
from fsspec.implementations import local
def get_filesystem_and_path(path: Path, storage_options=None) -> Tuple[fsspec.AbstractFileSystem, str]:
# Use supported filesystems in `fsspec`. If you need another one, please use `fsspec.registry.register_implementation`
# DO NOT USE `mode` argument as it adds a suffix `0.part` when using `mode="w"`.
fs, _, paths = fsspec.core.get_fs_token_paths(str(path), storage_options=storage_options)
assert len(paths) == 1
return fs, paths[0]
@contextlib.contextmanager
def fs_open(
file: Union[str, Path],
mode="r",
):
# TODO @thomasw21: pass storage options.
fs, path = get_filesystem_and_path(file)
with fs.open(path, mode=mode) as f:
yield f
def fs_copy(
input_file: Union[str, Path],
output_file: Union[str, Path],
):
"""Copy file from input to output (possibly on s3/other fs)"""
with fs_open(input_file, mode="rb") as fi, fs_open(output_file, mode="wb") as fo:
fo.write(fi.read())
def check_path_is_local(path: Path, storage_options=None) -> bool:
return isinstance(get_filesystem_and_path(path=path, storage_options=storage_options)[0], local.LocalFileSystem)
import glob
import json
import os
import subprocess
import time
from datetime import datetime
from enum import Enum
from typing import Optional, Union
import torch
from datasets.download.streaming_download_manager import xPath
from filelock import FileLock, Timeout
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import human_format
logger = logging.get_logger(__name__)
class S3Mover:
# TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
"""Take care of uploading a checkpoint to S3 in the background and remove it from the disk.
Args:
local_path: Path to the checkpoints on the local disk
s3_path: Path to the checkpoints on S3
remove_after_upload: If True, remove the checkpoint from the disk after uploading it to S3
s5cmd_numworkers: Number of workers to use for the s5cmd command
s5cmd_concurrency: Concurrency to use for the s5cmd command
s5cmd_path: Path to the s5cmd command
dummy: If True, don't actually upload/remove/etc anything. Useful for simpler multi-processing node and only uploading from one process.
Usage:
# Create a mover - use dummy=True for all the process that shouldn't do anything (e.g. all but one per node)
mover = S3Mover(local_path=/scratch/my-checkpoints,
s3_path=s3://my-bucket/my-checkpoints,
remove_after_upload=True,
s5cmd_numworkers=96,
s5cmd_concurrency=10,
s5cmd_path=/admin/user/my/bin/s5cmd,
dummy=False)
while training:
# from times to times update the state
mover_status = mover.update()
...
# When saving a checkpoint, check if the previous checkpoint has been uploaded and removed
# in a distributed setting
"""
class S3MoverState(Enum):
IDLE = "IDLE"
UPLOADING = "UPLOADING"
DOWNLOADING = "DOWNLOADING"
REMOVING_CHECKPOINT = "REMOVING_CHECKPOINT"
class DummyPopen:
def __init__(self, *args, **kwargs):
pass
def poll(self):
return 0
def communicate(self):
return ("", "")
def __init__(
self,
local_path: xPath,
s3_path: xPath,
post_upload_callback: Optional[callable] = None,
remove_after_upload: Optional[bool] = True,
s5cmd_numworkers: Optional[int] = None,
s5cmd_concurrency: Optional[int] = None,
s5cmd_path: Optional[str] = None,
s5cmd_credentials: Optional[str] = None,
clean_up_local_on_start: bool = False,
dummy: bool = False,
s3_region: str = "us-east-1",
):
self.process: Optional[Union[subprocess.Popen, S3Mover.DummyPopen]] = None
self.remove_after_upload = remove_after_upload
self.s5cmd_numworkers = s5cmd_numworkers
self.s5cmd_concurrency = s5cmd_concurrency
self.s5cmd_path = s5cmd_path if s5cmd_path is not None else "s5cmd"
self.s5cmd_credentials = s5cmd_credentials
self.lock_file = None
self.dummy = dummy
self.s3_region = s3_region
self.post_upload_callback = post_upload_callback
self.post_upload_callback_outputs = None
local_path = str(local_path)
if not local_path.startswith("/scratch/"):
self._warning(f"The local path is not on the scratch drive: {local_path}")
if not local_path.endswith("/"):
local_path += "/"
s3_path = str(s3_path)
if not s3_path.endswith("/"):
s3_path += "/"
self.local_path = local_path
self.s3_path = s3_path
s3_bucket, s3_prefix = s3_path.replace("s3://", "").split("/", maxsplit=1)
self.s3_path_direct_link = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?region={self.s3_region}&prefix={s3_prefix}&showversions=false"
self._reset_state()
if clean_up_local_on_start:
self._start_removing()
def _warning(self, message):
if self.dummy:
return
logger.warning(message)
def _info(self, message):
if self.dummy:
return
logger.info(message)
def _reset_state(self):
self.state = self.S3MoverState.IDLE
self.num_uploaded_files = 0
if self.lock_file is not None:
self._release_lock()
self.lock_file = None
self.stdout = ""
self.start_time: datetime = None
self.cmd = ""
def _popen(self, cmd: list):
self.stdout = ""
self.start_time = datetime.now()
self.cmd = cmd
if self.dummy:
return self.DummyPopen(cmd)
else:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
os.set_blocking(process.stdout.fileno(), False)
return process
def _acquire_lock(self, file_path: str) -> bool:
if self.dummy:
return True
if file_path.endswith("/"):
lock_file_path = file_path[:-1] + ".lock"
else:
lock_file_path = file_path + ".lock"
self.lock_file = FileLock(lock_file_path)
try:
self.lock_file.acquire(timeout=1)
except Timeout:
message = f"[S3] The checkpoint files {lock_file_path} are currently locked by another process. "
self._warning(message)
return False
return True
def get_state_as_int(self) -> int:
"""Return the state as an int"""
if self.state == self.S3MoverState.IDLE:
return 0
elif self.state == self.S3MoverState.UPLOADING:
return 1
elif self.state == self.S3MoverState.DOWNLOADING:
return 2
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
return 3
else:
return -1
def _release_lock(self):
if self.dummy:
return
if self.lock_file is not None and self.lock_file.is_locked:
self.lock_file.release()
def get_current_stdout(self) -> str:
"""Return the current stdout of the process if any"""
if self.process is None or isinstance(self.process, self.DummyPopen):
return ""
try:
stdout = self.process.stdout.read()
except ValueError:
stdout = "" # The buffer is already closed: "ValueError: read of closed file"
if stdout:
self.stdout += stdout.decode()
return self.stdout
def wait_for_completion(self):
while self.state != self.S3MoverState.IDLE:
_ = self.update()
time.sleep(0.5)
def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None):
"""Wait for the previous checkpoint to be fully uploaded and removed in a distributed setting.
Will wait for all process to be ready
"""
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
if all_saved != group.size() and self.state != self.S3MoverState.IDLE:
self._warning(
f"Waiting previous checkpoint saving is finished - S3Mover {dist.get_rank(group)} still in {self.state} state.",
)
while all_saved != group.size():
stdout = self.get_current_stdout()
stdout_lines = [lst for lst in stdout.split("\n") if lst]
if self.state != self.S3MoverState.IDLE:
self._warning(
f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}",
)
# sync all our saves on NCCL we could do a dist barrier later but this helps us not losing NCCL connections down the line
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
time.sleep(1)
def is_previous_save_finished(self) -> bool:
"""Return True if a potential previous checkpoint has been fully uploaded to S3
and removed from the drive
"""
self.update()
return self.state == self.S3MoverState.IDLE
def _start_downloading(self, sub_folder: Optional[str] = None) -> (bool, str):
self._warning(
f"[S3] Downloading checkpoint in background from {self.s3_path} to {self.local_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.s3_path + "*", self.local_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.DOWNLOADING
return True
def _post_downloading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
total_files = len([i for i in s5cmd_results if i["success"]])
total_not_downloaded_files = len([i for i in s5cmd_results if not i["success"]])
if total_not_downloaded_files == 0:
all_upload = "all files"
success = True
else:
all_upload = "not all files"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully downloaded {total_files} files for a total of {human_format(total_size)}B in {total_time}"
f"sec ({all_upload}) from S3 at {self.s3_path} to {self.local_path}"
f"(direct link: {self.s3_path_direct_link})"
)
return success
def _start_uploading(
self,
) -> (bool, str):
# Get a file lock on the first file
local_files = glob.glob(self.full_local_path + "/**/*.*", recursive=True)
locked = self._acquire_lock(local_files[0])
if not locked:
return False
if not os.path.exists(self.full_local_path):
message = f"[S3] Checkpoint {self.full_local_path} does not exist, cannot upload to S3"
self._warning(message)
return False
self._warning(
f"[S3] Uploading checkpoint in background from {self.full_local_path} to {self.full_s3_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp", "--exclude", "*.lock", "--exclude", "*.lock.*"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.full_local_path, self.full_s3_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.UPLOADING
return True
def _post_uploading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
local_files = glob.glob(self.full_local_path + "/**/*.?*", recursive=True)
total_files = len([i for i in s5cmd_results if i["success"]])
self.num_uploaded_files = total_files
if len(local_files) == total_files:
all_upload = "all files"
success = True
else:
all_upload = f"not all files: {len(local_files)} out of {total_files}"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully uploaded {total_files} files for a total of {human_format(total_size)}B in {total_time} sec"
f"({all_upload}) from {self.full_local_path} to S3 at {self.full_s3_path} "
f"(direct link: {self.s3_path_direct_link})"
)
if self.post_upload_callback:
self.post_upload_callback_outputs = self.post_upload_callback(uploaded_files=s5cmd_results)
self._release_lock()
return success
def _start_removing(self) -> (bool, str):
top_dir_in_local_checkpoint = [dir for dir in glob.glob(self.local_path + "/*") if os.path.isdir(dir)]
names_dir = [os.path.basename(dir) for dir in top_dir_in_local_checkpoint]
if len(names_dir) == 0:
# If the local is already empty or if we have already started duplicating in another process we skip with a noop
self._warning("[S3] Local checkpoint empty. skipping removal")
cmd = ["echo", "'skipping'"]
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
self._warning(f"[S3] Removing checkpoint in background: {names_dir}")
locked = self._acquire_lock(top_dir_in_local_checkpoint[0])
if not locked:
return False
cmd = ["rm", "-rfv"] + top_dir_in_local_checkpoint
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
def _post_removing(self) -> bool:
self.get_current_stdout()
local_files = [
loc_f
for loc_f in self.stdout.split("\n")
if "directory" not in loc_f.lower() and loc_f and ".lock" not in loc_f
]
if len(local_files) == self.num_uploaded_files:
all_removed = "all files"
success = True
else:
all_removed = "not all files"
success = False
self._release_lock()
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully removed {len(local_files)} local files ({all_removed}) from {self.local_path} (uploaded to {self.s3_path_direct_link}) in {total_time}"
)
return success
def update(self) -> (str, str):
"""Update the state of the mover: UPLOADING => REMOVING_DUPLICATED => DUPLICATING => REMOVING_CHECKPOINT => IDLE
Returns:
(str, str): The state and the stdout of the process if any
"""
if self.process is None:
self._reset_state()
return self.state, self.stdout
return_code = self.process.poll()
if return_code is None:
# Still running
return self.state, self.stdout
if return_code != 0:
self.get_current_stdout()
self._warning(
f"[S3] Error running command {self.cmd} during process {self.state.value}, "
f"return code {return_code}, return message {self.stdout}"
)
return self.state, self.stdout
if self.state == self.S3MoverState.DOWNLOADING:
self._post_downloading()
self._reset_state()
elif self.state == self.S3MoverState.UPLOADING:
self._post_uploading()
if self.remove_after_upload:
self._start_removing()
else:
self._reset_state()
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
self._post_removing()
self._reset_state()
return self.state.value, self.stdout
def start_uploading(self, sub_folder=None):
"""Start uploading last saved checkpoint to S3 in the background.
After running this method, you should call regularly `update` to update the
state to duplicating and then removing.
For a blocking upload, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = "[S3] Cannot move to S3 as the previous checkpoint has not been uploaded and removed"
self._warning(message)
return False
self.full_local_path = self.local_path + (f"/{sub_folder}" if sub_folder else "")
self.full_s3_path = self.s3_path + (f"/{sub_folder}" if sub_folder else "")
return self._start_uploading()
def start_downloading(self):
"""Start downloading a checkpoint from S3 in the background.
After running this method, you should call regularly `update` to update the
state.
For a blocking download, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = f"[S3] Cannot download from S3 as the state is not IDLE but {self.state.value}"
self._warning(message)
return False
return self._start_downloading()
from contextlib import contextmanager
from typing import Callable, Optional
import torch
from nanotron import distributed as dist
from nanotron import logging, optim
from nanotron.config import Config
from nanotron.logging import get_logger, log_rank
from nanotron.models import NanotronModel
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
logger = get_logger(__name__)
def assert_tensor_synced_across_pg(
tensor: torch.Tensor,
pg: dist.ProcessGroup,
msg: Optional[Callable[[str], str]] = None,
reference_rank: int = 0,
):
"""Assert that `tensor` is synced across `pg` with reference rank. Note that this always passes for reference rank"""
if dist.get_rank(pg) == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=dist.get_global_rank(group=pg, group_rank=reference_rank),
group=pg,
)
# TODO @nouamane: Getting Greatest absolute difference: 4.6e-10 at large scale when syncing tied weights
torch.testing.assert_close(tensor, reference_tensor, msg=msg)
# TODO @nouamanetazi: remove this with SANITY_CHECKS
@contextmanager
def assert_fail_except_rank_with(exception_class, rank_exception, pg):
try:
yield
except exception_class:
if rank_exception == dist.get_rank(pg):
raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
else:
return
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got {type(e)} instead:\n{e}")
if dist.get_rank(pg) != rank_exception:
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
def before_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(
parameters=unwrapped_model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)
# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
fp32_grad_buffer = elt["fp32_grad"]
torch.testing.assert_close(
fp32_grad_buffer,
torch.zeros_like(fp32_grad_buffer),
atol=0,
rtol=0,
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)
# TODO: add checks for memory contiguousness
# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()
def after_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradient flow on the entire model
# SANITY CHECK: Check that all parameters that required gradients, have actually a gradient
# SANITY CHECK: Check for nan/inf
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
if torch.isnan(grad).any() or torch.isinf(grad).any():
raise ValueError("Gradient is nan or inf")
if grad is None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} is missing gradient",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_tbi_sanity_checks()
def before_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
optimizer: optim.BaseOptimizer,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Test tied weights gradients are synchronized
for (name, group_ranks), param in sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
):
if not param.requires_grad:
continue
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=grad,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights grads for {name} are not synchronized. {err}",
)
# SANITY CHECK: Test gradients are synchronized across DP
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
assert_tensor_synced_across_pg(
tensor=grad,
pg=parallel_context.dp_pg,
msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP. {err}",
)
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that optimizer states are synchronized across DP
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()
def after_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradients is cleared
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.grad is not None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} still has gradient despite having ran the optimizer",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_optim_step_sanity_checks()
def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup):
for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
if name == "step":
continue
assert_tensor_synced_across_pg(
tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}"
)
import math
from abc import abstractmethod
from enum import Enum, auto
from typing import Dict
from nanotron.config import ModelArgs
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from torch import nn
from torch.nn import init
class ParametrizationMethod(Enum):
STANDARD = auto()
SPECTRAL_MUP = auto()
class Parametrizator:
def __init__(self, config: ModelArgs):
self.config = config
def parametrize(self, param_name: str, module: nn.Module):
if not isinstance(module, tuple(self.MODULE_TO_PARAMETRIZE.keys())):
raise Exception(f"Parameter {param_name} was not initialized")
return self.MODULE_TO_PARAMETRIZE[type(module)](param_name, module)
class StandardParametrizator(Parametrizator):
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_column_linear,
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers
def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_row_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
std = self.std / math.sqrt(2 * self.num_layers)
init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class SpectralMupParametrizator(Parametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
https://arxiv.org/abs/2310.17813
"""
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_mup_weight,
TensorParallelRowLinear: self._parametrize_mup_weight,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = 1.0
@staticmethod
def _compute_spectral_std(std: float, fan_in: int, fan_out: int):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
σₗ = Θ(1/√nₗ₋₁ min{1, √(nₗ/nₗ₋₁)})
"""
return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
def _parametrize_mup_weight(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
data = module.weight if param_name == "weight" else module.bias
fan_in, fan_out = init._calculate_fan_in_and_fan_out(data)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
std = SpectralMupParametrizator._compute_spectral_std(std=self.std, fan_in=fan_in, fan_out=fan_out)
init.normal_(data, mean=0.0, std=std)
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
# NOTE: you're free to change the initialization of layer norm
# as it's not a part of µTransfer
if "weight" == param_name:
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
# NOTE: you're free to change the initialization of input embedding/lm head
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class LearningRateForParametrizator:
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
self.lr = lr
self.names_to_modules = names_to_modules
@abstractmethod
def get_lr(self, param_name: str, module: nn.Module) -> float:
raise NotImplementedError
class LearningRateForSP(LearningRateForParametrizator):
"""All parameters get the same learning rate."""
def get_lr(self, param_name: str, param: nn.Module) -> float:
return self.lr
class LearningRateForSpectralMup(LearningRateForParametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
NOTE: each parameter gets a custom learning rate based on its fan-in and fan-out.
"""
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
super().__init__(lr, names_to_modules)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._get_mup_lr,
TensorParallelRowLinear: self._get_mup_lr,
TritonRMSNorm: self._get_global_lr,
TensorParallelEmbedding: self._get_global_lr,
}
def _get_mup_lr(self, param: nn.Parameter, module: nn.Module):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
ηₗ = Θ(nₗ/nₗ₋₁)
"""
fan_in, fan_out = init._calculate_fan_in_and_fan_out(param)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
return self.lr * (fan_out / fan_in)
def _get_global_lr(self, param: nn.Parameter, module: nn.Module) -> float:
return self.lr
def get_lr(self, param_name: str, param: nn.Parameter) -> float:
"""Return the learning rate for the given parameter."""
# NOTE: param_name should be like 'model.token_position_embeddings.pp_block.token_embedding.weight'
# since names_to_modules map module_name to module
# so we remove the .weight and .bias from param_name to get the module_name
module_name = param_name.rsplit(".", 1)[0]
module = self.names_to_modules[module_name]
return self.MODULE_TO_PARAMETRIZE[type(module)](param, module)
# flake8: noqa
from nanotron.serialize.main import *
from nanotron.serialize.optimizer import *
from nanotron.serialize.random import *
from nanotron.serialize.weights import *
import os
from pathlib import Path
from typing import Optional, cast
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from nanotron import distributed as dist
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
from nanotron.sanity_checks import (
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import TrainingMetadata, save_meta
from nanotron.serialize.optimizer import (
save_lr_scheduler,
save_optimizer,
)
from nanotron.serialize.weights import save_weights
"""
We're going to use safetensors. The reason is that loading segments is going to be much easier
Requirements:
- serialized format need to be able to recover the current training state. (random states, weights, optimizer states_
- serialized format should be topology agnostic. Will makes things much easier with varying topologies
Current way of thinking:
- one file = one tensor (it would create huge amount of files, but we should revisit only if that's a problem)
Version 1:
- serialize -> dumps every process weights in individual files
- load -> assume topology is exactly the same.
"""
logger = logging.get_logger(__name__)
def save(
config: "Config",
model: nn.Module,
optimizer: optim.BaseOptimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
parallel_context: ParallelContext,
training_metadata: TrainingMetadata,
root_folder: Path,
should_save_config: bool = True,
should_save_model: bool = True,
should_save_optimizer: bool = True,
should_save_lr_scheduler: bool = True,
sanity_checks: bool = True,
) -> None:
assert isinstance(training_metadata, TrainingMetadata)
try:
if should_save_config:
config.save_as_yaml(root_folder / "config.yaml")
except Exception as e:
# TODO @nouamane: catch full disk error
log_rank(
f"Error while saving config: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_model:
save_weights(model=model, parallel_context=parallel_context, root_folder=root_folder)
except Exception as e:
log_rank(
f"Error while saving weights checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_optimizer:
save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
except Exception as e:
log_rank(
f"Error while saving optimizer checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_lr_scheduler:
lr_scheduler = cast(LambdaLR, lr_scheduler)
assert len(lr_scheduler.lr_lambdas) == len(
optimizer.param_groups
), "The number of lambdas functions in the scheduler should be equal to the number of parameter groups in the optimizer."
save_lr_scheduler(
lr_scheduler=lr_scheduler,
# is_zero=True,
is_zero=config.optimizer.zero_stage,
parallel_context=parallel_context,
root_folder=root_folder,
)
except Exception as e:
log_rank(
f"Error while saving lr_scheduler checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
save_meta(root_folder=root_folder, parallel_context=parallel_context, training_metadata=training_metadata)
# TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs)
###
# SANITY CHECK: Check that the model params are synchronized across `parallel_context.dp_pg`
if sanity_checks:
for name, param_or_buffer in sorted(model.state_dict().items(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param_or_buffer,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synced across DP {err}",
)
# SANITY CHECK: Check that the tied parameters are synchronized
sorted_tied_parameters = sorted(
(
param
for parameters_group in optimizer.param_groups
for param in parameters_group["params"]
if param.requires_grad and isinstance(param, NanotronParameter) and param.is_tied
),
key=lambda param: param.get_tied_info().name,
)
for tied_param in sorted_tied_parameters:
tied_info = tied_param.get_tied_info()
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}"
)
if not optimizer.inherit_from(optim.ZeroDistributedOptimizer):
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)
# SANITY CHECK: tied parameters have their optimizer states synchronized
# Compute a mapping from id_ to index in the optimizer sense
state_dict = optimizer.state_dict()
assert len(optimizer.param_groups) == len(state_dict["param_groups"])
index_to_param = {}
for real_param_group, index_param_group in zip(optimizer.param_groups, state_dict["param_groups"]):
indices = index_param_group["params"]
parameters = real_param_group["params"]
assert len(indices) == len(parameters)
for param, index in zip(parameters, indices):
assert index not in index_to_param
index_to_param[index] = param
current_state_dict = optimizer.state_dict()
for index, optim_state in sorted(current_state_dict["state"].items(), key=lambda x: x[0]):
param = index_to_param[index]
if not isinstance(param, NanotronParameter):
continue
if not param.is_tied:
# If it's not shared, we don't need to check it's synced
continue
tied_info = param.get_tied_info()
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
reference_rank = 0
current_rank = dist.get_rank(group)
for name, tensor in optim_state.items():
# FIXME @thomasw21: Some data is actually on `cpu`, just for this test we most it to `cuda`
tensor = tensor.to("cuda")
if current_rank == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=get_global_rank(group=group, group_rank=reference_rank),
group=group,
)
torch.testing.assert_close(
tensor,
reference_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}",
)
dist.barrier(parallel_context.world_pg)
def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Args:
config: Config object.
Returns:
Path to checkpoint or None if no checkpoint.
"""
load_from_candidate = config.checkpoints.resume_checkpoint_path
if load_from_candidate is not None:
if check_path_is_local(load_from_candidate):
latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
# TODO @thomasw21: make a better structure system so that we get typing correct
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)
elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path
else:
log_rank(
f"No previous checkpoint found in: {latest_meta_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return None
log_rank(
f"Loading checkpoint from {checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
else:
latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
# if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
with fs_open(latest_meta_path, mode="r") as fi:
latest_iteration = int(fi.read())
s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path
checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path
elif config.checkpoints.resume_checkpoint_path.exists():
# we assume that the checkpoint path is a path to a checkpoint
s3_path = config.checkpoints.resume_checkpoint_path # load_path
checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path
else:
log_rank(
f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.",
logger=logger,
level=logging.WARNING,
rank=0,
)
return None
log_rank(
f"Downloading checkpoint from S3 in {checkpoint_path} ",
logger=logger,
level=logging.WARNING,
rank=0,
)
# Download checkpoint from S3
s3_mover = S3Mover(
local_path=os.path.join(checkpoint_path),
s3_path=os.path.join(s3_path),
s5cmd_numworkers=config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=config.s3_upload.s5cmd_concurrency,
s5cmd_path=config.s3_upload.s5cmd_path,
dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
s3_mover.start_downloading()
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
return checkpoint_path
import dataclasses
import json
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
import dacite
import torch
from dacite import from_dict
from packaging.version import Version
from nanotron import distributed as dist
from nanotron.constants import CHECKPOINT_FILE_NAME, CHECKPOINT_VERSION
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import SlicesPair
@dataclasses.dataclass
class DataStageMetadata:
"""
consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero).
last_train_step: The last training step across all stages.
# NOTE: we should allow people to change the name of the data stages in the config file.
# but not the start_training_step, because it could
"""
name: str
start_training_step: int
consumed_train_samples: int
@dataclasses.dataclass
class TrainingMetadata:
"""
consumed_train_samples: The number of samples consumed globally, across all stages.
last_train_step: The last training step across all stages.
last_stage_idx: The index of the last stage that was trained.
data_stages: The metadata for each stage.
"""
consumed_train_samples: int
last_train_step: int
# TODO(xrsrke): make this not optional, once we entirely remove
# the old checkpoint version
last_stage_idx: Optional[int] = None
data_stages: Optional[List[DataStageMetadata]] = None
def __post_init__(self):
# NOTE: this is a sanity check after loading a trained checkpoint
total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages)
assert (
self.consumed_train_samples == total_consumed_samples_across_stages
), "Mismatch between the total consumed samples and the sum of consumed samples across stages! Something went wrong in the training."
# TODO(xrsrke): remove this once we entirely remove non-data-stage training
if self.last_stage_idx is not None:
assert self.data_stages is not None, "data_stages should not be None if last_stage_idx is not None"
@dataclasses.dataclass
class CheckpointMetadata:
version: Version
tp: int
dp: int
metas: TrainingMetadata
custom_metas: Optional[Dict[str, Any]] = None
@dataclasses.dataclass
class TensorMetadata:
# Mandatory for checkpoint version higher than 1.2
version: Version
# Anything users want to store
# Info of to what slice of the unsharded tensor (global_slices) the current sharded tensor corresponds (local_slices)
local_global_slices_pairs: Tuple[SlicesPair, ...]
# The shape of the unsharded tensor
unsharded_shape: Tuple[int, ...]
_metadata_config: ClassVar[dacite.Config] = dacite.Config(
cast=[Version],
type_hooks={
Tuple[SlicesPair, ...]: SlicesPair.tuple_from_str,
Tuple[int, ...]: lambda x: torch.Size(int(size) for size in x.strip("()").split(",") if size),
},
strict=True,
)
def to_str_dict(self) -> Dict[str, str]:
return {
"version": str(self.version),
"local_global_slices_pairs": SlicesPair.tuple_to_str(self.local_global_slices_pairs),
"unsharded_shape": str(tuple(self.unsharded_shape)),
}
@classmethod
def from_str_dict(cls, dictionary: Dict[str, str]) -> "TensorMetadata":
tensor_metadata: TensorMetadata = dacite.from_dict(
data_class=TensorMetadata,
data=dictionary,
config=cls._metadata_config,
)
return tensor_metadata
def process_type(elt: Any, type_hooks: Dict[Type, Callable[[Any], Any]]):
if isinstance(elt, dict):
return to_dict(elt, type_hooks=type_hooks)
elif elt.__class__ in type_hooks:
return type_hooks[elt.__class__](elt)
elif isinstance(elt, (list, tuple)):
return to_list(elt, type_hooks=type_hooks)
else:
return elt
def to_dict(dict_: Dict, type_hooks: Dict[Type, Callable[[Any], Any]]):
result = {}
for key, value in dict_.items():
result[key] = process_type(value, type_hooks=type_hooks)
return result
def to_list(list_: Union[List, Tuple], type_hooks: Dict[Type, Callable[[Any], Any]]):
return list_.__class__((process_type(elt, type_hooks=type_hooks) for elt in list_))
def save_meta(parallel_context: ParallelContext, root_folder: Path, training_metadata: TrainingMetadata):
assert isinstance(training_metadata, TrainingMetadata)
if dist.get_rank(parallel_context.world_pg) != 0:
return
root_folder.mkdir(exist_ok=True, parents=True)
checkpoint_metadata = CheckpointMetadata(
version=CHECKPOINT_VERSION,
tp=parallel_context.tp_pg.size(),
dp=parallel_context.dp_pg.size(),
metas=training_metadata,
)
# There are some types that require manual casting in order to work correctly.
processed_metadata = process_type(dataclasses.asdict(checkpoint_metadata), type_hooks={Version: lambda x: str(x)})
with open(root_folder / CHECKPOINT_FILE_NAME, mode="w") as fo:
json.dump(processed_metadata, fo, indent=2, sort_keys=True)
def load_meta(parallel_context: ParallelContext, root_folder: Path) -> CheckpointMetadata:
with open(root_folder / CHECKPOINT_FILE_NAME, mode="r") as fi:
checkpoint_metadata = json.load(fi)
checkpoint_metadata = from_dict(
data_class=CheckpointMetadata,
data=checkpoint_metadata,
config=dacite.Config(
cast=[Version],
),
)
# Assume that we're always backward compatible, we only increment CHECKPOINT_VERSION when there's a breaking change.
assert (
checkpoint_metadata.version <= CHECKPOINT_VERSION
), f"Checkpoint is of version {checkpoint_metadata.version}, Current `nanotron` checkpoint version is {CHECKPOINT_VERSION}"
return checkpoint_metadata
import json
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
from torch import nn
from tqdm import tqdm
from nanotron import distributed as dist
from nanotron import optim
from nanotron.optim.zero import (
ZeroDistributedOptimizer,
extract_parallel_ranks_from_shard_path,
find_optim_index_from_param_name,
get_sliced_tensor,
merge_dp_shard_in_zero1_optimizer,
)
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors
# TODO(xrsrke): take rank instead of parallel_context
def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
def save_optimizer(
optimizer: optim.BaseOptimizer,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves optimizer states
- If Zero-0 is used, optimizer states are replicated across all DPs. Only DP-0 saves the states
- If Zero-1 is used, optimizer states are sharded across all DPs. Each DP saves its own states
"""
if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return
# TODO: Figure out if I need to save param groups. Right now I'm assuming no as we only store what's trainable
# TODO: We can probably "rotate" so that every process stores something (maybe doesn't matter if we're I/O bound)
root_folder = root_folder / "optimizer"
root_folder.mkdir(exist_ok=True, parents=True)
if dist.get_rank(parallel_context.world_pg) == 0:
with open(root_folder / "optimizer_config.json", "w") as fo:
tp_size = parallel_context.tp_pg.size()
pp_size = parallel_context.pp_pg.size()
dp_size = parallel_context.dp_pg.size()
expert_parallel_size = parallel_context.expert_parallel_size
config = {
"type": str(optimizer.__class__.__name__),
"parallelism": {
"tp_size": str(tp_size),
"dp_size": str(dp_size),
"pp_size": str(pp_size),
"expert_parallel_size": str(expert_parallel_size),
},
"configs": {},
}
if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: in order to serialize, we must save all keys and values as strings
def convert_to_string(input_item):
if isinstance(input_item, dict):
return {str(key): convert_to_string(value) for key, value in input_item.items()}
elif isinstance(input_item, list):
return [convert_to_string(element) for element in input_item]
elif isinstance(input_item, tuple):
return tuple(convert_to_string(element) for element in input_item)
else:
return str(input_item)
# NOTE: if it's a ZeRO-1 optimzier, then we save how the parameters are sharded
# across data parallel dimension, so that we can reconstruct the optimizer states
assert optimizer.param_name_to_dp_rank_offsets is not None, "param_name_to_dp_rank_offsets is required"
config["configs"]["param_name_to_dp_rank_offsets"] = convert_to_string(
optimizer.param_name_to_dp_rank_offsets
)
# NOTE: since tp sharded params are flattened, so we need to save the original param shapes
# so that we can recontruct the original shapes => reconstruct the unsharded params in tensor parallel dimension
config["configs"]["orig_param_shapes"] = convert_to_string(optimizer._orig_param_shapes)
json.dump(config, fo)
# We dump the optimizer state using `torch.save`
torch.save(
optimizer.state_dict(),
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
)
def save_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves lr scheduler states"""
if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return
root_folder = root_folder / "lr_scheduler"
root_folder.mkdir(exist_ok=True, parents=True)
# We dump the optimizer state using `torch.save`
torch.save(
lr_scheduler.state_dict(),
root_folder / lr_scheduler_filename(parallel_context, is_zero),
)
# Helper functions to move optimizer states
@torch.no_grad()
def state_dict_to_device(state_dict: Dict, device: str) -> Dict:
assert (
state_dict["state"][0]["exp_avg"].device.type == "cpu"
), "Optimizer states should be on CPU to avoid extra memory usage when loading from checkpoint"
torch.cuda.empty_cache()
for _, optim_state in sorted(state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
optim_state[name] = tensor.to(device)
assert (
state_dict["state"][0]["exp_avg"].device.type == "cuda"
), "Optimizer states should be on GPU because model is on GPU"
torch.cuda.empty_cache()
@torch.no_grad()
def load_optimizer(
optimizer: optim.BaseOptimizer,
parallel_context: ParallelContext,
root_folder: Path,
map_location: Optional[str] = None,
param_shard_metadata: Tuple[Tuple[int, int], TensorMetadata] = None, # (pp_rank, tp_rank) -> TensorMetadata
model: Optional[nn.Module] = None,
):
root_folder = root_folder / "optimizer"
ckp_optimizer_config_path = root_folder / "optimizer_config.json"
with open(ckp_optimizer_config_path, "r") as file:
ckp_optimizer_config = json.load(file)
ckp_pp_size = ckp_optimizer_config["parallelism"]["pp_size"]
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
if int(ckp_pp_size) != int(parallel_context.pp_pg.size()):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
assert (
model is not None
), "You have to pass the model in order to adjust the optimizer states according to how the current parameters are sharded"
def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -> TensorMetadata:
return param_shard_metadata[param_name.replace("module.", "")][(str(pp_rank), str(tp_rank))]
ckp_optim_type = ckp_optimizer_config["type"]
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: if the checkpoint is from a Zero-1 optimizer, then we need to merge the shards
# across data parallel dimension, before merging the shards across tensor parallel dimension
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}-exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)
ckp_sharded_optim_states = merge_dp_shard_in_zero1_optimizer(
model, ckp_optimizer_config, shard_paths, parallel_context, map_location
)
else:
# NOTE: if the checkpoint is from a Zero-0 optimizer, then we don't need to merge the shards
# across data parallel dimension, just directly load the checkpoints
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt"
) # WARN: wildcard here after tp can hold `0-of-1_exp-0`
)
ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=False)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(
shard_path, map_location=map_location
) # load all optim states in mem
model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
new_optim_state_dict["state"] = defaultdict(dict)
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
param_names = list(model_state_dict.keys())
new_optim_state_param_names = {}
# NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
# Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
for param_index, param_name in tqdm(
enumerate(param_names),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Topology-agnostic optimizer loading",
):
try:
param = model.get_parameter(param_name)
except AttributeError:
param = None
if not isinstance(param, NanotronParameter):
raise NotImplementedError("Parameters are required to be NanotronParameter")
# NOTE: for tied parameters, the metadata is stored using the parameter name,
# while the data is stored using the name of the main tied parameter,
# which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
# for `model.lm_head.pp_block.weight`).
base_name = param.get_tied_info().name if param.is_tied else param_name
if param_name != base_name:
# NOTE: skip tied parameter if main tied parameter has already been loaded
# (not always the case if pipeline parallel)
if base_name in new_optim_state_param_names.values():
continue
new_optim_state_param_names[param_index] = base_name
if param.is_sharded:
# NOTE: optimizer states's shape is equal to the parameter's shape
# NOTE: sometimes an unsharded parameter's shape differ
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device=map_location, dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(
new_unshared_shape, device=map_location, dtype=OPTIMIZER_STATE_DTYPE
)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # NOTE: param is not in this pp shard
ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key]
# NOTE: the metadata for the main parameter of a tied parameter might be in a
# different pipeline parallel shard.
if param.is_tied:
metadata_pp_rank = next(
iter(param_shard_metadata[param_name.replace("module.", "")].keys())
)[0]
else:
metadata_pp_rank = pp_rank
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank)
# NOTE: if the checkpoint is from a Zero-1 optimizer,
# so it's flattened, so we need to reshape it
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: this is the original shape of the parameter before being flattened
orig_shape = ckp_optimizer_config["configs"]["orig_param_shapes"][param_name]
orig_shape = [int(dim) for dim in orig_shape]
ckp_shard_data = ckp_shard_data.view(orig_shape)
new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors(
buffer,
unsharded_buffer,
[
(ckp_shard_data, ckp_shard_metadata.local_global_slices_pairs),
],
new_shard_metadata,
)
else:
# Handle non-sharded params (e.g. layernorm)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # Param not in this PP shard
# For non-sharded params, just copy over the state directly
for state_key in OPTIMIZER_STATE_NAMES:
new_optim_state_dict["state"][param_index][state_key] = ckp_optim_state["state"][
old_optim_state_index
][state_key]
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][param_index][
state_key
].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step
# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads
new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
state_dict = torch.load(
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
map_location=map_location,
)
if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: only reshard after merging tp shards
# or we get a new dp_Size
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
current_dp_rank = dist.get_rank(parallel_context.dp_pg)
OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
sliced_tensor = get_sliced_tensor(
param=state_dict["state"][param_index][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
)
state_dict["state"][param_index][state_name] = sliced_tensor
optimizer.load_state_dict(state_dict, map_location=map_location)
def load_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
root_folder = root_folder / "lr_scheduler"
state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context, is_zero))
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment