Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
class DifferentiableIdentity(torch.autograd.Function):
"""All-reduce gradients in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
return tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group), None
class DifferentiableAllReduceSum(torch.autograd.Function):
"""All-reduce in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
if group.size() == 1:
return tensor
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class DifferentiableAllGather(torch.autograd.Function):
"""All gather in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group)
return unsharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None
class DifferentiableReduceScatterSum(torch.autograd.Function):
"""Reduce scatter in a differentiable fashion"""
@staticmethod
def forward(ctx, tensor, group: Optional[ProcessGroup]):
ctx.group = group
if group.size() == 1:
return tensor
# TODO @thomasw21: shard along another dimension
unsharded_batch_size, *rest_size = tensor.shape
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
assert unsharded_batch_size % group.size() == 0
# TODO @thomasw21: Collectives seem to require tensors to be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
tensor = tensor.contiguous()
sharded_tensor = torch.empty(
unsharded_batch_size // group.size(),
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllGather.apply(grad_output, group), None
# -----------------
# Helper functions.
# -----------------
def differentiable_identity(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableIdentity.apply(tensor, group)
def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllReduceSum.apply(tensor, group)
def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableAllGather.apply(tensor, group)
def differentiable_reduce_scatter_sum(tensor, group: Optional[ProcessGroup] = None):
return DifferentiableReduceScatterSum.apply(tensor, group)
from enum import Enum, auto
# TODO @thomasw21: python 3.11 introduces `StrEnum` which would've been great to use.
class TensorParallelLinearMode(Enum):
ALL_REDUCE = auto()
REDUCE_SCATTER = auto()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from torch.nn import functional as F
import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1
class _ShardedCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(
ctx,
sharded_logits, # (batch_size, length, sharded_hidden_size)
target, # (batch_size, length)
group: dist.ProcessGroup,
):
# Maximum value along last dimension across all GPUs.
logits_max = torch.max(sharded_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=group)
# Subtract the maximum value.
sharded_logits = sharded_logits - logits_max.unsqueeze(dim=-1)
# Get the shard's indices
sharded_hidden_size = sharded_logits.shape[-1]
rank = dist.get_rank(group)
start_index = rank * sharded_hidden_size
end_index = start_index + sharded_hidden_size
# Create a mask of valid ids (1 means it needs to be masked).
target_mask = (target < start_index) | (target >= end_index)
masked_target = target.clone() - start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, shard-size] and target to a 1-D tensor of size [*].
logits_2d = sharded_logits.view(-1, sharded_hidden_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
if predicted_logits_1d.is_contiguous():
predicted_logits_1d = predicted_logits_1d.clone()
else:
predicted_logits_1d = predicted_logits_1d.contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
dist.all_reduce(predicted_logits, op=dist.ReduceOp.SUM, group=group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = sharded_logits
torch.exp(sharded_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss.view_as(target)
@staticmethod
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, sharded_hidden_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtype: torch.dtype = None):
"""Helper function for the cross entropy."""
if dtype is not None:
# Cast input to specific dtype.
sharded_logits = sharded_logits.to(dtype=dtype)
return _ShardedCrossEntropy.apply(sharded_logits, target, group)
class _ColumnLinearAsyncCommunication(torch.autograd.Function):
"""Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215"""
@staticmethod
@assert_cuda_max_connections_set_to_1
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
group_size = group.size()
current_rank = dist.get_rank(group)
if group_size == 1:
gathered_tensor = tensor
ctx.save_for_backward(tensor, weight)
return F.linear(gathered_tensor, weight, bias)
else:
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
# ctx.save_for_backward(tensor, weight)
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)
# Compute a shard of column_linear in the same time of AllGather
# We could compute the matmul of current holding shard and the current rank's weight
# We assume that rank 0 holds w0, rank 1 holds w1, etc.
# weights: w0 w1 w2 w3
# rank 0: X - - -
# rank 1: - X - -
# rank 2: - - X -
# rank 3: - - - X
# We call the corresponding shard of output "same_device_shard"
output_size = weight.shape[0]
gathered_output = torch.empty(
gathered_batch_size,
*intermediate_size,
output_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
before_shard, same_device_shard, after_shard = torch.split(
gathered_output,
split_size_or_sections=[
sharded_batch_size * current_rank,
sharded_batch_size,
sharded_batch_size * (group_size - current_rank - 1),
],
dim=0,
)
first_dims = math.prod([sharded_batch_size, *intermediate_size])
if bias is None:
torch.mm(
input=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=tensor.view(first_dims, hidden_size),
mat2=weight.t(),
out=same_device_shard.view(first_dims, output_size),
)
# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)
# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
# rank 0: - X X X
# rank 1: X - X X
# rank 2: X X - X
# rank 3: X X X -
# As they could be not contiguous (r1 and r2) vertically as they are separated by "same_device_shard"
# We need to compute them separately, i.e. "before_shard" and "after_shard"
# For r0, "before_shard" is empty. For r3, "after_shard" is empty.
if before_shard.numel() > 0:
first_dims = math.prod(before_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[: sharded_batch_size * current_rank].view(first_dims, hidden_size),
mat2=weight.t(),
out=before_shard.view(first_dims, output_size),
)
if after_shard.numel() > 0:
first_dims = math.prod(after_shard.shape[:-1])
if bias is None:
torch.mm(
input=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
else:
torch.addmm(
input=bias[None, :],
mat1=gathered_tensor[sharded_batch_size * (current_rank + 1) :].view(
first_dims, hidden_size
),
mat2=weight.t(),
out=after_shard.view(first_dims, output_size),
)
return gathered_output
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
if group.size() == 1:
total_tensor = tensor
else:
unsharded_batch_size = sharded_batch_size * group.size()
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
else:
total_tensor = tensor
grad_tensor = grad_output.matmul(weight)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()
grad_bias = grad_output.sum(dim=0) if use_bias else None
if handle1 is not None:
handle1.wait()
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
if handle2 is not None:
handle2.wait()
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""
@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):
# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)
# Get linear output.
out = F.linear(total_input, weight, bias)
return out
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)
# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None
return sub_grad_input, grad_weight, grad_bias, None, None
def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")
class _RowLinearAsyncCommunication(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, weight, bias, group, tp_mode):
assert (
tp_mode is TensorParallelLinearMode.REDUCE_SCATTER
), f"async communication in RowLinear only supports REDUCE_SCATTER, got {tp_mode}"
if group is None:
group = dist.distributed_c10d._get_default_group()
ctx.use_bias = bias is not None
ctx.group = group
out = F.linear(tensor, weight, bias)
if group.size() > 1:
out = differentiable_reduce_scatter_sum(out, group=group)
ctx.save_for_backward(tensor, weight)
return out
@staticmethod
@assert_cuda_max_connections_set_to_1
def backward(ctx, grad_output):
tensor, weight = ctx.saved_tensors
group = ctx.group
use_bias = ctx.use_bias
handle: Optional[dist.Work] = None
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = grad_output.shape
if group.size() == 1:
total_grad_output = grad_output
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
# total_grad_output: [b, s, h_out]
# weight: [h_out, h_in/n]
# total_grad_tensor: [b, s, h_in/n]
# grad_output: [b/n, s, h_out]
sharded_batch_size, *rest_size_grad_output = grad_output.shape
rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]]
if group.size() == 1:
total_grad_tensor = grad_output.matmul(weight)
else:
unsharded_batch_size = sharded_batch_size * group.size()
total_grad_tensor = torch.empty(
unsharded_batch_size,
*rest_size_grad_tensor,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
)
before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split(
total_grad_tensor,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# compute local shard
torch.mm(
input=grad_output.view(-1, grad_output.shape[-1]),
mat2=weight,
out=same_device_shard_grad_tensor.view(-1, weight.shape[1]),
)
if handle is not None:
handle.wait()
before_shard_grad_output, _, after_shard_grad_output = torch.split(
total_grad_output,
split_size_or_sections=[
sharded_batch_size * dist.get_rank(group),
sharded_batch_size,
sharded_batch_size * (group.size() - dist.get_rank(group) - 1),
],
dim=0,
)
# before shard compute
if before_shard_grad_tensor.numel() > 0:
torch.mm(
input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]),
mat2=weight,
out=before_shard_grad_tensor.view(-1, weight.shape[1]),
)
# after shard compute
if after_shard_grad_tensor.numel() > 0:
torch.mm(
input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]),
mat2=weight,
out=after_shard_grad_tensor.view(-1, weight.shape[1]),
)
# Convert the tensor shapes to 2D for execution compatibility
tensor = tensor.contiguous()
tensor_first_dims, tensor_last_dim = tensor.shape[:-1], tensor.shape[-1]
tensor = tensor.view(math.prod(tensor_first_dims), tensor_last_dim)
# Convert the tensor shapes to 2D for execution compatibility
total_grad_output_first_dims, total_grad_output_last_dim = (
total_grad_output.shape[:-1],
total_grad_output.shape[-1],
)
total_grad_output = total_grad_output.view(math.prod(total_grad_output_first_dims), total_grad_output_last_dim)
# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
def row_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
):
if async_communication:
return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
out = F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=group)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")
return out
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.distributed import get_global_rank
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.sharded_parameters import (
SplitConfig,
create_sharded_parameter_from_config,
mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
column_linear,
row_linear,
)
from nanotron.parallel.tied_parameters import create_tied_parameter
class TensorParallelColumnLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
):
self.pg = pg
self.world_size = pg.size()
assert out_features % self.world_size == 0
self.in_features = in_features
self.out_features = out_features // self.world_size
self.tp_recompute_allgather = tp_recompute_allgather
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == out_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to out_features ({out_features})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(
self,
pg=self.pg,
split_config=split_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return column_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"
class TensorParallelRowLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.world_size = pg.size()
assert in_features % self.world_size == 0
self.in_features = in_features // self.world_size
self.out_features = out_features
# No need to shard the bias term, only rank 0 would have it
bias = dist.get_rank(self.pg) == 0 and bias
super().__init__(
in_features=self.in_features,
out_features=self.out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.mode = mode
self.async_communication = async_communication
if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication:
raise ValueError("async_communication is not supported for ALL_REDUCE mode")
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == in_features
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to in_features ({in_features})"
split_config = SplitConfig(split_dim=1, contiguous_chunks=contiguous_chunks)
self._mark_all_parameters_in_module_as_sharded(split_config)
def _mark_all_parameters_in_module_as_sharded(self, split_config: SplitConfig):
for name, param in list(self.named_parameters()):
if name == "bias":
# `bias` only exists in rank 0 because it's not sharded
new_param = NanotronParameter(tensor=param)
else:
new_param = create_sharded_parameter_from_config(
parameter=param,
pg=self.pg,
split_config=split_config,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return row_linear(
input=x,
weight=self.weight,
bias=self.bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}"
class TiedLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
):
self.pg = pg
self.world_size = pg.size()
self.mode = mode
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self._mark_all_parameters_in_module_as_tied()
def _mark_all_parameters_in_module_as_tied(self):
for name, param in list(self.named_parameters()):
new_param = create_tied_parameter(
parameter=param,
name=name,
global_ranks=tuple(sorted((get_global_rank(self.pg, i) for i in range(self.pg.size())))),
reduce_op=None if self.mode is TensorParallelLinearMode.ALL_REDUCE else dist.ReduceOp.SUM,
root_module=self,
)
setattr(self, name, new_param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = super().forward(x)
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
y = differentiable_identity(y, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
y = differentiable_all_gather(y, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return y
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
):
self.pg = pg
self.rank = dist.get_rank(self.pg)
self.world_size = pg.size()
self.original_num_embeddings = num_embeddings
# TODO @thomasw21: Fix and remove that constraint. Typically there's no reason to have such a constraint.
assert num_embeddings % self.world_size == 0
block_size = num_embeddings // self.world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.rank * block_size
self.max_id = (self.rank + 1) * block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
self.mode = mode
if contiguous_chunks is not None:
assert (
sum(contiguous_chunks) == num_embeddings
), f"Sum of contiguous chunks ({sum(contiguous_chunks)}) must equal to num_embeddings ({num_embeddings})"
split_config = SplitConfig(split_dim=0, contiguous_chunks=contiguous_chunks)
mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=split_config)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.pg.size() > 1:
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input_ids, input_ids >= self.max_id)
# translate for [0, self.max_id - self.min_id[
masked_input = input_ids.clone() - self.min_id
# default all out of bounds values to `0`
masked_input[input_mask] = 0
else:
masked_input = input_ids
out = super().forward(masked_input)
if self.pg.size() > 1:
out = out * (~input_mask[..., None])
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")
return out
def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_num_embeddings={self.original_num_embeddings}"
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_parameter_and_parent_module
logger = logging.get_logger(__name__)
def create_tied_parameter(
parameter: nn.Parameter,
name: str,
global_ranks: Tuple[int, ...],
reduce_op: Optional[dist.ReduceOp],
root_module: nn.Module,
) -> NanotronParameter:
if not isinstance(parameter, NanotronParameter):
parameter = NanotronParameter(tensor=parameter)
parameter.mark_as_tied(name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module)
return parameter
def tie_parameters(
root_module: nn.Module,
ties: List[Tuple[str, Tuple[int, ...]]],
parallel_context: ParallelContext,
reduce_op: Optional[dist.ReduceOp],
):
"""
Tie parameters.
Within a single device, tied parameters are replaced with a single Parameter
Across devices, we add metadata to Parameters that require extra synchronization.
:param root_module: nn.Module
:param ties: List[Tuple[str, Tuple[int, ...]]]: a tie is (param_target, global_ranks)
:param parallel_context: ParallelContext
:return:
"""
if len(ties) < 1:
raise ValueError("Can't tie nothing")
# TODO @thomasw21: When we support Zero3 this isn't true anymore
dp_ranks = tuple(
sorted(
{
parallel_context.get_local_ranks(world_rank=global_rank)[2]
for _, global_ranks in ties
for global_rank in global_ranks
}
)
)
assert (
len(dp_ranks) == 1
), f"Tying weights has to happen with a replica of a model. Got the ranks from the following replicas: {dp_ranks}"
name = ties[0][0]
global_ranks = tuple(sorted(set().union(*(tie[1] for tie in ties))))
new_param = None
world_rank = dist.get_rank(parallel_context.world_pg)
for tie_target, tie_model_ranks in ties:
if world_rank not in tie_model_ranks:
continue
param, parent_module, param_name = get_parameter_and_parent_module(target=tie_target, root_module=root_module)
# If they are physically in the same device, then we tie them
if new_param is None:
new_param = create_tied_parameter(
parameter=param, name=name, global_ranks=global_ranks, reduce_op=reduce_op, root_module=root_module
)
# Re-assign it to the original name. We assign the raw tensor instead of the parameter since we moved it already.
setattr(parent_module, param_name, new_param)
def create_pg_for_tied_weights(root_module: nn.Module, parallel_context: ParallelContext):
"""Tied weights are tied across specific set of global ranks, we use this method to create process groups for each difference set of global ranks"""
group_ranks = {
param.get_tied_info().global_ranks
for name, param in root_module.named_parameters()
if isinstance(param, NanotronParameter) and param.is_tied
}
world_group_ranks = [None] * parallel_context.world_pg.size()
dist.all_gather_object(world_group_ranks, group_ranks, group=parallel_context.world_pg)
all_group_ranks = sorted(
set().union(*world_group_ranks),
)
for global_ranks in all_group_ranks:
if global_ranks not in parallel_context.world_ranks_to_pg:
parallel_context.world_ranks_to_pg[global_ranks] = dist.new_group(global_ranks)
def get_tied_id_to_param(
parameters: List[NanotronParameter], root_module: nn.Module
) -> Dict[Tuple[str, Tuple[int, ...]], NanotronParameter]:
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in root_module.named_modules()}
# Fix the root_model
module_id_to_prefix[id(root_module)] = ""
return {
(
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix),
param.get_tied_info().global_ranks, # TODO @nouamane: merge groups which tie the same parameter
): param
for param in parameters
if param.is_tied
}
def sync_tied_weights_gradients(
module: nn.Module, # TODO: NanotronModel
parallel_context: ParallelContext,
grad_accumulator: Optional[GradientAccumulator],
):
tied_id_to_param = get_tied_id_to_param(
parameters=[param for param in module.parameters() if param.requires_grad], root_module=module
)
# Only first and last rank should print the warning
for rank in [0, parallel_context.world_pg.size() - 1]:
log_rank(
f"[Debug Tied Weights] Syncing the following tied weights: {tied_id_to_param.keys()}",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_pg,
rank=rank,
)
# Group tensors to reduce by process groups
# Important to use ordered dict in order to be synchronized across all ranks
group_ranks_and_reduce_op_to_tensors_to_reduce = OrderedDict()
for (name, group_ranks), tied_param in sorted(tied_id_to_param.items(), key=lambda x: x[0]):
tied_info = tied_param.get_tied_info()
# Some weights don't require any syncing, because they are by design synchronised
if tied_info.reduce_op is None:
continue
if grad_accumulator is not None:
tied_grad = grad_accumulator.get_grad_buffer(name=name)
else:
tied_grad = tied_param.grad
log_rank(
f"Syncing tied weights {name} across ranks {group_ranks} ...",
logger=logger,
level=logging.DEBUG,
group=parallel_context.world_ranks_to_pg[group_ranks],
rank=0,
)
key = (group_ranks, tied_info.reduce_op)
if key in group_ranks_and_reduce_op_to_tensors_to_reduce:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)].append(tied_grad)
else:
group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)] = [tied_grad]
for (group_ranks, reduce_op), tensors in group_ranks_and_reduce_op_to_tensors_to_reduce.items():
dist.all_reduce_coalesced(tensors=tensors, op=reduce_op, group=parallel_context.world_ranks_to_pg[group_ranks])
import functools
import operator
import os
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton
class MemoryBuffer(metaclass=Singleton):
"""
Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
"""
def __init__(self):
self.buffer = {}
def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)
def assert_cuda_max_connections_set_to_1(func):
flag_is_set_to_1 = None
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal flag_is_set_to_1
if flag_is_set_to_1 is None:
assert os.environ.get("CUDA_DEVICE_MAX_CONNECTIONS") == "1"
flag_is_set_to_1 = True
return func(*args, **kwargs)
return wrapper
def initial_sync(model: nn.Module, parallel_context: ParallelContext):
# Synchronize across dp: basic assumption
sorted_name_params = sorted(model.named_parameters(), key=lambda x: x[0])
for name, param in sorted_name_params:
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=parallel_context.dp_pg)
# Synchronize across tied weights: basic assumption
for (_, group_ranks), param in sorted(
get_tied_id_to_param(parameters=model.parameters(), root_module=model).items(), key=lambda x: x[0]
):
group = parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
import contextlib
import random
from dataclasses import dataclass
from typing import MutableMapping, Optional, Tuple
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup
@dataclass
class RandomState:
random: Tuple[int, Tuple[int, ...], None]
numpy: Tuple[str, np.ndarray, int, int, float]
torch_cpu: torch.Tensor
torch_cuda: Optional[torch.Tensor]
def __eq__(self, other):
return (
isinstance(other, RandomState)
and all(v1 == v2 for v1, v2 in zip(self.random, other.random))
and all(
np.array_equal(v1, v2) if isinstance(v1, np.ndarray) else v1 == v2
for v1, v2 in zip(self.numpy, other.numpy)
)
and torch.equal(self.torch_cpu, other.torch_cpu)
and (
other.torch_cuda is None if self.torch_cuda is None else torch.equal(self.torch_cuda, other.torch_cuda)
)
)
class RandomStates(MutableMapping[str, RandomState]):
def __init__(self, dict: dict):
for key, value in dict.items():
self.check_type(key, value)
# TODO @thomasw21: We make a copy for safety measure.
self._dict = dict.copy()
@staticmethod
def check_type(key, value):
if not isinstance(key, str):
raise ValueError(f"Expected key to be of type str. Got {type(key)}")
if not isinstance(value, RandomState):
raise ValueError(f"Expected value to be of type `nanotron.dataclass.RandomState`. Got {type(value)}")
def __getitem__(self, item):
return self._dict[item]
def __iter__(self):
return self._dict.__iter__()
def __len__(self):
return len(self._dict)
def __delitem__(self, key):
raise ValueError("Can't delete a random states key")
def __setitem__(self, key, value):
if key not in self._dict:
raise ValueError("Can't add a new random states after initialisation")
self.check_type(key, value)
return self._dict.__setitem__(key, value)
def __eq__(self, other):
if not isinstance(other, RandomStates):
return False
return self._dict == other._dict
def set_random_seed(seed: int):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def set_random_state(random_state: RandomState):
random.setstate(random_state.random)
np.random.set_state(random_state.numpy)
torch.set_rng_state(random_state.torch_cpu)
if torch.cuda.is_available():
torch.cuda.set_rng_state(random_state.torch_cuda, "cuda")
else:
assert random_state.torch_cuda is None
def get_current_random_state():
"""Returns a snapshot of current random state"""
return RandomState(
random=random.getstate(),
numpy=np.random.get_state(),
torch_cpu=torch.random.get_rng_state(),
torch_cuda=torch.cuda.get_rng_state("cuda") if torch.cuda.is_available() else None,
)
@contextlib.contextmanager
def branch_random_state(random_states: RandomStates, key: str, enabled: bool):
"""
Context manager handling random state:
- upon entering: Stores current random state and set new random state defined by key.
- upon exiting: updates key in `random_states` to the new current random state, and set back the old one.
"""
if not enabled:
yield
return
old_random_state = get_current_random_state()
# Get the new state associated to the key
new_random_state = random_states[key]
set_random_state(new_random_state)
try:
yield
finally:
# Update state from parallel_context with the newest state
new_random_state = get_current_random_state()
random_states[key] = new_random_state
# Set the old state back
set_random_state(old_random_state)
def get_synced_random_state(
random_state: RandomState,
pg: ProcessGroup,
):
# We use rank 0 as a reference and broadcast random states from that rank to all the other ranks within a group in order to sync them
reference_rank = 0
if dist.get_rank(pg) == reference_rank:
random_states = [random_state]
else:
random_states = [None]
# TODO @thomasw21: broadcast tensor using `broadcast` in order not to use pickle
dist.broadcast_object_list(
random_states, src=dist.get_global_rank(pg, reference_rank), group=pg, device=torch.device("cuda")
)
new_random_state = random_states[0]
assert new_random_state is not None
return new_random_state
from .fsspec import check_path_is_local, fs_copy, fs_open
from .s3_mover import S3Mover
__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"]
import contextlib
from pathlib import Path
from typing import Tuple, Union
import fsspec
from fsspec.implementations import local
def get_filesystem_and_path(path: Path, storage_options=None) -> Tuple[fsspec.AbstractFileSystem, str]:
# Use supported filesystems in `fsspec`. If you need another one, please use `fsspec.registry.register_implementation`
# DO NOT USE `mode` argument as it adds a suffix `0.part` when using `mode="w"`.
fs, _, paths = fsspec.core.get_fs_token_paths(str(path), storage_options=storage_options)
assert len(paths) == 1
return fs, paths[0]
@contextlib.contextmanager
def fs_open(
file: Union[str, Path],
mode="r",
):
# TODO @thomasw21: pass storage options.
fs, path = get_filesystem_and_path(file)
with fs.open(path, mode=mode) as f:
yield f
def fs_copy(
input_file: Union[str, Path],
output_file: Union[str, Path],
):
"""Copy file from input to output (possibly on s3/other fs)"""
with fs_open(input_file, mode="rb") as fi, fs_open(output_file, mode="wb") as fo:
fo.write(fi.read())
def check_path_is_local(path: Path, storage_options=None) -> bool:
return isinstance(get_filesystem_and_path(path=path, storage_options=storage_options)[0], local.LocalFileSystem)
import glob
import json
import os
import subprocess
import time
from datetime import datetime
from enum import Enum
from typing import Optional, Union
import torch
from datasets.download.streaming_download_manager import xPath
from filelock import FileLock, Timeout
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import human_format
logger = logging.get_logger(__name__)
class S3Mover:
# TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading
"""Take care of uploading a checkpoint to S3 in the background and remove it from the disk.
Args:
local_path: Path to the checkpoints on the local disk
s3_path: Path to the checkpoints on S3
remove_after_upload: If True, remove the checkpoint from the disk after uploading it to S3
s5cmd_numworkers: Number of workers to use for the s5cmd command
s5cmd_concurrency: Concurrency to use for the s5cmd command
s5cmd_path: Path to the s5cmd command
dummy: If True, don't actually upload/remove/etc anything. Useful for simpler multi-processing node and only uploading from one process.
Usage:
# Create a mover - use dummy=True for all the process that shouldn't do anything (e.g. all but one per node)
mover = S3Mover(local_path=/scratch/my-checkpoints,
s3_path=s3://my-bucket/my-checkpoints,
remove_after_upload=True,
s5cmd_numworkers=96,
s5cmd_concurrency=10,
s5cmd_path=/admin/user/my/bin/s5cmd,
dummy=False)
while training:
# from times to times update the state
mover_status = mover.update()
...
# When saving a checkpoint, check if the previous checkpoint has been uploaded and removed
# in a distributed setting
"""
class S3MoverState(Enum):
IDLE = "IDLE"
UPLOADING = "UPLOADING"
DOWNLOADING = "DOWNLOADING"
REMOVING_CHECKPOINT = "REMOVING_CHECKPOINT"
class DummyPopen:
def __init__(self, *args, **kwargs):
pass
def poll(self):
return 0
def communicate(self):
return ("", "")
def __init__(
self,
local_path: xPath,
s3_path: xPath,
post_upload_callback: Optional[callable] = None,
remove_after_upload: Optional[bool] = True,
s5cmd_numworkers: Optional[int] = None,
s5cmd_concurrency: Optional[int] = None,
s5cmd_path: Optional[str] = None,
s5cmd_credentials: Optional[str] = None,
clean_up_local_on_start: bool = False,
dummy: bool = False,
s3_region: str = "us-east-1",
):
self.process: Optional[Union[subprocess.Popen, S3Mover.DummyPopen]] = None
self.remove_after_upload = remove_after_upload
self.s5cmd_numworkers = s5cmd_numworkers
self.s5cmd_concurrency = s5cmd_concurrency
self.s5cmd_path = s5cmd_path if s5cmd_path is not None else "s5cmd"
self.s5cmd_credentials = s5cmd_credentials
self.lock_file = None
self.dummy = dummy
self.s3_region = s3_region
self.post_upload_callback = post_upload_callback
self.post_upload_callback_outputs = None
local_path = str(local_path)
if not local_path.startswith("/scratch/"):
self._warning(f"The local path is not on the scratch drive: {local_path}")
if not local_path.endswith("/"):
local_path += "/"
s3_path = str(s3_path)
if not s3_path.endswith("/"):
s3_path += "/"
self.local_path = local_path
self.s3_path = s3_path
s3_bucket, s3_prefix = s3_path.replace("s3://", "").split("/", maxsplit=1)
self.s3_path_direct_link = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?region={self.s3_region}&prefix={s3_prefix}&showversions=false"
self._reset_state()
if clean_up_local_on_start:
self._start_removing()
def _warning(self, message):
if self.dummy:
return
logger.warning(message)
def _info(self, message):
if self.dummy:
return
logger.info(message)
def _reset_state(self):
self.state = self.S3MoverState.IDLE
self.num_uploaded_files = 0
if self.lock_file is not None:
self._release_lock()
self.lock_file = None
self.stdout = ""
self.start_time: datetime = None
self.cmd = ""
def _popen(self, cmd: list):
self.stdout = ""
self.start_time = datetime.now()
self.cmd = cmd
if self.dummy:
return self.DummyPopen(cmd)
else:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
os.set_blocking(process.stdout.fileno(), False)
return process
def _acquire_lock(self, file_path: str) -> bool:
if self.dummy:
return True
if file_path.endswith("/"):
lock_file_path = file_path[:-1] + ".lock"
else:
lock_file_path = file_path + ".lock"
self.lock_file = FileLock(lock_file_path)
try:
self.lock_file.acquire(timeout=1)
except Timeout:
message = f"[S3] The checkpoint files {lock_file_path} are currently locked by another process. "
self._warning(message)
return False
return True
def get_state_as_int(self) -> int:
"""Return the state as an int"""
if self.state == self.S3MoverState.IDLE:
return 0
elif self.state == self.S3MoverState.UPLOADING:
return 1
elif self.state == self.S3MoverState.DOWNLOADING:
return 2
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
return 3
else:
return -1
def _release_lock(self):
if self.dummy:
return
if self.lock_file is not None and self.lock_file.is_locked:
self.lock_file.release()
def get_current_stdout(self) -> str:
"""Return the current stdout of the process if any"""
if self.process is None or isinstance(self.process, self.DummyPopen):
return ""
try:
stdout = self.process.stdout.read()
except ValueError:
stdout = "" # The buffer is already closed: "ValueError: read of closed file"
if stdout:
self.stdout += stdout.decode()
return self.stdout
def wait_for_completion(self):
while self.state != self.S3MoverState.IDLE:
_ = self.update()
time.sleep(0.5)
def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None):
"""Wait for the previous checkpoint to be fully uploaded and removed in a distributed setting.
Will wait for all process to be ready
"""
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
if all_saved != group.size() and self.state != self.S3MoverState.IDLE:
self._warning(
f"Waiting previous checkpoint saving is finished - S3Mover {dist.get_rank(group)} still in {self.state} state.",
)
while all_saved != group.size():
stdout = self.get_current_stdout()
stdout_lines = [lst for lst in stdout.split("\n") if lst]
if self.state != self.S3MoverState.IDLE:
self._warning(
f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}",
)
# sync all our saves on NCCL we could do a dist barrier later but this helps us not losing NCCL connections down the line
test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda"))
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())]
dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False)
dist.barrier()
all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list)
time.sleep(1)
def is_previous_save_finished(self) -> bool:
"""Return True if a potential previous checkpoint has been fully uploaded to S3
and removed from the drive
"""
self.update()
return self.state == self.S3MoverState.IDLE
def _start_downloading(self, sub_folder: Optional[str] = None) -> (bool, str):
self._warning(
f"[S3] Downloading checkpoint in background from {self.s3_path} to {self.local_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.s3_path + "*", self.local_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.DOWNLOADING
return True
def _post_downloading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
total_files = len([i for i in s5cmd_results if i["success"]])
total_not_downloaded_files = len([i for i in s5cmd_results if not i["success"]])
if total_not_downloaded_files == 0:
all_upload = "all files"
success = True
else:
all_upload = "not all files"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully downloaded {total_files} files for a total of {human_format(total_size)}B in {total_time}"
f"sec ({all_upload}) from S3 at {self.s3_path} to {self.local_path}"
f"(direct link: {self.s3_path_direct_link})"
)
return success
def _start_uploading(
self,
) -> (bool, str):
# Get a file lock on the first file
local_files = glob.glob(self.full_local_path + "/**/*.*", recursive=True)
locked = self._acquire_lock(local_files[0])
if not locked:
return False
if not os.path.exists(self.full_local_path):
message = f"[S3] Checkpoint {self.full_local_path} does not exist, cannot upload to S3"
self._warning(message)
return False
self._warning(
f"[S3] Uploading checkpoint in background from {self.full_local_path} to {self.full_s3_path} (direct link: {self.s3_path_direct_link})"
)
cmd = [self.s5cmd_path, "--json"]
if self.s5cmd_credentials is not None:
cmd += ["--credentials-file", self.s5cmd_credentials]
if self.s5cmd_numworkers is not None:
cmd += ["--numworkers", str(self.s5cmd_numworkers)]
cmd += ["cp", "--exclude", "*.lock", "--exclude", "*.lock.*"]
if self.s5cmd_concurrency is not None:
cmd += ["--concurrency", str(self.s5cmd_concurrency)]
cmd += [self.full_local_path, self.full_s3_path]
self.process = self._popen(cmd)
self.state = self.S3MoverState.UPLOADING
return True
def _post_uploading(self) -> bool:
self.get_current_stdout()
s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i]
local_files = glob.glob(self.full_local_path + "/**/*.?*", recursive=True)
total_files = len([i for i in s5cmd_results if i["success"]])
self.num_uploaded_files = total_files
if len(local_files) == total_files:
all_upload = "all files"
success = True
else:
all_upload = f"not all files: {len(local_files)} out of {total_files}"
success = False
total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"])
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully uploaded {total_files} files for a total of {human_format(total_size)}B in {total_time} sec"
f"({all_upload}) from {self.full_local_path} to S3 at {self.full_s3_path} "
f"(direct link: {self.s3_path_direct_link})"
)
if self.post_upload_callback:
self.post_upload_callback_outputs = self.post_upload_callback(uploaded_files=s5cmd_results)
self._release_lock()
return success
def _start_removing(self) -> (bool, str):
top_dir_in_local_checkpoint = [dir for dir in glob.glob(self.local_path + "/*") if os.path.isdir(dir)]
names_dir = [os.path.basename(dir) for dir in top_dir_in_local_checkpoint]
if len(names_dir) == 0:
# If the local is already empty or if we have already started duplicating in another process we skip with a noop
self._warning("[S3] Local checkpoint empty. skipping removal")
cmd = ["echo", "'skipping'"]
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
self._warning(f"[S3] Removing checkpoint in background: {names_dir}")
locked = self._acquire_lock(top_dir_in_local_checkpoint[0])
if not locked:
return False
cmd = ["rm", "-rfv"] + top_dir_in_local_checkpoint
self.process = self._popen(cmd)
self.state = self.S3MoverState.REMOVING_CHECKPOINT
return True
def _post_removing(self) -> bool:
self.get_current_stdout()
local_files = [
loc_f
for loc_f in self.stdout.split("\n")
if "directory" not in loc_f.lower() and loc_f and ".lock" not in loc_f
]
if len(local_files) == self.num_uploaded_files:
all_removed = "all files"
success = True
else:
all_removed = "not all files"
success = False
self._release_lock()
total_time = (datetime.now() - self.start_time).total_seconds()
self._warning(
f"[S3] Successfully removed {len(local_files)} local files ({all_removed}) from {self.local_path} (uploaded to {self.s3_path_direct_link}) in {total_time}"
)
return success
def update(self) -> (str, str):
"""Update the state of the mover: UPLOADING => REMOVING_DUPLICATED => DUPLICATING => REMOVING_CHECKPOINT => IDLE
Returns:
(str, str): The state and the stdout of the process if any
"""
if self.process is None:
self._reset_state()
return self.state, self.stdout
return_code = self.process.poll()
if return_code is None:
# Still running
return self.state, self.stdout
if return_code != 0:
self.get_current_stdout()
self._warning(
f"[S3] Error running command {self.cmd} during process {self.state.value}, "
f"return code {return_code}, return message {self.stdout}"
)
return self.state, self.stdout
if self.state == self.S3MoverState.DOWNLOADING:
self._post_downloading()
self._reset_state()
elif self.state == self.S3MoverState.UPLOADING:
self._post_uploading()
if self.remove_after_upload:
self._start_removing()
else:
self._reset_state()
elif self.state == self.S3MoverState.REMOVING_CHECKPOINT:
self._post_removing()
self._reset_state()
return self.state.value, self.stdout
def start_uploading(self, sub_folder=None):
"""Start uploading last saved checkpoint to S3 in the background.
After running this method, you should call regularly `update` to update the
state to duplicating and then removing.
For a blocking upload, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = "[S3] Cannot move to S3 as the previous checkpoint has not been uploaded and removed"
self._warning(message)
return False
self.full_local_path = self.local_path + (f"/{sub_folder}" if sub_folder else "")
self.full_s3_path = self.s3_path + (f"/{sub_folder}" if sub_folder else "")
return self._start_uploading()
def start_downloading(self):
"""Start downloading a checkpoint from S3 in the background.
After running this method, you should call regularly `update` to update the
state.
For a blocking download, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method.
"""
self.update()
if self.state != self.S3MoverState.IDLE:
message = f"[S3] Cannot download from S3 as the state is not IDLE but {self.state.value}"
self._warning(message)
return False
return self._start_downloading()
from contextlib import contextmanager
from typing import Callable, Optional
import torch
from nanotron import distributed as dist
from nanotron import logging, optim
from nanotron.config import Config
from nanotron.logging import get_logger, log_rank
from nanotron.models import NanotronModel
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
logger = get_logger(__name__)
def assert_tensor_synced_across_pg(
tensor: torch.Tensor,
pg: dist.ProcessGroup,
msg: Optional[Callable[[str], str]] = None,
reference_rank: int = 0,
):
"""Assert that `tensor` is synced across `pg` with reference rank. Note that this always passes for reference rank"""
if dist.get_rank(pg) == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=dist.get_global_rank(group=pg, group_rank=reference_rank),
group=pg,
)
# TODO @nouamane: Getting Greatest absolute difference: 4.6e-10 at large scale when syncing tied weights
torch.testing.assert_close(tensor, reference_tensor, msg=msg)
# TODO @nouamanetazi: remove this with SANITY_CHECKS
@contextmanager
def assert_fail_except_rank_with(exception_class, rank_exception, pg):
try:
yield
except exception_class:
if rank_exception == dist.get_rank(pg):
raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
else:
return
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got {type(e)} instead:\n{e}")
if dist.get_rank(pg) != rank_exception:
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
def before_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(
parameters=unwrapped_model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before train] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that model grads are zeroed or None
for name, param in unwrapped_model.named_parameters():
if param.grad is not None:
torch.testing.assert_close(
param.grad,
torch.zeros_like(param.grad),
atol=0,
rtol=0,
msg="Model half precision grads must be zeroed or None in first accumulation step.",
)
# SANITY CHECK: Check that the grad accumulator buffers are ready for DDP
if grad_accumulator is not None:
for _, elt in grad_accumulator.fp32_grad_buffers.items():
fp32_grad_buffer = elt["fp32_grad"]
torch.testing.assert_close(
fp32_grad_buffer,
torch.zeros_like(fp32_grad_buffer),
atol=0,
rtol=0,
msg="Grad accumulator buffers must be zeroed in first accumulation step.",
)
# TODO: add checks for memory contiguousness
# SANITY CHECK: Check that optimizer's lr is synchronized with lr_scheduler
for i, group in enumerate(lr_scheduler.optimizer.param_groups):
assert (
group["lr"] == lr_scheduler.get_last_lr()[i]
), f"Optimizer and LR scheduler are not in sync. Got {group['lr']} and {lr_scheduler.get_last_lr()[i]}"
break
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_tbi_sanity_checks()
def after_tbi_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradient flow on the entire model
# SANITY CHECK: Check that all parameters that required gradients, have actually a gradient
# SANITY CHECK: Check for nan/inf
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
if torch.isnan(grad).any() or torch.isinf(grad).any():
raise ValueError("Gradient is nan or inf")
if grad is None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} is missing gradient",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_tbi_sanity_checks()
def before_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
optimizer: optim.BaseOptimizer,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Test tied weights gradients are synchronized
for (name, group_ranks), param in sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
):
if not param.requires_grad:
continue
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=grad,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights grads for {name} are not synchronized. {err}",
)
# SANITY CHECK: Test gradients are synchronized across DP
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
if not param.requires_grad:
continue
if param.is_tied:
tied_info = param.get_tied_info()
name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=unwrapped_model.module_id_to_prefix
)
if grad_accumulator is not None:
grad = grad_accumulator.get_grad_buffer(name=name)
else:
grad = param.grad
assert grad is not None, f"Grad is None for {name}"
assert_tensor_synced_across_pg(
tensor=grad,
pg=parallel_context.dp_pg,
msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP. {err}",
)
# SANITY CHECK: Check that the model params are synchronized across dp
for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synchronized across DP {err}",
)
# SANITY CHECK: Tied weights are synchronized
tied_params_list = sorted(
get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(),
key=lambda x: x[0],
)
for (name, group_ranks), param in tied_params_list:
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=param,
pg=group,
msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}",
)
# SANITY CHECK: Check that optimizer states are synchronized across DP
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.before_optim_step_sanity_checks()
def after_optim_step_sanity_checks(
config: Config,
parallel_context: ParallelContext,
unwrapped_model: NanotronModel,
grad_accumulator: GradientAccumulator,
) -> None:
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check that gradients is cleared
for name, param in unwrapped_model.named_parameters():
if not param.requires_grad:
continue
if param.grad is not None:
log_rank(
f"Process rank { dist.get_rank(parallel_context.world_pg)}/{parallel_context.world_pg.size()}: {name} still has gradient despite having ran the optimizer",
logger=logger,
level=logging.ERROR,
)
# SANITY CHECK: run model specific sanity checks
unwrapped_model.after_optim_step_sanity_checks()
def check_optim_state_in_sync(optim_state_dict: dict, pg: dist.ProcessGroup):
for _, optim_state in sorted(optim_state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
if name == "step":
continue
assert_tensor_synced_across_pg(
tensor=tensor, pg=pg, msg=lambda err: f"{name} are not synced across DP {err}"
)
import math
from abc import abstractmethod
from enum import Enum, auto
from typing import Dict
from nanotron.config import ModelArgs
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from torch import nn
from torch.nn import init
class ParametrizationMethod(Enum):
STANDARD = auto()
SPECTRAL_MUP = auto()
class Parametrizator:
def __init__(self, config: ModelArgs):
self.config = config
def parametrize(self, param_name: str, module: nn.Module):
if not isinstance(module, tuple(self.MODULE_TO_PARAMETRIZE.keys())):
raise Exception(f"Parameter {param_name} was not initialized")
return self.MODULE_TO_PARAMETRIZE[type(module)](param_name, module)
class StandardParametrizator(Parametrizator):
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_column_linear,
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers
def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_row_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
std = self.std / math.sqrt(2 * self.num_layers)
init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class SpectralMupParametrizator(Parametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
https://arxiv.org/abs/2310.17813
"""
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._parametrize_mup_weight,
TensorParallelRowLinear: self._parametrize_mup_weight,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
}
self.std = 1.0
@staticmethod
def _compute_spectral_std(std: float, fan_in: int, fan_out: int):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
σₗ = Θ(1/√nₗ₋₁ min{1, √(nₗ/nₗ₋₁)})
"""
return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))
def _parametrize_mup_weight(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
data = module.weight if param_name == "weight" else module.bias
fan_in, fan_out = init._calculate_fan_in_and_fan_out(data)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
std = SpectralMupParametrizator._compute_spectral_std(std=self.std, fan_in=fan_in, fan_out=fan_out)
init.normal_(data, mean=0.0, std=std)
def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
# NOTE: you're free to change the initialization of layer norm
# as it's not a part of µTransfer
if "weight" == param_name:
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
def _parametrize_embedding(self, param_name: str, module: nn.Module):
assert param_name in ["weight"]
# NOTE: you're free to change the initialization of input embedding/lm head
if "weight" == param_name:
init.normal_(module.weight, mean=0.0, std=self.std)
class LearningRateForParametrizator:
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
self.lr = lr
self.names_to_modules = names_to_modules
@abstractmethod
def get_lr(self, param_name: str, module: nn.Module) -> float:
raise NotImplementedError
class LearningRateForSP(LearningRateForParametrizator):
"""All parameters get the same learning rate."""
def get_lr(self, param_name: str, param: nn.Module) -> float:
return self.lr
class LearningRateForSpectralMup(LearningRateForParametrizator):
"""
A Spectral Condition for Feature Learning by Greg Yang, et al.
NOTE: each parameter gets a custom learning rate based on its fan-in and fan-out.
"""
def __init__(self, lr: float, names_to_modules: Dict[str, nn.Module]):
super().__init__(lr, names_to_modules)
self.MODULE_TO_PARAMETRIZE = {
TensorParallelColumnLinear: self._get_mup_lr,
TensorParallelRowLinear: self._get_mup_lr,
TritonRMSNorm: self._get_global_lr,
TensorParallelEmbedding: self._get_global_lr,
}
def _get_mup_lr(self, param: nn.Parameter, module: nn.Module):
"""
Parametrization 1 (Spectral parametrization)
Page 8, A Spectral Condition for Feature Learning by Greg Yang, et al.
ηₗ = Θ(nₗ/nₗ₋₁)
"""
fan_in, fan_out = init._calculate_fan_in_and_fan_out(param)
world_size = module.world_size
if isinstance(module, TensorParallelColumnLinear):
fan_out = fan_out * world_size
elif isinstance(module, TensorParallelRowLinear):
fan_in = fan_in * world_size
else:
raise ValueError(f"Unknown module {module}")
return self.lr * (fan_out / fan_in)
def _get_global_lr(self, param: nn.Parameter, module: nn.Module) -> float:
return self.lr
def get_lr(self, param_name: str, param: nn.Parameter) -> float:
"""Return the learning rate for the given parameter."""
# NOTE: param_name should be like 'model.token_position_embeddings.pp_block.token_embedding.weight'
# since names_to_modules map module_name to module
# so we remove the .weight and .bias from param_name to get the module_name
module_name = param_name.rsplit(".", 1)[0]
module = self.names_to_modules[module_name]
return self.MODULE_TO_PARAMETRIZE[type(module)](param, module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment