"examples/runtime/multimodal/llava_onevision_server.py" did not exist on "446ea3327735e125e19d37b6a2c25aed7ead68f3"
Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from enum import Enum, auto
class DTypes(Enum):
FP8E4M3 = auto()
FP8E5M2 = auto()
KFLOAT16 = auto()
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.meta import FP8Meta
@torch.no_grad()
def fp8_matmul_kernel(
mat_a: FP8Tensor,
transpose_a: bool,
mat_b: FP8Tensor,
transpose_b: bool,
use_split_accumulator: bool,
) -> torch.Tensor:
assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!"
device = mat_a.device
_empty_tensor = torch.Tensor()
output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32)
workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
accumulate = False
out_dtype = getattr(tex.DType, "kFloat32")
# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
bias = torch.tensor([], dtype=torch.float32)
TE_CONFIG_TRANSPOSE_BIAS = False
mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta
mat_b_fp8_meta: FP8Meta = mat_b.fp8_meta
# NOTE: these are the fixed configs that TE only takes
# so we have to TE the A and B matrix to match these configs
TE_CONFIG_TRANSPOSE_A = True
TE_CONFIG_TRANSPOSE_B = False
SCALE = AMAX = _empty_tensor
mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a
mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b
tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
return output
from typing import Optional, Tuple, TypedDict, Union
import torch
import torch.nn.functional as F
import transformer_engine as te # noqa
from torch import nn
from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.kernel import fp8_matmul_kernel
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor
class FP8LinearMeta(TypedDict):
"""FP8 metadata for FP8Linear."""
input_grad: FP8Meta
weight_grad: FP8Meta
output_grad: FP8Meta
class FP8Linear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None):
super().__init__(in_features, out_features, bias, device)
# TODO(xrsrke): add device, and 2 fp8 dtypes
if self.weight.device != torch.device("cpu"):
self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3)
# NOTE: quantization metadata for input gradients, weight gradients, and output gradients
# TODO(xrsrke): don't fixed this
fp8e4m3_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR),
dtype=DTypes.FP8E4M3,
)
fp8e5m2_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32),
dtype=DTypes.FP8E5M2,
)
self.fp8_meta: FP8LinearMeta = {
# kfloat8_e4m3
"input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
"weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
# kfloat8_e5m2
"output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale),
}
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
# NOTE: only do fp8 kernel if both input and weight are on CUDA device
if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"):
return F.linear(input, self.weight, self.bias)
# NOTE: just a phony tensor to make pytorch trigger the backward pass
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony)
# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
# since we will quantize it back to FP8 anyway in the next linear
output = output if self.bias is None else output + self.bias
return output
class _FP8Matmul(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(
ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor
) -> torch.Tensor:
if type(input) == torch.Tensor:
input = FP8Tensor(input, dtype=DTypes.FP8E4M3)
ctx.save_for_backward(input, weight)
ctx.fp8_meta = fp8_meta
# NOTE: pass FP8Tensor instead of FP8Parameter
output = fp8_matmul_kernel(
mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False
)
return output, phony
@staticmethod
@torch.no_grad()
def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]:
"""
∂L/∂X = ∂L/∂Y @ Wᵀ
∂L/∂W = Xᵀ @ ∂L/∂Y
Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
# TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs
input, weight = ctx.saved_tensors
if type(grad_output) == torch.Tensor:
grad_output = torch.ones_like(grad_output)
grad_output = grad_output.contiguous()
grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2)
grad_input = fp8_matmul_kernel(
mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True
)
grad_weight = fp8_matmul_kernel(
mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True
)
weight.grad = grad_weight
return grad_input, None, None, None
from dataclasses import dataclass
from typing import Union
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype
@dataclass
class FP8Meta:
"""Metadata for FP8Tensor."""
amax: Union[int, float]
scale: torch.Tensor
# TODO(xrsrke): change to Literal[torch.int8, torch.uint8]
dtype: torch.dtype
@property
def te_dtype(self) -> tex.DType:
return convert_torch_dtype_to_te_dtype(self.dtype)
def __post_init__(self):
# NOTE: transformer engine only accepts torch tensors
self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax
@property
def fp8_max(self) -> float:
"""Return the maximum normal value for the current dtype."""
return DTYPE_TO_FP8_MAX[self.dtype]
@property
def inverse_scale(self) -> torch.Tensor:
return 1 / self.scale
def __repr__(self) -> str:
return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})"
import torch
from torch import nn
from nanotron.fp8.constants import FP8_DTYPES
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor
class FP8Parameter(nn.Parameter):
"""
A custom FP8 parameter class that allows gradients
to flow into FP8 tensors (which are integer tensors).
"""
def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter:
assert isinstance(data, torch.Tensor), "data must be a tensor"
assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter"
assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors"
# TODO(xrsrke): if the tensor is on cpu, then bypass quantization
with torch.no_grad():
# TODO(xrsrke): support take an FP8 Tensor as data
# currently we can't only quantize a tensor to FP8 after the parameter is created
# because it raise "Only Tensors of floating point and complex dtype can require gradients"
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self._data = FP8Tensor(data, dtype=dtype)
return self
@property
def data(self) -> FP8Tensor:
return self._data
@data.setter
def data(self, data: FP8Tensor):
self._data = data
@property
def fp8_meta(self) -> FP8Meta:
return self.data.fp8_meta
def __repr__(self) -> str:
return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}"
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX, FP8_DTYPES, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
class FP8Tensor(torch.Tensor):
"""FP8 Tensor."""
def __new__(cls, tensor: torch.Tensor, dtype: DTypes) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor), "tensor must be a tensor"
assert tensor.dtype not in FP8_DTYPES, "The tensor already quantized to FP8"
# TODO(xrsrke): there is a circular import issue
# between tensor.py and meta.py fix this
from nanotron.fp8.meta import FP8Meta
# TODO(xrsrke): if the tensor is on cpu, then bypass the quantization
# because the current kernels only support gpu tensor
assert tensor.device != torch.device("cpu"), "FP8Tensor only supports CUDA device"
assert isinstance(dtype, DTypes)
amax = tensor.abs().max().clone()
scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype)
fp8_meta = FP8Meta(amax, scale, dtype)
fp8_tensor = convert_tensor_to_fp8(tensor, fp8_meta)
# TODO(xrsrke): move update inverse scaling to FP8Meta's initialization
obj = torch.Tensor._make_subclass(cls, fp8_tensor)
obj.fp8_meta = fp8_meta
return obj
def __repr__(self) -> str:
return f"FP8Tensor({self}, fp8_meta={self.fp8_meta})"
def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType:
# NOTE: transformer engine maintains it own dtype mapping
# so we need to manually map torch dtypes to TE dtypes
TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {
torch.int32: "kInt32",
torch.float32: "kFloat32",
torch.float16: "kFloat16",
torch.bfloat16: "kBFloat16",
# torch.fp8e5m2: "kFloat8E5M2",
# torch.fp8e4m3: "kFloat8E4M3",
# torch.int8: "kFloat8E5M2",
# torch.uint8: "kFloat8E4M3",
DTypes.FP8E4M3: "kFloat8E4M3",
DTypes.FP8E5M2: "kFloat8E5M2",
DTypes.KFLOAT16: "kFloat16",
}
return getattr(tex.DType, TORCH_DTYPE_TE_DTYPE_NAME_MAPPING[dtype])
# TODO(xrsrke): add type hint for meta after fixing
# circular import between tensor.py and meta.py
def convert_tensor_to_fp8(tensor: torch.Tensor, meta) -> FP8Tensor:
te_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
# TODO(xrsrke): after casting to fp8, update the scaling factor
# TODO(xrsrke): it's weird that TE only take inverse_scale equal to 1
inverse_scale = torch.tensor(1.0, device=tensor.device, dtype=torch.float32)
return tex.cast_to_fp8(tensor, meta.scale, meta.amax, inverse_scale, te_dtype)
def convert_tensor_from_fp8(tensor: torch.Tensor, meta, dtype: torch.dtype) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor)
assert isinstance(dtype, torch.dtype)
tensor_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
output_dtype = convert_torch_dtype_to_te_dtype(dtype)
return tex.cast_from_fp8(tensor, meta.inverse_scale, tensor_dtype, output_dtype)
def update_scaling_factor(
amax: torch.Tensor, scaling_factor: torch.Tensor, dtype: DTypes, margin: float = 0
) -> torch.Tensor:
"""
Update the scaling factor to quantize a tensor to FP8.
Credits: https://github.com/Azure/MS-AMP/blob/d562f0f0bcfc9b712fa0726b73428753ff1300ab/msamp/common/tensor/meta.py#L39
"""
assert amax.dtype == torch.float32
# TODO(xrsrke): can we use lower precision for scaling_factor?
assert scaling_factor.dtype == torch.float32
# NOTE: Since fp8_max is a fixed number based on two FP8 data types,
# we prefer not to take fp8_max in the input arguments.
fp8_max = torch.tensor(DTYPE_TO_FP8_MAX[dtype], dtype=torch.float32)
# NOTE: torch.jit only take a concrete value rather than a DTYPE_TO_FP8_MAX[dtype],
# so we create an inner function to bypass that
@torch.jit.script
def _inner(amax: torch.Tensor, fp8_max: torch.Tensor, scaling_factor: torch.Tensor, margin: float):
# NOTE: calculate the number of bits to shift the exponent
ratio = fp8_max / amax
exp = torch.floor(torch.log2(ratio)) - margin
new_scaling_factor = torch.round(torch.pow(2, torch.abs(exp)))
new_scaling_factor = torch.where(amax > 0.0, new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(torch.isfinite(amax), new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(exp < 0, 1 / new_scaling_factor, new_scaling_factor)
return new_scaling_factor
return _inner(amax, fp8_max, scaling_factor, margin)
import torch
import transformer_engine as te # noqa
from nanotron.fp8.constants import FP8_GPU_NAMES
def is_fp8_available() -> bool:
"""Check if FP8 is available on the current device."""
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(torch.cuda.current_device()).lower()
return any(gpu_name in device_name for gpu_name in FP8_GPU_NAMES)
else:
return False
from .sampler import BasicSampler, GreedySampler, Sampler, SamplerType, TopKSampler, TopPSampler
__all__ = ["BasicSampler", "GreedySampler", "Sampler", "SamplerType", "TopKSampler", "TopPSampler"]
import dataclasses
import time
from itertools import chain, islice
from typing import TYPE_CHECKING, Generator, Iterable, List, Optional, Tuple, Union
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import BenchArgs, GenerationArgs
from nanotron.distributed import ProcessGroup, get_global_rank
from nanotron.generation.generate_store import Store, attach_store
from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler
from nanotron.helpers import log_throughput
from nanotron.models.llama import LlamaModel
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.block import get_min_max_rank
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.p2p import P2PTensorMetaData, view_as_contiguous
from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import get_untyped_storage
if TYPE_CHECKING:
try:
from transformers import PreTrainedTokenizer
except ImportError:
PreTrainedTokenizer = None
logger = logging.get_logger(__name__)
@dataclasses.dataclass
class GenerationInput:
text: str
@dataclasses.dataclass
class GenerationInputs:
input_ids: Union[torch.Tensor, TensorPointer] # [B, S]
input_masks: Union[torch.Tensor, TensorPointer]
@dataclasses.dataclass
class GenerationOutput:
input_ids: Union[torch.Tensor, TensorPointer]
generation_ids: Union[torch.Tensor, TensorPointer]
return_logits: Optional[Union[torch.Tensor, TensorPointer]] = None
@dataclasses.dataclass
class GenerationStates:
new_input_ids: Union[torch.Tensor, TensorPointer]
new_input_mask: Union[torch.Tensor, TensorPointer]
store: Store
# The rest of the state I need to reconstruct the generated output
generation_ids: List[Union[torch.Tensor, TensorPointer]]
generation_mask: List[Union[torch.Tensor, TensorPointer]]
@dataclasses.dataclass
class TokenizerConfig:
max_input_length: Optional[int]
truncation: Optional[Union[str, bool]] = None
padding: Optional[Union[str, bool]] = None
def chunks(iterable, chunk_size: int) -> Generator[List, None, None]:
"""Yield successive n-sized chunks from `iterable`"""
assert chunk_size >= 1
iterator = iter(iterable)
for first in iterator:
yield list(chain([first], islice(iterator, chunk_size - 1)))
def micro_batcher(
input_iter: Iterable[GenerationInput],
tokenizer: "PreTrainedTokenizer",
max_micro_batch_size: int,
tokenizer_config: TokenizerConfig,
parallel_context: ParallelContext,
input_rank: int,
) -> Generator[GenerationInputs, None, None]:
"""
Returns:
input_ids: [max_micro_batch_size, max_input_length]
input_masks: [max_micro_batch_size, max_input_length]
"""
if tokenizer_config.padding is None:
tokenizer_config.padding = "max_length" if tokenizer_config.max_input_length is not None else True
if tokenizer_config.truncation is None:
tokenizer_config.truncation = True if tokenizer_config.max_input_length is not None else None
for micro_batch_id, micro_batch in enumerate(chunks(input_iter, chunk_size=max_micro_batch_size)):
if len(micro_batch) == 0:
# Empty micro batches don't matter
return
if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg):
# Each dp is responsible for its own micro batches
continue
if dist.get_rank(parallel_context.pp_pg) == input_rank:
encodings = tokenizer(
[elt.text for elt in micro_batch],
return_tensors="pt",
return_attention_mask=True,
padding=tokenizer_config.padding,
max_length=tokenizer_config.max_input_length,
truncation=tokenizer_config.truncation,
# pad_to_multiple_of=8
)
encodings["attention_mask"] = encodings.attention_mask.to(dtype=torch.bool, device="cuda")
encodings.to("cuda")
yield GenerationInputs(input_ids=encodings.input_ids, input_masks=encodings.attention_mask)
else:
yield GenerationInputs(
input_ids=TensorPointer(group_rank=input_rank), input_masks=TensorPointer(group_rank=input_rank)
)
def micro_splitter(
input_ids: torch.Tensor,
input_mask: torch.Tensor,
max_micro_batch_size: int,
parallel_context: ParallelContext,
input_rank: int,
) -> Generator[GenerationInputs, None, None]:
"""
Returns:
input_ids: [max_micro_batch_size, max_input_length]
input_masks: [max_micro_batch_size, max_input_length]
"""
for micro_batch_id, (micro_batch_ids, micro_batch_mask) in enumerate(
zip(torch.split(input_ids, max_micro_batch_size), torch.split(input_mask, max_micro_batch_size))
):
if len(micro_batch_ids) == 0:
# Empty micro batches don't matter
return
# if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg):
# # Each dp is responsible for its own micro batches
# continue
if dist.get_rank(parallel_context.pp_pg) == input_rank:
micro_batch_mask = micro_batch_mask.to(dtype=torch.bool, device="cuda")
micro_batch_mask.to("cuda")
yield GenerationInputs(input_ids=micro_batch_ids.clone(), input_masks=micro_batch_mask.clone())
else:
yield GenerationInputs(
input_ids=TensorPointer(group_rank=input_rank), input_masks=TensorPointer(group_rank=input_rank)
)
@torch.inference_mode()
def decode_text(
input_iter: Iterable[GenerationInput],
tokenizer: "PreTrainedTokenizer",
model: LlamaModel,
parallel_context: ParallelContext,
generation_config: GenerationArgs,
tokenizer_config: Optional[TokenizerConfig],
max_micro_batch_size: int,
max_new_tokens: int,
is_bench: bool = False,
logits_are_batch_first: bool = True,
) -> Generator[GenerationOutput, None, None]:
"""We assume the following:
- Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input.
- Only a specific rank will output the generated text_ids as `torch.Tensor`, the others return a `TensorPointer`. # TODO @thomasw21: Maybe all ranks should return the text.
- We assume that within a model replica, the inputs are already synchronized.
"""
decoder_input_rank, decoder_logit_rank = get_min_max_rank(module=model)
if generation_config:
if isinstance(generation_config.sampler, str):
sampler_type = SamplerType(generation_config.sampler.upper())
else:
sampler_type = generation_config.sampler
else:
sampler_type = SamplerType.GREEDY
# Compute flag
is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank
is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank
max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1
p2p = model.p2p
# replicate input for n_samples times when using TOP_P or TOP_K samplers, in order to get diverse results
if generation_config and generation_config.n_samples:
if sampler_type != SamplerType.TOP_P and sampler_type != SamplerType.TOP_K:
raise ValueError("Only support n_samples for TOP_P and TOP_K sampler")
input_iter = [
GenerationInput(text=input.text) for input in input_iter for _ in range(generation_config.n_samples)
]
# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
# We query the first `pipeline_size` batches
for batches in chunks(
iterable=micro_batcher(
input_iter=input_iter,
tokenizer=tokenizer,
max_micro_batch_size=max_micro_batch_size,
tokenizer_config=tokenizer_config,
input_rank=decoder_input_rank,
parallel_context=parallel_context,
),
chunk_size=max_nb_microbatches,
):
if len(batches) == 0:
# It means we're out of element
return
# Number of micro batches
number_states_in_buffer = len(batches)
# Otherwise the pipelining doesn't work
assert number_states_in_buffer <= max_nb_microbatches
is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches
# Initialize decoder states
decoder_states: Iterable[GenerationStates] = (
GenerationStates(
new_input_ids=batch.input_ids,
new_input_mask=batch.input_masks,
store=Store(),
generation_ids=[batch.input_ids],
generation_mask=[batch.input_masks],
)
for batch in batches
)
if is_bench:
start_time, elapsed_time_first_iteration = time.perf_counter(), 0
for generation_iter in range(max_new_tokens):
if is_bench and generation_iter == 0:
torch.cuda.synchronize()
elapsed_time_first_iteration = start_time - time.perf_counter()
all_new_decoder_input_ids_and_mask_same_rank: List[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
] = []
new_decoder_states: List[GenerationStates] = []
for state_id, state in enumerate(decoder_states):
new_decoder_states.append(state)
# Get the new logits
if generation_config.use_cache:
with attach_store(model=model, store=state.store):
# transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size]
sharded_logits = model(
input_ids=state.new_input_ids,
input_mask=state.new_input_mask,
)
else:
if isinstance(state.new_input_ids, torch.Tensor):
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
batch_generated_ids = state.new_input_ids
batch_generated_mask = state.new_input_mask
sharded_logits = model(
input_ids=batch_generated_ids,
input_mask=batch_generated_mask,
)
if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first:
sharded_logits = sharded_logits.transpose(0, 1)
# Communicate
# TODO @thomasw21: Make a diagram to show how this works
nb_send: int = 0
if is_decoder_input_rank:
if is_max_nb_microbatches:
if generation_iter == 0:
if state_id == number_states_in_buffer - 1:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if number_states_in_buffer - 1 == state_id or generation_iter == 0:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if state_id == number_states_in_buffer - 1:
if not is_max_nb_microbatches:
nb_send = len(pipeline_state.microbatches_activations_to_send)
for _ in range(nb_send):
pipeline_state.run_communication()
if is_decoder_logit_rank:
assert isinstance(sharded_logits, torch.Tensor)
# run a logit chooser.
if sampler_type == SamplerType.GREEDY:
sampler = GreedySampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_K:
sampler = TopKSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_P:
sampler = TopPSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.BASIC:
sampler = BasicSampler(pg=parallel_context.tp_pg)
else:
raise NotImplementedError(f"Sampler type {sampler_type} is not implemented")
new_decoder_input_ids = sampler(sharded_logits=sharded_logits[:, -1, :])
# TODO @thomasw21: Handle this correctly, ie from some point after <eos> this should only generate masked tokens
# TODO @thomasw21: Actually I can probably build this thing on the next device directly. Will save some communication
new_decoder_input_mask = torch.ones(
size=(new_decoder_input_ids.shape[0], 1),
dtype=torch.bool,
device=new_decoder_input_ids.device,
)
# TODO @thomasw21: We need to have stop condition.
# broadcast new_tokens to everyone
if decoder_input_rank == decoder_logit_rank:
# It's the same rank so no need to do anything too fancy
all_new_decoder_input_ids_and_mask_same_rank.append(
(new_decoder_input_ids, new_decoder_input_mask)
)
else:
pipeline_state.register_send_activation(
new_decoder_input_ids, to_rank=decoder_input_rank, p2p=p2p
)
pipeline_state.register_send_activation(
new_decoder_input_mask, to_rank=decoder_input_rank, p2p=p2p
)
if not is_max_nb_microbatches and state_id == number_states_in_buffer - 1:
# Send new_decoder_input_ids AND new_decoder_input_ids
pipeline_state.run_communication()
pipeline_state.run_communication()
else:
assert isinstance(sharded_logits, TensorPointer)
all_new_decoder_input_ids_and_mask: Iterable[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
]
if is_decoder_input_rank:
# We receive the tensor from other ranks unless `decoder_input_rank` == `decoder_logit_rank` in which case `all_new_decoder_input_ids` is already populated.
if decoder_input_rank == decoder_logit_rank:
# `all_new_decoder_input_ids_and_mask_same_rank` is already populated. Since `decoder_input_rank` and `decoder_logit_rank` are the same, there's no need to communicate as we can just store the new input_ids in a list.
assert len(all_new_decoder_input_ids_and_mask_same_rank) == number_states_in_buffer
all_new_decoder_input_ids_and_mask = all_new_decoder_input_ids_and_mask_same_rank
else:
def generator():
for _ in range(number_states_in_buffer):
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
while len(pipeline_state.activations_buffer) < 2:
pipeline_state.run_communication()
new_decoder_input_ids = pipeline_state.activations_buffer.popleft()
new_decoder_input_mask = pipeline_state.activations_buffer.popleft()
yield new_decoder_input_ids, new_decoder_input_mask
all_new_decoder_input_ids_and_mask = iter(generator())
else:
all_new_decoder_input_ids_and_mask = (
(TensorPointer(group_rank=decoder_input_rank), TensorPointer(group_rank=decoder_input_rank))
for _ in range(number_states_in_buffer)
)
# Create new decoder states
decoder_states = (
GenerationStates(
new_input_ids=new_decoder_input_ids_and_mask[0],
new_input_mask=new_decoder_input_ids_and_mask[1],
store=state.store,
generation_ids=state.generation_ids + [new_decoder_input_ids_and_mask[0]],
generation_mask=state.generation_mask + [new_decoder_input_ids_and_mask[1]],
)
for state, new_decoder_input_ids_and_mask in zip(
new_decoder_states, all_new_decoder_input_ids_and_mask
)
)
if is_bench:
# Compute throughput (tok/s/gpu). Note that the first generation is done with full seq_len, so we don't count it.
torch.cuda.synchronize()
total_time_sec = time.perf_counter() - start_time - elapsed_time_first_iteration
# We generate 1 token per iteration per batch (batch=microbatch)
# Number of tokens generated every iteration: gbs/iteration_time
global_batch_size = len(batches) * parallel_context.dp_pg.size()
tokens_per_sec = global_batch_size * max_new_tokens / total_time_sec
model_tflops, hardware_tflops = model.get_flops_per_sec(
iteration_time_in_sec=total_time_sec,
sequence_length=max_new_tokens,
global_batch_size=global_batch_size,
)
bench_config = BenchArgs(
model_name=model.config._name_or_path,
sequence_length=max_new_tokens,
micro_batch_size=max_micro_batch_size,
batch_accumulation_per_replica=1,
benchmark_csv_path="benchmark.csv",
)
model_size = sum(
[p.numel() * p.data.element_size() for p in chain(model.parameters(), model.buffers())]
)
log_throughput(
bench_config,
parallel_context,
model_tflops,
hardware_tflops,
tokens_per_sec,
bandwidth=model_size * tokens_per_sec / 1e9,
)
# Flush communication
for _ in range(
max(
len(pipeline_state.microbatches_activations_to_send),
len(pipeline_state.microbatches_activations_to_recv),
)
):
pipeline_state.run_communication()
assert len(pipeline_state.microbatches_activations_to_send) == 0
assert len(pipeline_state.microbatches_activations_to_recv) == 0
# Yield result
decoder_states = list(decoder_states)
for state, batch in zip(decoder_states, batches):
if is_decoder_input_rank:
assert all(isinstance(elt, torch.Tensor) for elt in state.generation_ids)
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
assert all(isinstance(elt, TensorPointer) for elt in state.generation_ids)
batch_generated_ids = TensorPointer(group_rank=decoder_input_rank)
batch_generated_mask = TensorPointer(group_rank=decoder_input_rank)
# Broadcast all data
batch_generated_ids, batch_generated_mask = broadcast_tensors(
[batch_generated_ids, batch_generated_mask],
group_src=decoder_input_rank,
group=parallel_context.pp_pg,
)
batch.input_ids, batch.input_masks = broadcast_tensors(
[batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg
)
# Flush the store to release memory
state.store.flush()
assert len(state.store) == 0
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
assert (
batch_generated_ids.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_mask.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_mask.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_ids.shape[1] == batch_generated_mask.shape[1]
), f"Sequence length needs to match {batch_generated_ids.shape[1]} != {batch_generated_mask.shape[0]}"
for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)):
# TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
input_ids = batch.input_ids[i]
input_mask = batch.input_masks[i]
yield GenerationOutput(
input_ids=input_ids[input_mask],
generation_ids=generated_ids[generated_mask],
)
else:
yield GenerationOutput(
input_ids=TensorPointer(group_rank=decoder_input_rank),
generation_ids=TensorPointer(group_rank=decoder_input_rank),
)
@torch.inference_mode()
def decode_tokenized(
input_ids: torch.Tensor,
input_mask: torch.Tensor,
model: LlamaModel,
parallel_context: ParallelContext,
generation_config: GenerationArgs,
max_micro_batch_size: int,
max_new_tokens: int,
returns_logits: Optional[bool] = False,
) -> Generator[GenerationOutput, None, None]:
"""We assume the following:
- Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input.
- Only a specific rank will output the generated text_ids as `torch.Tensor`, the others return a `TensorPointer`. # TODO @thomasw21: Maybe all ranks should return the text.
- We assume that within a model replica, the inputs are already synchronized.
"""
if returns_logits:
raise NotImplementedError("return_logits is not implemented yet")
if generation_config:
if isinstance(generation_config.sampler, str):
sampler_type = SamplerType(generation_config.sampler.upper())
else:
sampler_type = generation_config.sampler
else:
sampler_type = SamplerType.GREEDY
decoder_input_rank, decoder_logit_rank = get_min_max_rank(module=model)
# Compute flag
is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank
is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank
max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1
# TODO @thomasw21: Fix this as we shouldn't get P2P like that
p2p = model.p2p
# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
# We query the first `pipeline_size` batches
for batches in chunks(
iterable=micro_splitter(
input_ids,
input_mask,
max_micro_batch_size=max_micro_batch_size,
parallel_context=parallel_context,
input_rank=decoder_input_rank,
),
chunk_size=max_nb_microbatches,
):
if len(batches) == 0:
# It means we're out of element
return
# Number of micro batches
number_states_in_buffer = len(batches)
# Otherwise the pipelining doesn't work
assert number_states_in_buffer <= max_nb_microbatches
is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches
# Initialize decoder states
decoder_states: Iterable[GenerationStates] = (
GenerationStates(
new_input_ids=batch.input_ids,
new_input_mask=batch.input_masks,
store=Store(),
generation_ids=[batch.input_ids],
generation_mask=[batch.input_masks],
)
for batch in batches
)
for generation_iter in range(max_new_tokens):
all_new_decoder_input_ids_and_mask_same_rank: List[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
] = []
new_decoder_states: List[GenerationStates] = []
for state_id, state in enumerate(decoder_states):
new_decoder_states.append(state)
# Get the new logits
with attach_store(model=model, store=state.store):
# transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size]
sharded_logits = model(
input_ids=state.new_input_ids,
input_mask=state.new_input_mask,
)
if isinstance(sharded_logits, torch.Tensor):
sharded_logits = sharded_logits.transpose(0, 1)
# Communicate
# TODO @thomasw21: Make a diagram to show how this works
nb_send: int = 0
if is_decoder_input_rank:
if is_max_nb_microbatches:
if generation_iter == 0:
if state_id == number_states_in_buffer - 1:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if number_states_in_buffer - 1 == state_id or generation_iter == 0:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if state_id == number_states_in_buffer - 1:
if not is_max_nb_microbatches:
nb_send = len(pipeline_state.microbatches_activations_to_send)
for _ in range(nb_send):
pipeline_state.run_communication()
if is_decoder_logit_rank:
assert isinstance(sharded_logits, torch.Tensor)
# run a logit chooser.
if sampler_type == SamplerType.GREEDY:
sampler = GreedySampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_K:
sampler = TopKSampler(
pg=parallel_context.tp_pg,
k=generation_config.top_k,
temperature=generation_config.temperature,
)
elif sampler_type == SamplerType.TOP_P:
sampler = TopPSampler(
pg=parallel_context.tp_pg,
p=generation_config.top_p,
temperature=generation_config.temperature,
)
elif sampler_type == SamplerType.BASIC:
sampler = BasicSampler(pg=parallel_context.tp_pg)
else:
raise NotImplementedError(f"Sampler type {sampler_type} is not implemented")
new_decoder_input_ids = sampler(sharded_logits=sharded_logits[:, -1, :])
# TODO @thomasw21: Handle this correctly, ie from some point after <eos> this should only generate masked tokens
# TODO @thomasw21: Actually I can probably build this thing on the next device directly. Will save some communication
new_decoder_input_mask = torch.ones(
size=(new_decoder_input_ids.shape[0], 1),
dtype=torch.bool,
device=new_decoder_input_ids.device,
)
# TODO @thomasw21: We need to have stop condition.
# broadcast new_tokens to everyone
if decoder_input_rank == decoder_logit_rank:
# It's the same rank so no need to do anything too fancy
all_new_decoder_input_ids_and_mask_same_rank.append(
(new_decoder_input_ids, new_decoder_input_mask)
)
else:
pipeline_state.register_send_activation(
new_decoder_input_ids, to_rank=decoder_input_rank, p2p=p2p
)
pipeline_state.register_send_activation(
new_decoder_input_mask, to_rank=decoder_input_rank, p2p=p2p
)
if not is_max_nb_microbatches and state_id == number_states_in_buffer - 1:
# Send new_decoder_input_ids AND new_decoder_input_ids
pipeline_state.run_communication()
pipeline_state.run_communication()
else:
assert isinstance(sharded_logits, TensorPointer)
all_new_decoder_input_ids_and_mask: Iterable[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
]
if is_decoder_input_rank:
# We receive the tensor from other ranks unless `decoder_input_rank` == `decoder_logit_rank` in which case `all_new_decoder_input_ids` is already populated.
if decoder_input_rank == decoder_logit_rank:
# `all_new_decoder_input_ids_and_mask_same_rank` is already populated. Since `decoder_input_rank` and `decoder_logit_rank` are the same, there's no need to communicate as we can just store the new input_ids in a list.
assert len(all_new_decoder_input_ids_and_mask_same_rank) == number_states_in_buffer
all_new_decoder_input_ids_and_mask = all_new_decoder_input_ids_and_mask_same_rank
else:
def generator():
for _ in range(number_states_in_buffer):
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
while len(pipeline_state.activations_buffer) < 2:
pipeline_state.run_communication()
new_decoder_input_ids = pipeline_state.activations_buffer.popleft()
new_decoder_input_mask = pipeline_state.activations_buffer.popleft()
yield new_decoder_input_ids, new_decoder_input_mask
all_new_decoder_input_ids_and_mask = iter(generator())
else:
all_new_decoder_input_ids_and_mask = (
(TensorPointer(group_rank=decoder_input_rank), TensorPointer(group_rank=decoder_input_rank))
for _ in range(number_states_in_buffer)
)
# Create new decoder states
decoder_states = (
GenerationStates(
new_input_ids=new_decoder_input_ids_and_mask[0],
new_input_mask=new_decoder_input_ids_and_mask[1],
store=state.store,
generation_ids=state.generation_ids + [new_decoder_input_ids_and_mask[0]],
generation_mask=state.generation_mask + [new_decoder_input_ids_and_mask[1]],
)
for state, new_decoder_input_ids_and_mask in zip(
new_decoder_states, all_new_decoder_input_ids_and_mask
)
)
# Flush communication
for _ in range(
max(
len(pipeline_state.microbatches_activations_to_send),
len(pipeline_state.microbatches_activations_to_recv),
)
):
pipeline_state.run_communication()
assert len(pipeline_state.microbatches_activations_to_send) == 0
assert len(pipeline_state.microbatches_activations_to_recv) == 0
# Yield result
decoder_states = list(decoder_states)
for state, batch in zip(decoder_states, batches):
if is_decoder_input_rank:
assert all(isinstance(elt, torch.Tensor) for elt in state.generation_ids)
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
assert all(isinstance(elt, TensorPointer) for elt in state.generation_ids)
batch_generated_ids = TensorPointer(group_rank=decoder_input_rank)
batch_generated_mask = TensorPointer(group_rank=decoder_input_rank)
# Broadcast all data
batch_generated_ids, batch_generated_mask = broadcast_tensors(
[batch_generated_ids, batch_generated_mask],
group_src=decoder_input_rank,
group=parallel_context.pp_pg,
)
batch.input_ids, batch.input_masks = broadcast_tensors(
[batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg
)
# Flush the store to release memory
state.store.flush()
assert len(state.store) == 0
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
assert (
batch_generated_ids.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_mask.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_mask.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_ids.shape[1] == batch_generated_mask.shape[1]
), f"Sequence length needs to match {batch_generated_ids.shape[1]} != {batch_generated_mask.shape[0]}"
for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)):
# TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
input_ids = batch.input_ids[i]
input_mask = batch.input_masks[i]
yield GenerationOutput(
input_ids=input_ids[input_mask],
generation_ids=generated_ids[generated_mask],
)
else:
yield GenerationOutput(
input_ids=TensorPointer(group_rank=decoder_input_rank),
generation_ids=TensorPointer(group_rank=decoder_input_rank),
)
# Distributed utilities
def broadcast_tensors(
tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None
) -> List[torch.Tensor]:
result = []
for tensor in tensors:
if dist.get_rank(group) == group_src:
assert isinstance(tensor, torch.Tensor)
meta = [
[
tensor.dtype,
tensor.requires_grad,
tensor.shape,
get_untyped_storage(tensor).size(),
tensor.stride(),
tensor.is_contiguous(),
tensor.storage_offset(),
]
]
else:
assert isinstance(tensor, TensorPointer)
meta = [None]
dist.broadcast_object_list(meta, src=get_global_rank(group_rank=group_src, group=group), group=group)
dtype, requires_grad, shape, untyped_storage_size, stride, is_contiguous, storage_offset = meta[0]
meta = P2PTensorMetaData(
dtype=dtype,
requires_grad=requires_grad,
shape=shape,
untyped_storage_size=untyped_storage_size,
stride=stride,
is_contiguous=is_contiguous,
storage_offset=storage_offset,
)
if dist.get_rank(group) != group_src:
tensor = meta.create_empty_storage(device=torch.device("cuda"))
else:
tensor = view_as_contiguous(tensor)
dist.broadcast(tensor, src=get_global_rank(group_rank=group_src, group=group), group=group)
# Set shape and stride
tensor = tensor.as_strided(size=tuple(meta.shape), stride=tuple(meta.stride))
result.append(tensor)
return result
import collections
import contextlib
from torch import nn
class Store(collections.defaultdict):
"""
We use the store to locally store on gpu some states so that we don't have to communicate.
This is useful at inference if we don't want to recompute kv_cache for example, or that we don't want to communicate it through the pipeline
"""
def __init__(self):
super().__init__(dict)
def flush(self):
# TODO @thomasw21: There's probably a simpler way than doing this.
for key in list(self.keys()):
del self[key]
class AttachableStore:
def _attach_store(self, store: Store):
assert not hasattr(self, "_store"), "You can't assign a store when there's already one attached"
self._store = store
def _detach_store(self):
delattr(self, "_store")
def get_local_store(self):
if hasattr(self, "_store"):
if isinstance(self, nn.Module):
assert self.training is False, "Store is used only in evaluation mode"
return self._store[id(self)]
else:
return None
@contextlib.contextmanager
def attach_store(model: nn.Module, store: Store):
list_module_containing_store = []
for module in model.modules():
if not isinstance(module, AttachableStore):
continue
module._attach_store(store)
list_module_containing_store.append(module)
try:
yield
finally:
for module in list_module_containing_store:
module._detach_store()
from dataclasses import dataclass
from enum import Enum, auto
from typing import Sequence
import torch
from nanotron import distributed as dist
def all_gather_batches(in_tensor: torch.Tensor, in_split: Sequence[int], group: dist.ProcessGroup) -> torch.Tensor:
# All gather along first dimension, allow un-equal splits
out_tensor = torch.empty((sum(in_split),) + in_tensor.shape[1:], dtype=in_tensor.dtype, device=in_tensor.device)
out_split_list = list(torch.split(out_tensor, in_split, dim=0))
dist.all_gather(out_split_list, in_tensor, group=group)
return out_tensor
class SamplerType(Enum):
TOP_P = auto()
TOP_K = auto()
GREEDY = auto()
BASIC = auto()
class Sampler:
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@dataclass
class TopPSampler(Sampler):
pg: dist.ProcessGroup
p: float = 0.9
temperature: float = 1.0
filter_value: float = 0.0
min_tokens_to_keep: int = 1
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float) / self.temperature, dim=-1) # [batch_size, vocab_size]
# Sort the probs and their corresponding indices in descending order
sorted_probs, sorted_indices = torch.sort(probs, descending=False, dim=-1)
# Calculate the cumulative sum of the sorted probs
# the bfloat16 type is not accurate enough for the cumulative sum
cumulative_probs = torch.cumsum(sorted_probs, dim=-1, dtype=torch.float) # [batch_size, vocab_size]
# Find the smallest set of indices for which the cumulative probability mass exceeds p
sorted_indices_to_remove = cumulative_probs <= (1 - self.p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Construct the probability mask for original indices
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
filter_probs = probs.masked_fill(indices_to_remove, self.filter_value)
sampled_indices = torch.multinomial(filter_probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class GreedySampler(Sampler):
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local max logit and its index
# Note that max is deterministic, and always takes the first one.
max_values, max_indices = sharded_logits.max(dim=-1, keepdim=True) # [batch_size, 1]
# Add offset to the max indices
# TODO: We're assuming that TensorColumnLinear shards in a specific manner, i.e. rank 0 gets the first.
# It might require us to expose something from TensorColumnLinear.
max_indices = max_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather max logits and their indices from all shards, tp_rank first
max_values_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_values.dtype,
device=max_values.device,
)
max_indices_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_indices.dtype,
device=max_indices.device,
)
local_max_values_in = list(torch.split(max_values, in_split, dim=0))
local_max_indices_in = list(torch.split(max_indices, in_split, dim=0))
local_max_values_out = list(torch.split(max_values_out_mat, out_split, dim=0))
local_max_indices_out = list(torch.split(max_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_max_values_out, local_max_values_in, group=self.pg)
dist.all_to_all(local_max_indices_out, local_max_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_max_values = torch.cat(local_max_values_out, dim=-1) # [sharded_batch_size, num_shards]
sharded_max_indices = torch.cat(local_max_indices_out, dim=-1) # [sharded_batch_size, num_shards]
# Find global max logit across all shards
# Note that max is deterministic, and always takes the first one.
# [sharded_batch_size, 1]
_global_max_values, global_max_indices = sharded_max_values.max(dim=-1, keepdim=True)
# Select the corresponding token index from the offsetted gathered indices
sharded_selected_tokens = sharded_max_indices.gather(1, global_max_indices)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sharded_selected_tokens, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class TopKSampler(Sampler):
pg: dist.ProcessGroup
k: int = 50
temperature: float = 1.0
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local top-k logits and their indices
local_top_k_values, local_top_k_indices = torch.topk(sharded_logits, self.k, dim=-1)
# Add offset to the indices
local_top_k_indices = local_top_k_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split local_top_k_values into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# The last shard could be smaller than shard_batch_size
local_top_k_values_in = list(torch.split(local_top_k_values, in_split, dim=0))
local_tok_k_indices_in = list(torch.split(local_top_k_indices, in_split, dim=0))
# Prepare tensors for all-to-all operation
# Gather top-k logits and their indices from all shards, tp_rank first
top_k_values_out_mat = torch.empty(
(total_out_size,) + local_top_k_values.shape[1:],
dtype=local_top_k_values.dtype,
device=local_top_k_values.device,
)
top_k_indices_out_mat = torch.empty(
(total_out_size,) + local_top_k_indices.shape[1:],
dtype=local_top_k_indices.dtype,
device=local_top_k_indices.device,
)
local_top_k_values_out = list(torch.split(top_k_values_out_mat, out_split, dim=0))
local_top_k_indices_out = list(torch.split(top_k_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_top_k_values_out, local_top_k_values_in, group=self.pg)
dist.all_to_all(local_top_k_indices_out, local_tok_k_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_local_top_k_values = torch.cat(local_top_k_values_out, dim=-1) # [sharded_batch_size, k * num_shards]
sharded_local_top_k_indices = torch.cat(
local_top_k_indices_out, dim=-1
) # [sharded_batch_size, k * num_shards]
# Select global top-k from the gathered top-k, now the top-k is across all vocab, batch_size is sharded
sharded_top_k_values, sharded_top_k_indices = torch.topk(
sharded_local_top_k_values, self.k, dim=-1
) # [sharded_batch_size, k]
# Select corresponding indices from the gathered indices
sharded_top_k_indices = sharded_local_top_k_indices.gather(
-1, sharded_top_k_indices
) # [sharded_batch_size, k]
# Apply temperature and compute softmax probabilities
probs = torch.softmax(sharded_top_k_values.to(dtype=torch.float) / self.temperature, dim=-1)
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1) # [sharded_batch_size]
# Select the corresponding token index from the global top-k indices
new_decoder_input_ids = sharded_top_k_indices.gather(-1, sampled_indices) # [sharded_batch_size]
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(new_decoder_input_ids, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class BasicSampler(Sampler):
"""Basic sampler that samples from the full vocab according to the logits."""
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
# We will cross batch and vocab shards to sample from the full vocab and a part of the batch
# (right now logits are sharded on vocab and batch, so we need to do all-to-all)
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float), dim=-1) # [batch_size, vocab_size]
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
import contextlib
import csv
import gc
import math
import os
import time
from datetime import datetime
from functools import partial
from math import ceil
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.gradient_accumulator import (
FP32GradBucketManager,
FP32GradientAccumulator,
GradientAccumulator,
get_fp32_accum_hook,
)
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.optim.optimizer_from_gradient_accumulator import (
OptimizerFromGradientAccumulator,
)
from nanotron.optim.zero import ZeroDistributedOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
)
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata
logger = logging.get_logger(__name__)
def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size_divisible_by: int):
"""Pad vocab size so it is divisible by pg_size * make_vocab_size_divisible_by."""
multiple = make_vocab_size_divisible_by * pg_size
after = int(ceil(orig_vocab_size / multiple) * multiple)
if after != orig_vocab_size:
log_rank(
f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})",
logger=logger,
level=logging.WARNING,
rank=0,
)
return after
def init_random_states(parallel_config: ParallelismArgs, tp_pg: ProcessGroup):
# Get synchronized random states
if parallel_config is None or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
random_states = RandomStates(
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=tp_pg)}
)
else:
# NOTE: We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
return random_states
def lr_scheduler_builder(optimizer: Optimizer, lr_scheduler_args: LRSchedulerArgs, total_training_steps: int):
if lr_scheduler_args.lr_decay_steps is None:
lr_decay_steps = total_training_steps
if lr_scheduler_args.lr_warmup_steps is not None:
lr_decay_steps -= lr_scheduler_args.lr_warmup_steps
if lr_scheduler_args.lr_decay_starting_step is not None:
lr_decay_steps -= lr_scheduler_args.lr_decay_starting_step
else:
lr_decay_steps = lr_scheduler_args.lr_decay_steps
if lr_scheduler_args.lr_decay_starting_step is None:
if lr_scheduler_args.lr_warmup_steps is not None:
lr_decay_starting_step = lr_scheduler_args.lr_warmup_steps
else:
lr_decay_starting_step = 0
else:
lr_decay_starting_step = lr_scheduler_args.lr_decay_starting_step
def lr_lambda(current_step: int, initial_lr: float):
"""
current_step: current training step
initial_lr: the learning rate of a parameter group
More info on initial_lr:
And in standard parameterization, lr_lambda only takes a single learning rate.
But in µTransfer, each parameter has a custom learning rate (custom_lr = lr_scheduler_args.learning_rate * scaling_factor),
so each parameter group has a custom lr_lambda function.
LR Scheduling function, it has from 2 up to 4 phases:
- warmup,
- optional: constant (if lr_decay_starting_step is set)
- decay
- optional: constant (if lr_decay_steps and/or lr_decay_starting_step are set)
Warmup starts at lr=0 and ends at `lr=lr`
Then it stays constant at lr if lr_decay_starting_step is set and larger than lr_warmup_steps
Then it decays until `min_decay_lr` for lr_decay_steps if set, else: (total_training_steps - lr_warmup_steps or lr_decay_starting_step)
Then it stays constant at min_decay_lr if lr_decay_starting_step is set and total_training_steps is larger)
"""
# No warmup or decay
if lr_scheduler_args.lr_warmup_steps == 0 and lr_decay_steps == 0:
return initial_lr
# Warmup phase
elif lr_scheduler_args.lr_warmup_style is not None and current_step <= lr_scheduler_args.lr_warmup_steps:
if lr_scheduler_args.lr_warmup_style == "linear":
lmbda = initial_lr * current_step / max(lr_scheduler_args.lr_warmup_steps, 1)
elif lr_scheduler_args.lr_warmup_style == "constant":
lmbda = lr_scheduler_args.learning_rate
else:
raise ValueError(f"Unknown warmup style {lr_scheduler_args.lr_warmup_style}")
# Optional constant phase at learning_rate
elif current_step < lr_decay_starting_step:
lmbda = initial_lr
# Decay phase
elif lr_scheduler_args.lr_decay_style is not None and current_step < lr_decay_starting_step + lr_decay_steps:
if lr_scheduler_args.lr_decay_style == "cosine":
lmbda = (
lr_scheduler_args.min_decay_lr
+ (initial_lr - lr_scheduler_args.min_decay_lr)
* (1 + math.cos(math.pi * (current_step - lr_decay_starting_step) / lr_decay_steps))
/ 2
)
elif lr_scheduler_args.lr_decay_style == "linear":
lmbda = (
lr_scheduler_args.min_decay_lr
+ (initial_lr - lr_scheduler_args.min_decay_lr)
* (lr_decay_steps - (current_step - lr_decay_starting_step))
/ lr_decay_steps
)
elif lr_scheduler_args.lr_decay_style == "1-sqrt":
lmbda = lr_scheduler_args.min_decay_lr + (initial_lr - lr_scheduler_args.min_decay_lr) * (
1 - math.sqrt((current_step - lr_decay_starting_step) / lr_decay_steps)
)
else:
raise ValueError(f"Unknown decay style {lr_scheduler_args.lr_decay_style}")
# Optional constant phase at min_decay_lr
else:
lmbda = lr_scheduler_args.min_decay_lr
lmbda /= initial_lr # Normalization for pytorch
return lmbda
def get_lr_lambda_for_param_group(lr: float):
return partial(lr_lambda, initial_lr=lr)
# NOTE: get learning rate scheduler for each param group
lr_lambdas = []
for param_group in optimizer.get_base_optimizer().param_groups:
lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"]))
assert len(lr_lambdas) == len(
optimizer.get_base_optimizer().param_groups
), "Custom learning rate functions dont match the number of param groups"
log_rank(
f"[Optimizer Building] There are total {len(lr_lambdas)} custom learning rate function for parameter groups",
logger=logger,
level=logging.DEBUG,
)
lr_scheduler = LambdaLR(optimizer.get_base_optimizer(), lr_lambda=lr_lambdas)
return lr_scheduler
def get_custom_weight_decay_for_named_parameters(
named_parameters: Iterable[Tuple[str, torch.Tensor]],
model: NanotronModel,
module_id_to_prefix: Dict[int, str],
weight_decay: float,
) -> List[Dict[str, Any]]:
"""
Apply weight decay to all parameters except the ones that are in the named_param_without_weight_decay list.
"""
named_param_groups_with_custom_weight_decay = []
exclude_named_params = model.get_named_params_without_weight_decay()
for name, param in named_parameters:
if param.is_tied:
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
pass
if any(name.endswith(substring) for substring in exclude_named_params):
named_param_groups_with_custom_weight_decay.append({"named_params": [(name, param)], "weight_decay": 0.0})
else:
named_param_groups_with_custom_weight_decay.append(
{"named_params": [(name, param)], "weight_decay": weight_decay}
)
log_rank(
f"[Optimizer Building] Creating {len(named_param_groups_with_custom_weight_decay)} param groups with custom weight decay",
logger=logger,
level=logging.DEBUG,
)
return named_param_groups_with_custom_weight_decay
def get_custom_lr_for_named_parameters(
parametrization_method: ParametrizationMethod,
lr: float,
named_parameters: Iterable[Tuple[str, torch.Tensor]],
model: NanotronModel,
) -> List[Dict[str, Any]]:
"""
Get custom learning rates for parameters based on the parametrization method.
NOTE: in some paramtrization methods, we use a global learning rate for all parameters,
in others we use a custom learning rate for each parameter (eg: spectral µTransfer).
"""
assert parametrization_method in [ParametrizationMethod.SPECTRAL_MUP, ParametrizationMethod.STANDARD]
lr_mapper_cls = (
LearningRateForSpectralMup
if parametrization_method == ParametrizationMethod.SPECTRAL_MUP
else LearningRateForSP
)
log_rank(
f"[Optimizer Building] Using {lr_mapper_cls.__name__} as learning rate",
logger=logger,
level=logging.INFO,
rank=0,
)
# NOTE: since in the case of pipeline parallelism, each rank only has a subset of the model
# so we only get the parameters that are in the current rank
learning_rate_mapper = lr_mapper_cls(names_to_modules=model.named_modules_in_pp_rank, lr=lr)
named_param_groups_with_custom_lr = []
for (
name,
param,
) in named_parameters:
learning_rate = learning_rate_mapper.get_lr(name, param)
assert isinstance(learning_rate, float), f"Expected a float, got {learning_rate} for parameter {name}"
named_param_groups_with_custom_lr.append({"named_params": [(name, param)], "lr": learning_rate})
log_rank(
f"[Optimizer Building] Creating {len(named_param_groups_with_custom_lr)} param groups with custom learning rates",
logger=logger,
level=logging.DEBUG,
)
return named_param_groups_with_custom_lr
def merge_named_param_groups(
named_param_groups_with_lr: List[Dict[str, Any]],
named_param_groups_with_weight_decay: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
assert len(named_param_groups_with_lr) == len(
named_param_groups_with_weight_decay
), "Named param groups don't match in length"
named_param_groups = []
for group_with_lr, group_with_weight_decay in zip(
named_param_groups_with_lr, named_param_groups_with_weight_decay
):
assert group_with_lr["named_params"] == group_with_weight_decay["named_params"]
named_param_groups.append(
{
"named_params": group_with_lr["named_params"],
"lr": group_with_lr["lr"],
"weight_decay": group_with_weight_decay["weight_decay"],
}
)
return named_param_groups
def init_optimizer_and_grad_accumulator(
parametrization_method: ParametrizationMethod,
model: nn.Module,
optimizer_args: OptimizerArgs,
parallel_context: ParallelContext,
) -> Tuple[BaseOptimizer, GradientAccumulator]:
# Unwrap DDP
unwrapped_model: NanotronModel = model.module if isinstance(model, DistributedDataParallel) else model
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in unwrapped_model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(unwrapped_model)] = ""
named_parameters = list(unwrapped_model.get_named_params_with_correct_tied())
named_param_groups_with_lr = get_custom_lr_for_named_parameters(
parametrization_method=parametrization_method,
named_parameters=named_parameters,
model=unwrapped_model,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
)
named_param_groups_with_weight_decay = get_custom_weight_decay_for_named_parameters(
named_parameters=named_parameters,
model=unwrapped_model,
module_id_to_prefix=module_id_to_prefix,
weight_decay=optimizer_args.weight_decay,
)
named_param_groups = merge_named_param_groups(named_param_groups_with_lr, named_param_groups_with_weight_decay)
# Basic optimizer builder
def basic_optimizer_builder(named_param_groups):
optimizer = None
if optimizer_args.optimizer_factory.name == "adamW":
def optimizer(param_groups):
return torch.optim.AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)
elif optimizer_args.optimizer_factory.name == "sgd":
def optimizer(param_groups):
return torch.optim.SGD(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)
else:
raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")
return NamedOptimizer(
named_params_or_groups=named_param_groups,
optimizer_builder=optimizer,
)
optimizer_builder = basic_optimizer_builder
# Gradient accumulator builder
grad_accumulator: Optional[GradientAccumulator] = None
if optimizer_args.accumulate_grad_in_fp32:
# TODO @thomasw21: Make an optimizer builder system, instead of doing everything in functional manner
def grad_optimizer_builder(named_param_groups):
result = OptimizerFromGradientAccumulator(
gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(
named_parameters=named_params,
grad_buckets_named_params=named_parameters,
),
named_params_or_groups=named_param_groups,
optimizer_builder=basic_optimizer_builder,
)
# TODO @thomasw21: get better API to get the grad_accumulator
nonlocal grad_accumulator
grad_accumulator = result.gradient_accumulator
return result
optimizer_builder = grad_optimizer_builder
if optimizer_args.zero_stage > 0:
# Build optimizer
optimizer = ZeroDistributedOptimizer(
named_params_or_groups=named_param_groups,
# TODO @thomasw21: We need a better API for gradient accumulation/zero etc ...
optimizer_builder=optimizer_builder,
dp_pg=parallel_context.dp_pg,
)
# SANITY CHECK: assert that optimizer's named_params point to model's params (check only the first one)
if (
len(optimizer.zero_named_param_groups) > 0
and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0
):
optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0]
if isinstance(model, DistributedDataParallel):
optim_model_param_name = f"module.{optim_model_param_name}"
param = model.get_parameter(optim_model_param_name)
assert param.data_ptr() == optim_model_param.data_ptr()
else:
# Build optimizer
optimizer = optimizer_builder(named_param_groups)
if grad_accumulator is not None and optimizer_args.zero_stage > 0:
# There's a way to only require to reduce_scatter the gradients instead of all_reducing
# In order to do so I need to pass which segments of each parameter should be reduced on which dp rank.
assert isinstance(optimizer, ZeroDistributedOptimizer)
param_name_to_dp_rank_offsets = optimizer.param_name_to_dp_rank_offsets
assert isinstance(grad_accumulator, FP32GradientAccumulator)
grad_accumulator.assign_param_offsets(
dp_rank=dist.get_rank(parallel_context.dp_pg),
param_name_to_offsets=param_name_to_dp_rank_offsets,
)
# Register DDP hook to make fp32 grad accumulation work
if isinstance(model, DistributedDataParallel) and grad_accumulator is not None:
assert isinstance(grad_accumulator, FP32GradientAccumulator)
model.register_comm_hook(
state=FP32GradBucketManager(
dp_pg=parallel_context.dp_pg,
accumulator=grad_accumulator,
param_id_to_name={
id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
if param.is_tied
else name
for name, param in unwrapped_model.named_parameters()
},
),
hook=get_fp32_accum_hook(
reduce_scatter=optimizer.inherit_from(ZeroDistributedOptimizer), reduce_op=dist.ReduceOp.AVG
),
)
return optimizer, grad_accumulator
def test_equal_dict(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> None:
"""Raise if doesn't match."""
if sub_paths is None:
sub_paths = []
first_keys = set(first.keys())
second_keys = set(second.keys())
assert first_keys == second_keys, f"Keys don't match.\nFirst: {first_keys}\nSecond: {second_keys}"
for key in first_keys:
first_elt = first[key]
second_elt = second[key]
if isinstance(first_elt, dict):
assert isinstance(second_elt, dict), f"{first_elt} doesn't match {second_elt}"
test_equal_dict(first_elt, second_elt, sub_paths=sub_paths + [str(key)])
elif isinstance(first_elt, torch.Tensor):
assert isinstance(second_elt, torch.Tensor), f"{first_elt} doesn't match {second_elt}"
torch.testing.assert_close(
first_elt,
second_elt,
atol=0.0,
rtol=0.0,
msg=lambda msg: f"tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}",
)
else:
assert (
first_elt == second_elt
), f"{first_elt} doesn't match {second_elt} at key {'.'.join(sub_paths + [str(key)])}"
def get_profiler(config: Config):
if config.profiler is not None:
if config.profiler.profiler_export_path is not None:
on_trace_ready = tensorboard_trace_handler(
config.profiler.profiler_export_path / datetime.now().strftime("%Y%m%d-%H%M%S")
)
else:
on_trace_ready = None
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1, skip_first=3),
on_trace_ready=on_trace_ready,
# record_shapes=True,
# profile_memory=True,
with_stack=True,
)
else:
prof = contextlib.nullcontext()
return prof
def get_all_comps(n: int) -> List[List[List[int]]]:
"""Return a 3D numpy array with a series of pairs to test latency/bandwidth between:
This basically make a square matrix from the triangle of pair-to-pair comparisons
[[[0 1]
[2 3]]
[[0 2]
[1 3]]
[[0 3]
[1 2]]]
"""
# n: power of two
if not ((n & (n - 1) == 0) and n != 0):
# every power of 2 has exactly 1 bit set to 1 (the bit in that number's log base-2 index).
# So when subtracting 1 from it, that bit flips to 0 and all preceding bits flip to 1.
# That makes these 2 numbers the inverse of each other so when AND-ing them, we will get 0 as the result
raise ValueError("n must be a power of two")
def op(lst, d=4, r=1):
lst = lst.reshape(-1, d)
lst[1::2] = np.roll(lst[1::2], r, axis=1)
return lst.T.reshape(-1)
x = np.array(list(range(n)))
comps = []
d = 1
while d < n:
for r in range(d):
comps.append(op(x, d=d, r=r).copy())
d *= 2
ret = np.stack(comps)
return ret.reshape(ret.shape[0], -1, 2).tolist()
def test_all_pair_to_pair(
parallel_context: ParallelContext, throughput_size: int, throughput_iters: int, only_node_to_node: bool = True
):
"""Test all pair-to-pair GPUs throughput
Args:
parallel_context: ParallelContext
throughput_size: size of the tensor to send
throughput_iters: number of warm-up iterations before testing the throughput
only_node_to_node: if True, only test node-to-node throughput
"""
comparisons = get_all_comps(parallel_context.world_pg.size())
wr = dist.get_rank(parallel_context.world_pg)
log_rank(
f"[TEST] Testing throughput between {comparisons}",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=0,
)
for j, comp in enumerate(comparisons):
dist.barrier(group=parallel_context.world_pg)
for i, (a, b) in enumerate(comp):
dist.barrier(group=parallel_context.world_pg)
if wr not in [a, b]:
continue
if only_node_to_node and (a % 8 != 0 or b % 8 != 0):
# We only check node-to-node throughput
continue
test_tensor = torch.zeros((int(throughput_size),), dtype=torch.uint8, device=torch.device("cuda"))
for k in range(throughput_iters):
pre = time.perf_counter()
torch.cuda.synchronize()
if wr == a:
dist.send(test_tensor, b, group=parallel_context.world_pg, tag=i + k)
elif wr == b:
dist.recv(test_tensor, a, group=parallel_context.world_pg, tag=i + k)
torch.cuda.synchronize()
duration = time.perf_counter() - pre
del test_tensor
gc.collect()
torch.cuda.empty_cache()
tput = (float(throughput_size) / duration) * 8 # *8 for gigabits/second
log_rank(
f"[TEST] {j, i, wr} Results throughput from {a} to {b}: {tput/1e9:.4f} Gbps",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=None,
)
log_rank(
"[TEST] All comparisons done",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=0,
)
def create_table_log(
config: Config,
parallel_context: ParallelContext,
model_tflops,
hardware_tflops,
tokens_per_sec,
bandwidth,
slurm_job_id,
):
return [
LogItem("job_id", slurm_job_id, "s"),
LogItem("name", config.general.run, "s"),
LogItem("nodes", math.ceil(parallel_context.world_pg.size() / torch.cuda.device_count()), "d"),
LogItem("seq_len", config.tokens.sequence_length, "d"),
LogItem("mbs", config.tokens.micro_batch_size, "d"),
LogItem("batch_accum", config.tokens.batch_accumulation_per_replica, "d"),
LogItem("gbs", config.global_batch_size, "d"),
LogItem("mTFLOPs", model_tflops, ".2f"),
LogItem("hTFLOPs", hardware_tflops, ".2f"),
LogItem("tok/s/gpu", tokens_per_sec / parallel_context.world_pg.size(), ".2f"),
LogItem("Bandwidth (GB/s)", bandwidth, ".2f"),
LogItem("Mem Alloc (GB)", torch.cuda.max_memory_allocated() / 1024**3, ".2f"),
LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"),
]
def create_table_output(table_log, column_widths):
header_row = "| " + " | ".join([item.tag.ljust(width) for item, width in zip(table_log, column_widths)]) + " |"
separator_row = "| " + " | ".join(["-" * width for width in column_widths]) + " |"
data_row = (
"| "
+ " | ".join(
[f"{item.scalar_value:{item.log_format}}".ljust(width) for item, width in zip(table_log, column_widths)]
)
+ " |"
)
return f"{header_row}\n{separator_row}\n{data_row}"
def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id):
if not os.path.exists(csv_filename):
os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
with open(csv_filename, mode="w") as fo:
writer = csv.writer(fo)
writer.writerow([item.tag for item in table_log])
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# elif model_tflops > 0:
# # replace line with same job_id
# with open(csv_filename, mode="r") as fi:
# lines = fi.readlines()
# with open(csv_filename, mode="w") as fo:
# writer = csv.writer(fo)
# for line in lines:
# if line.startswith(slurm_job_id):
# writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# else:
# fo.write(line)
else:
with open(csv_filename, mode="a") as fo:
writer = csv.writer(fo)
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
def log_throughput(
config: Config,
parallel_context: ParallelContext,
model_tflops=0,
hardware_tflops=0,
tokens_per_sec=0,
bandwidth=0,
):
slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A")
table_log = create_table_log(
config, parallel_context, model_tflops, hardware_tflops, tokens_per_sec, bandwidth, slurm_job_id
)
column_widths = [max(len(item.tag), len(f"{item.scalar_value:{item.log_format}}")) for item in table_log]
table_output = create_table_output(table_log, column_widths)
log_rank(
table_output,
logger=logger,
level=logging.INFO,
rank=0,
)
if dist.get_rank(parallel_context.world_pg) == 0:
write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id)
def compute_remain_train_steps_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, config: Config, metadata: TrainingMetadata
) -> int:
def is_last_stage():
sorted_stages = sorted(config.data_stages, key=lambda x: x.start_training_step)
return sorted_stages[-1].start_training_step == stage.start_training_step
def is_resume_from_training():
return metadata.last_train_step > 0
if is_last_stage() is True:
total_train_steps = config.tokens.train_steps
else:
next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None)
total_train_steps = next_stage.start_training_step
if metadata.last_train_step > stage.start_training_step:
# NOTE: if the last_train_step is larger than the start_training_step of the current stage,
# it means that the training has already passed this stage
# so there is no remaining steps
return 0
else:
last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step
return total_train_steps - last_train_steps
def get_consumed_train_samples_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, metadata: TrainingMetadata
) -> Optional[int]:
start_training_step = stage.start_training_step
return next(
(s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step),
None,
)
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
from dataclasses import dataclass
from functools import lru_cache
from logging import (
CRITICAL,
DEBUG,
ERROR,
FATAL,
INFO,
NOTSET,
WARNING,
Formatter,
Logger,
)
from typing import TYPE_CHECKING, List, Optional, Union
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
if TYPE_CHECKING:
from nanotron.config import LoggingArgs
from nanotron.parallel import ParallelContext
log_levels = {
"debug": DEBUG,
"info": INFO,
"warning": WARNING,
"error": ERROR,
"critical": CRITICAL,
"fatal": FATAL,
"notset": NOTSET,
}
class NewLineStreamHandler(logging.StreamHandler):
"""
We want to apply formatter before each new line
https://stackoverflow.com/a/38458877
"""
def emit(self, record):
lines = record.msg.split("\n")
for line in lines:
record.msg = line
super().emit(record)
DEFAULT_HANDLER = NewLineStreamHandler()
DEFAULT_LOG_LEVEL = logging.WARNING
LIBRARY_NAME = __name__.split(".")[0]
def _get_default_logging_level():
"""
If NANOTRON_LOGGING_LEVEL env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("NANOTRON_LOGGING_LEVEL", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option NANOTRON_LOGGING_LEVEL={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return DEFAULT_LOG_LEVEL
def get_library_root_logger() -> Logger:
return get_logger(LIBRARY_NAME)
def _configure_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.addHandler(DEFAULT_HANDLER)
library_root_logger.setLevel(_get_default_logging_level())
def _reset_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.setLevel(logging.NOTSET)
def get_logger(name: Optional[str] = None, log_level: Optional[str] = None) -> Logger:
"""
Return a logger with the specified name.
"""
logger_already_exists = isinstance(logging.root.manager.loggerDict.get(name, None), Logger)
logger = logging.getLogger(name)
if logger_already_exists or name is None:
# if name is None we return root logger
return logger
# If the logger is in a `nanotron` module then we remove the capability to propagate
if LIBRARY_NAME == name.split(".", 1)[0]:
if log_level is not None:
logger.setLevel(log_level.upper())
elif LEVEL is not None:
logger.setLevel(LEVEL)
else:
logger.setLevel(_get_default_logging_level())
if HANDLER is not None:
logger.handlers.clear()
logger.addHandler(HANDLER)
logger.propagate = False
return logger
def get_verbosity() -> int:
"""
Return the current level for the Nanotron root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
Nanotron has following logging levels:
- 50: ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- 40: ``nanotron.logging.ERROR``
- 30: ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- 20: ``nanotron.logging.INFO``
- 10: ``nanotron.logging.DEBUG``
"""
return get_library_root_logger().getEffectiveLevel()
LEVEL = None
def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the all `nanotron` loggers.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- ``nanotron.logging.ERROR``
- ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- ``nanotron.logging.INFO``
- ``nanotron.logging.DEBUG``
"""
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
logger.setLevel(verbosity)
# We update all handles to be at the current verbosity as well.
for handle in logger.handlers:
handle.setLevel(verbosity)
global LEVEL
LEVEL = verbosity
HANDLER = None
def set_formatter(formatter: logging.Formatter) -> None:
"""
Set a new custom formatter as the current handler.
Note: it's important to first set level and then
:param formatter:
:return:
"""
handler = NewLineStreamHandler(sys.stdout)
handler.setFormatter(formatter)
handler.setLevel(get_verbosity())
handler.flush = sys.stderr.flush
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
# We keep only a single handler
logger.handlers.clear()
logger.addHandler(handler)
global HANDLER
HANDLER = handler
def log_rank(
msg: str,
logger: Logger,
level: int,
group: Optional[dist.ProcessGroup] = None,
rank: Optional[int] = None,
**kwargs,
):
"""Log only if the current process is the rank specified."""
# Use default group is group is not provided
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
# rank is None means everyone logs
if rank is None or dist.get_rank(group) == rank:
logger.log(level, msg, **kwargs)
@lru_cache(maxsize=None)
def warn_once(
msg: str, logger: Logger, group: Optional[dist.ProcessGroup] = None, rank: Optional[int] = None, **kwargs
):
log_rank(msg=msg, logger=logger, level=logging.WARNING, group=group, rank=rank, **kwargs)
def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str:
if abs(num) < 1:
return "{:.3g}".format(num)
SIZES = ["", "K", "M", "G", "T", "P", "E"]
num = float("{:.3g}".format(num))
magnitude = 0
i = 0
while abs(num) >= 1000 and i < len(SIZES) - 1:
magnitude += 1
num /= 1000.0 if not divide_by_1024 else 1024.0
i += 1
return "{}{}".format("{:f}".format(num).rstrip("0").rstrip("."), SIZES[magnitude])
def log_memory(logger: logging.Logger):
log_rank(
f" Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB."
f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB."
f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB",
logger=logger,
level=logging.INFO,
rank=0,
)
torch.cuda.reset_peak_memory_stats()
@dataclass
class LogItem:
tag: str
scalar_value: Union[float, int, str]
log_format: Optional[str] = None
@dataclass
class LoggerWriter:
global_step: int
def add_scalar(self, tag: str, scalar_value: Union[float, int], log_format=None) -> str:
if log_format == "human_format":
log_str = f"{tag}: {human_format(scalar_value)}"
else:
log_str = f"{tag}: {scalar_value:{log_format}}" if log_format is not None else f"{tag}: {scalar_value}"
return log_str
def add_scalars_from_list(self, log_entries: List[LogItem], iteration_step: int):
log_strs = [f"iteration: {iteration_step} / {self.global_step}"]
log_strs += [
self.add_scalar(log_item.tag, log_item.scalar_value, log_item.log_format) for log_item in log_entries
]
log_str = " | ".join(log_strs)
log_rank(log_str, logger=get_logger(__name__), level=logging.INFO)
def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
node_name = os.environ.get("SLURMD_NODENAME")
expert_parallel_log = (
f"|EXP={dist.get_rank(parallel_context.expert_pg)}" if parallel_context.expert_parallel_size > 1 else ""
)
formatter = Formatter(
fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(parallel_context.dp_pg)}|PP={dist.get_rank(parallel_context.pp_pg)}|"
f"TP={dist.get_rank(parallel_context.tp_pg)}{expert_parallel_log}{'|' + node_name if node_name else ''}]: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
log_level = log_levels[logging_level]
# main root logger
root_logger = get_logger()
root_logger.setLevel(log_level)
handler = NewLineStreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Nanotron
set_verbosity(log_level)
set_formatter(formatter=formatter)
def set_ranks_logging_level(parallel_context: ParallelContext, logging_config: "LoggingArgs"):
if dist.get_rank(parallel_context.world_pg) == 0:
if logging_config.log_level is not None:
set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
else:
if logging_config.log_level_replica is not None:
set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)
_configure_library_root_logger()
# flake8: noqa
from .base import DTypeInvariantTensor, NanotronModel, build_model, check_model_has_grad, init_on_device_and_dtype
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment