Commit 71e79847 authored by chenzk's avatar chenzk
Browse files

v1.0.3

parents
Pipeline #2034 canceled with stages
from enum import Enum, auto
class DTypes(Enum):
FP8E4M3 = auto()
FP8E5M2 = auto()
KFLOAT16 = auto()
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.meta import FP8Meta
@torch.no_grad()
def fp8_matmul_kernel(
mat_a: FP8Tensor,
transpose_a: bool,
mat_b: FP8Tensor,
transpose_b: bool,
use_split_accumulator: bool,
) -> torch.Tensor:
assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!"
device = mat_a.device
_empty_tensor = torch.Tensor()
output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32)
workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
accumulate = False
out_dtype = getattr(tex.DType, "kFloat32")
# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
bias = torch.tensor([], dtype=torch.float32)
TE_CONFIG_TRANSPOSE_BIAS = False
mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta
mat_b_fp8_meta: FP8Meta = mat_b.fp8_meta
# NOTE: these are the fixed configs that TE only takes
# so we have to TE the A and B matrix to match these configs
TE_CONFIG_TRANSPOSE_A = True
TE_CONFIG_TRANSPOSE_B = False
SCALE = AMAX = _empty_tensor
mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a
mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b
tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
return output
from typing import Optional, Tuple, TypedDict, Union
import torch
import torch.nn.functional as F
import transformer_engine as te # noqa
from torch import nn
from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.kernel import fp8_matmul_kernel
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor
class FP8LinearMeta(TypedDict):
"""FP8 metadata for FP8Linear."""
input_grad: FP8Meta
weight_grad: FP8Meta
output_grad: FP8Meta
class FP8Linear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None):
super().__init__(in_features, out_features, bias, device)
# TODO(xrsrke): add device, and 2 fp8 dtypes
if self.weight.device != torch.device("cpu"):
self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3)
# NOTE: quantization metadata for input gradients, weight gradients, and output gradients
# TODO(xrsrke): don't fixed this
fp8e4m3_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR),
dtype=DTypes.FP8E4M3,
)
fp8e5m2_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32),
dtype=DTypes.FP8E5M2,
)
self.fp8_meta: FP8LinearMeta = {
# kfloat8_e4m3
"input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
"weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
# kfloat8_e5m2
"output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale),
}
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
# NOTE: only do fp8 kernel if both input and weight are on CUDA device
if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"):
return F.linear(input, self.weight, self.bias)
# NOTE: just a phony tensor to make pytorch trigger the backward pass
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony)
# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
# since we will quantize it back to FP8 anyway in the next linear
output = output if self.bias is None else output + self.bias
return output
class _FP8Matmul(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(
ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor
) -> torch.Tensor:
if type(input) == torch.Tensor:
input = FP8Tensor(input, dtype=DTypes.FP8E4M3)
ctx.save_for_backward(input, weight)
ctx.fp8_meta = fp8_meta
# NOTE: pass FP8Tensor instead of FP8Parameter
output = fp8_matmul_kernel(
mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False
)
return output, phony
@staticmethod
@torch.no_grad()
def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]:
"""
∂L/∂X = ∂L/∂Y @ Wᵀ
∂L/∂W = Xᵀ @ ∂L/∂Y
Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
# TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs
input, weight = ctx.saved_tensors
if type(grad_output) == torch.Tensor:
grad_output = torch.ones_like(grad_output)
grad_output = grad_output.contiguous()
grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2)
grad_input = fp8_matmul_kernel(
mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True
)
grad_weight = fp8_matmul_kernel(
mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True
)
weight.grad = grad_weight
return grad_input, None, None, None
from dataclasses import dataclass
from typing import Union
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype
@dataclass
class FP8Meta:
"""Metadata for FP8Tensor."""
amax: Union[int, float]
scale: torch.Tensor
# TODO(xrsrke): change to Literal[torch.int8, torch.uint8]
dtype: torch.dtype
@property
def te_dtype(self) -> tex.DType:
return convert_torch_dtype_to_te_dtype(self.dtype)
def __post_init__(self):
# NOTE: transformer engine only accepts torch tensors
self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax
@property
def fp8_max(self) -> float:
"""Return the maximum normal value for the current dtype."""
return DTYPE_TO_FP8_MAX[self.dtype]
@property
def inverse_scale(self) -> torch.Tensor:
return 1 / self.scale
def __repr__(self) -> str:
return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})"
import torch
from torch import nn
from nanotron.fp8.constants import FP8_DTYPES
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor
class FP8Parameter(nn.Parameter):
"""
A custom FP8 parameter class that allows gradients
to flow into FP8 tensors (which are integer tensors).
"""
def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter:
assert isinstance(data, torch.Tensor), "data must be a tensor"
assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter"
assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors"
# TODO(xrsrke): if the tensor is on cpu, then bypass quantization
with torch.no_grad():
# TODO(xrsrke): support take an FP8 Tensor as data
# currently we can't only quantize a tensor to FP8 after the parameter is created
# because it raise "Only Tensors of floating point and complex dtype can require gradients"
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self._data = FP8Tensor(data, dtype=dtype)
return self
@property
def data(self) -> FP8Tensor:
return self._data
@data.setter
def data(self, data: FP8Tensor):
self._data = data
@property
def fp8_meta(self) -> FP8Meta:
return self.data.fp8_meta
def __repr__(self) -> str:
return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}"
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX, FP8_DTYPES, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
class FP8Tensor(torch.Tensor):
"""FP8 Tensor."""
def __new__(cls, tensor: torch.Tensor, dtype: DTypes) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor), "tensor must be a tensor"
assert tensor.dtype not in FP8_DTYPES, "The tensor already quantized to FP8"
# TODO(xrsrke): there is a circular import issue
# between tensor.py and meta.py fix this
from nanotron.fp8.meta import FP8Meta
# TODO(xrsrke): if the tensor is on cpu, then bypass the quantization
# because the current kernels only support gpu tensor
assert tensor.device != torch.device("cpu"), "FP8Tensor only supports CUDA device"
assert isinstance(dtype, DTypes)
amax = tensor.abs().max().clone()
scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype)
fp8_meta = FP8Meta(amax, scale, dtype)
fp8_tensor = convert_tensor_to_fp8(tensor, fp8_meta)
# TODO(xrsrke): move update inverse scaling to FP8Meta's initialization
obj = torch.Tensor._make_subclass(cls, fp8_tensor)
obj.fp8_meta = fp8_meta
return obj
def __repr__(self) -> str:
return f"FP8Tensor({self}, fp8_meta={self.fp8_meta})"
def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType:
# NOTE: transformer engine maintains it own dtype mapping
# so we need to manually map torch dtypes to TE dtypes
TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {
torch.int32: "kInt32",
torch.float32: "kFloat32",
torch.float16: "kFloat16",
torch.bfloat16: "kBFloat16",
# torch.fp8e5m2: "kFloat8E5M2",
# torch.fp8e4m3: "kFloat8E4M3",
# torch.int8: "kFloat8E5M2",
# torch.uint8: "kFloat8E4M3",
DTypes.FP8E4M3: "kFloat8E4M3",
DTypes.FP8E5M2: "kFloat8E5M2",
DTypes.KFLOAT16: "kFloat16",
}
return getattr(tex.DType, TORCH_DTYPE_TE_DTYPE_NAME_MAPPING[dtype])
# TODO(xrsrke): add type hint for meta after fixing
# circular import between tensor.py and meta.py
def convert_tensor_to_fp8(tensor: torch.Tensor, meta) -> FP8Tensor:
te_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
# TODO(xrsrke): after casting to fp8, update the scaling factor
# TODO(xrsrke): it's weird that TE only take inverse_scale equal to 1
inverse_scale = torch.tensor(1.0, device=tensor.device, dtype=torch.float32)
return tex.cast_to_fp8(tensor, meta.scale, meta.amax, inverse_scale, te_dtype)
def convert_tensor_from_fp8(tensor: torch.Tensor, meta, dtype: torch.dtype) -> torch.Tensor:
assert isinstance(tensor, torch.Tensor)
assert isinstance(dtype, torch.dtype)
tensor_dtype = convert_torch_dtype_to_te_dtype(meta.dtype)
output_dtype = convert_torch_dtype_to_te_dtype(dtype)
return tex.cast_from_fp8(tensor, meta.inverse_scale, tensor_dtype, output_dtype)
def update_scaling_factor(
amax: torch.Tensor, scaling_factor: torch.Tensor, dtype: DTypes, margin: float = 0
) -> torch.Tensor:
"""
Update the scaling factor to quantize a tensor to FP8.
Credits: https://github.com/Azure/MS-AMP/blob/d562f0f0bcfc9b712fa0726b73428753ff1300ab/msamp/common/tensor/meta.py#L39
"""
assert amax.dtype == torch.float32
# TODO(xrsrke): can we use lower precision for scaling_factor?
assert scaling_factor.dtype == torch.float32
# NOTE: Since fp8_max is a fixed number based on two FP8 data types,
# we prefer not to take fp8_max in the input arguments.
fp8_max = torch.tensor(DTYPE_TO_FP8_MAX[dtype], dtype=torch.float32)
# NOTE: torch.jit only take a concrete value rather than a DTYPE_TO_FP8_MAX[dtype],
# so we create an inner function to bypass that
@torch.jit.script
def _inner(amax: torch.Tensor, fp8_max: torch.Tensor, scaling_factor: torch.Tensor, margin: float):
# NOTE: calculate the number of bits to shift the exponent
ratio = fp8_max / amax
exp = torch.floor(torch.log2(ratio)) - margin
new_scaling_factor = torch.round(torch.pow(2, torch.abs(exp)))
new_scaling_factor = torch.where(amax > 0.0, new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(torch.isfinite(amax), new_scaling_factor, scaling_factor)
new_scaling_factor = torch.where(exp < 0, 1 / new_scaling_factor, new_scaling_factor)
return new_scaling_factor
return _inner(amax, fp8_max, scaling_factor, margin)
import torch
import transformer_engine as te # noqa
from nanotron.fp8.constants import FP8_GPU_NAMES
def is_fp8_available() -> bool:
"""Check if FP8 is available on the current device."""
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(torch.cuda.current_device()).lower()
return any(gpu_name in device_name for gpu_name in FP8_GPU_NAMES)
else:
return False
from .sampler import BasicSampler, GreedySampler, Sampler, SamplerType, TopKSampler, TopPSampler
__all__ = ["BasicSampler", "GreedySampler", "Sampler", "SamplerType", "TopKSampler", "TopPSampler"]
import dataclasses
import time
from itertools import chain, islice
from typing import TYPE_CHECKING, Generator, Iterable, List, Optional, Tuple, Union
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import BenchArgs, GenerationArgs
from nanotron.distributed import ProcessGroup, get_global_rank
from nanotron.generation.generate_store import Store, attach_store
from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler
from nanotron.helpers import log_throughput
from nanotron.models.llama import LlamaModel
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.block import get_min_max_rank
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.p2p import P2PTensorMetaData, view_as_contiguous
from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import get_untyped_storage
if TYPE_CHECKING:
try:
from transformers import PreTrainedTokenizer
except ImportError:
PreTrainedTokenizer = None
logger = logging.get_logger(__name__)
@dataclasses.dataclass
class GenerationInput:
text: str
@dataclasses.dataclass
class GenerationInputs:
input_ids: Union[torch.Tensor, TensorPointer] # [B, S]
input_masks: Union[torch.Tensor, TensorPointer]
@dataclasses.dataclass
class GenerationOutput:
input_ids: Union[torch.Tensor, TensorPointer]
generation_ids: Union[torch.Tensor, TensorPointer]
return_logits: Optional[Union[torch.Tensor, TensorPointer]] = None
@dataclasses.dataclass
class GenerationStates:
new_input_ids: Union[torch.Tensor, TensorPointer]
new_input_mask: Union[torch.Tensor, TensorPointer]
store: Store
# The rest of the state I need to reconstruct the generated output
generation_ids: List[Union[torch.Tensor, TensorPointer]]
generation_mask: List[Union[torch.Tensor, TensorPointer]]
@dataclasses.dataclass
class TokenizerConfig:
max_input_length: Optional[int]
truncation: Optional[Union[str, bool]] = None
padding: Optional[Union[str, bool]] = None
def chunks(iterable, chunk_size: int) -> Generator[List, None, None]:
"""Yield successive n-sized chunks from `iterable`"""
assert chunk_size >= 1
iterator = iter(iterable)
for first in iterator:
yield list(chain([first], islice(iterator, chunk_size - 1)))
def micro_batcher(
input_iter: Iterable[GenerationInput],
tokenizer: "PreTrainedTokenizer",
max_micro_batch_size: int,
tokenizer_config: TokenizerConfig,
parallel_context: ParallelContext,
input_rank: int,
) -> Generator[GenerationInputs, None, None]:
"""
Returns:
input_ids: [max_micro_batch_size, max_input_length]
input_masks: [max_micro_batch_size, max_input_length]
"""
if tokenizer_config.padding is None:
tokenizer_config.padding = "max_length" if tokenizer_config.max_input_length is not None else True
if tokenizer_config.truncation is None:
tokenizer_config.truncation = True if tokenizer_config.max_input_length is not None else None
for micro_batch_id, micro_batch in enumerate(chunks(input_iter, chunk_size=max_micro_batch_size)):
if len(micro_batch) == 0:
# Empty micro batches don't matter
return
if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg):
# Each dp is responsible for its own micro batches
continue
if dist.get_rank(parallel_context.pp_pg) == input_rank:
encodings = tokenizer(
[elt.text for elt in micro_batch],
return_tensors="pt",
return_attention_mask=True,
padding=tokenizer_config.padding,
max_length=tokenizer_config.max_input_length,
truncation=tokenizer_config.truncation,
# pad_to_multiple_of=8
)
encodings["attention_mask"] = encodings.attention_mask.to(dtype=torch.bool, device="cuda")
encodings.to("cuda")
yield GenerationInputs(input_ids=encodings.input_ids, input_masks=encodings.attention_mask)
else:
yield GenerationInputs(
input_ids=TensorPointer(group_rank=input_rank), input_masks=TensorPointer(group_rank=input_rank)
)
def micro_splitter(
input_ids: torch.Tensor,
input_mask: torch.Tensor,
max_micro_batch_size: int,
parallel_context: ParallelContext,
input_rank: int,
) -> Generator[GenerationInputs, None, None]:
"""
Returns:
input_ids: [max_micro_batch_size, max_input_length]
input_masks: [max_micro_batch_size, max_input_length]
"""
for micro_batch_id, (micro_batch_ids, micro_batch_mask) in enumerate(
zip(torch.split(input_ids, max_micro_batch_size), torch.split(input_mask, max_micro_batch_size))
):
if len(micro_batch_ids) == 0:
# Empty micro batches don't matter
return
# if micro_batch_id % parallel_context.dp_pg.size() != dist.get_rank(parallel_context.dp_pg):
# # Each dp is responsible for its own micro batches
# continue
if dist.get_rank(parallel_context.pp_pg) == input_rank:
micro_batch_mask = micro_batch_mask.to(dtype=torch.bool, device="cuda")
micro_batch_mask.to("cuda")
yield GenerationInputs(input_ids=micro_batch_ids.clone(), input_masks=micro_batch_mask.clone())
else:
yield GenerationInputs(
input_ids=TensorPointer(group_rank=input_rank), input_masks=TensorPointer(group_rank=input_rank)
)
@torch.inference_mode()
def decode_text(
input_iter: Iterable[GenerationInput],
tokenizer: "PreTrainedTokenizer",
model: LlamaModel,
parallel_context: ParallelContext,
generation_config: GenerationArgs,
tokenizer_config: Optional[TokenizerConfig],
max_micro_batch_size: int,
max_new_tokens: int,
is_bench: bool = False,
logits_are_batch_first: bool = True,
) -> Generator[GenerationOutput, None, None]:
"""We assume the following:
- Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input.
- Only a specific rank will output the generated text_ids as `torch.Tensor`, the others return a `TensorPointer`. # TODO @thomasw21: Maybe all ranks should return the text.
- We assume that within a model replica, the inputs are already synchronized.
"""
decoder_input_rank, decoder_logit_rank = get_min_max_rank(module=model)
if generation_config:
if isinstance(generation_config.sampler, str):
sampler_type = SamplerType(generation_config.sampler.upper())
else:
sampler_type = generation_config.sampler
else:
sampler_type = SamplerType.GREEDY
# Compute flag
is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank
is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank
max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1
p2p = model.p2p
# replicate input for n_samples times when using TOP_P or TOP_K samplers, in order to get diverse results
if generation_config and generation_config.n_samples:
if sampler_type != SamplerType.TOP_P and sampler_type != SamplerType.TOP_K:
raise ValueError("Only support n_samples for TOP_P and TOP_K sampler")
input_iter = [
GenerationInput(text=input.text) for input in input_iter for _ in range(generation_config.n_samples)
]
# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
# We query the first `pipeline_size` batches
for batches in chunks(
iterable=micro_batcher(
input_iter=input_iter,
tokenizer=tokenizer,
max_micro_batch_size=max_micro_batch_size,
tokenizer_config=tokenizer_config,
input_rank=decoder_input_rank,
parallel_context=parallel_context,
),
chunk_size=max_nb_microbatches,
):
if len(batches) == 0:
# It means we're out of element
return
# Number of micro batches
number_states_in_buffer = len(batches)
# Otherwise the pipelining doesn't work
assert number_states_in_buffer <= max_nb_microbatches
is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches
# Initialize decoder states
decoder_states: Iterable[GenerationStates] = (
GenerationStates(
new_input_ids=batch.input_ids,
new_input_mask=batch.input_masks,
store=Store(),
generation_ids=[batch.input_ids],
generation_mask=[batch.input_masks],
)
for batch in batches
)
if is_bench:
start_time, elapsed_time_first_iteration = time.perf_counter(), 0
for generation_iter in range(max_new_tokens):
if is_bench and generation_iter == 0:
torch.cuda.synchronize()
elapsed_time_first_iteration = start_time - time.perf_counter()
all_new_decoder_input_ids_and_mask_same_rank: List[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
] = []
new_decoder_states: List[GenerationStates] = []
for state_id, state in enumerate(decoder_states):
new_decoder_states.append(state)
# Get the new logits
if generation_config.use_cache:
with attach_store(model=model, store=state.store):
# transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size]
sharded_logits = model(
input_ids=state.new_input_ids,
input_mask=state.new_input_mask,
)
else:
if isinstance(state.new_input_ids, torch.Tensor):
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
batch_generated_ids = state.new_input_ids
batch_generated_mask = state.new_input_mask
sharded_logits = model(
input_ids=batch_generated_ids,
input_mask=batch_generated_mask,
)
if isinstance(sharded_logits, torch.Tensor) and logits_are_batch_first:
sharded_logits = sharded_logits.transpose(0, 1)
# Communicate
# TODO @thomasw21: Make a diagram to show how this works
nb_send: int = 0
if is_decoder_input_rank:
if is_max_nb_microbatches:
if generation_iter == 0:
if state_id == number_states_in_buffer - 1:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if number_states_in_buffer - 1 == state_id or generation_iter == 0:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if state_id == number_states_in_buffer - 1:
if not is_max_nb_microbatches:
nb_send = len(pipeline_state.microbatches_activations_to_send)
for _ in range(nb_send):
pipeline_state.run_communication()
if is_decoder_logit_rank:
assert isinstance(sharded_logits, torch.Tensor)
# run a logit chooser.
if sampler_type == SamplerType.GREEDY:
sampler = GreedySampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_K:
sampler = TopKSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_P:
sampler = TopPSampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.BASIC:
sampler = BasicSampler(pg=parallel_context.tp_pg)
else:
raise NotImplementedError(f"Sampler type {sampler_type} is not implemented")
new_decoder_input_ids = sampler(sharded_logits=sharded_logits[:, -1, :])
# TODO @thomasw21: Handle this correctly, ie from some point after <eos> this should only generate masked tokens
# TODO @thomasw21: Actually I can probably build this thing on the next device directly. Will save some communication
new_decoder_input_mask = torch.ones(
size=(new_decoder_input_ids.shape[0], 1),
dtype=torch.bool,
device=new_decoder_input_ids.device,
)
# TODO @thomasw21: We need to have stop condition.
# broadcast new_tokens to everyone
if decoder_input_rank == decoder_logit_rank:
# It's the same rank so no need to do anything too fancy
all_new_decoder_input_ids_and_mask_same_rank.append(
(new_decoder_input_ids, new_decoder_input_mask)
)
else:
pipeline_state.register_send_activation(
new_decoder_input_ids, to_rank=decoder_input_rank, p2p=p2p
)
pipeline_state.register_send_activation(
new_decoder_input_mask, to_rank=decoder_input_rank, p2p=p2p
)
if not is_max_nb_microbatches and state_id == number_states_in_buffer - 1:
# Send new_decoder_input_ids AND new_decoder_input_ids
pipeline_state.run_communication()
pipeline_state.run_communication()
else:
assert isinstance(sharded_logits, TensorPointer)
all_new_decoder_input_ids_and_mask: Iterable[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
]
if is_decoder_input_rank:
# We receive the tensor from other ranks unless `decoder_input_rank` == `decoder_logit_rank` in which case `all_new_decoder_input_ids` is already populated.
if decoder_input_rank == decoder_logit_rank:
# `all_new_decoder_input_ids_and_mask_same_rank` is already populated. Since `decoder_input_rank` and `decoder_logit_rank` are the same, there's no need to communicate as we can just store the new input_ids in a list.
assert len(all_new_decoder_input_ids_and_mask_same_rank) == number_states_in_buffer
all_new_decoder_input_ids_and_mask = all_new_decoder_input_ids_and_mask_same_rank
else:
def generator():
for _ in range(number_states_in_buffer):
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
while len(pipeline_state.activations_buffer) < 2:
pipeline_state.run_communication()
new_decoder_input_ids = pipeline_state.activations_buffer.popleft()
new_decoder_input_mask = pipeline_state.activations_buffer.popleft()
yield new_decoder_input_ids, new_decoder_input_mask
all_new_decoder_input_ids_and_mask = iter(generator())
else:
all_new_decoder_input_ids_and_mask = (
(TensorPointer(group_rank=decoder_input_rank), TensorPointer(group_rank=decoder_input_rank))
for _ in range(number_states_in_buffer)
)
# Create new decoder states
decoder_states = (
GenerationStates(
new_input_ids=new_decoder_input_ids_and_mask[0],
new_input_mask=new_decoder_input_ids_and_mask[1],
store=state.store,
generation_ids=state.generation_ids + [new_decoder_input_ids_and_mask[0]],
generation_mask=state.generation_mask + [new_decoder_input_ids_and_mask[1]],
)
for state, new_decoder_input_ids_and_mask in zip(
new_decoder_states, all_new_decoder_input_ids_and_mask
)
)
if is_bench:
# Compute throughput (tok/s/gpu). Note that the first generation is done with full seq_len, so we don't count it.
torch.cuda.synchronize()
total_time_sec = time.perf_counter() - start_time - elapsed_time_first_iteration
# We generate 1 token per iteration per batch (batch=microbatch)
# Number of tokens generated every iteration: gbs/iteration_time
global_batch_size = len(batches) * parallel_context.dp_pg.size()
tokens_per_sec = global_batch_size * max_new_tokens / total_time_sec
model_tflops, hardware_tflops = model.get_flops_per_sec(
iteration_time_in_sec=total_time_sec,
sequence_length=max_new_tokens,
global_batch_size=global_batch_size,
)
bench_config = BenchArgs(
model_name=model.config._name_or_path,
sequence_length=max_new_tokens,
micro_batch_size=max_micro_batch_size,
batch_accumulation_per_replica=1,
benchmark_csv_path="benchmark.csv",
)
model_size = sum(
[p.numel() * p.data.element_size() for p in chain(model.parameters(), model.buffers())]
)
log_throughput(
bench_config,
parallel_context,
model_tflops,
hardware_tflops,
tokens_per_sec,
bandwidth=model_size * tokens_per_sec / 1e9,
)
# Flush communication
for _ in range(
max(
len(pipeline_state.microbatches_activations_to_send),
len(pipeline_state.microbatches_activations_to_recv),
)
):
pipeline_state.run_communication()
assert len(pipeline_state.microbatches_activations_to_send) == 0
assert len(pipeline_state.microbatches_activations_to_recv) == 0
# Yield result
decoder_states = list(decoder_states)
for state, batch in zip(decoder_states, batches):
if is_decoder_input_rank:
assert all(isinstance(elt, torch.Tensor) for elt in state.generation_ids)
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
assert all(isinstance(elt, TensorPointer) for elt in state.generation_ids)
batch_generated_ids = TensorPointer(group_rank=decoder_input_rank)
batch_generated_mask = TensorPointer(group_rank=decoder_input_rank)
# Broadcast all data
batch_generated_ids, batch_generated_mask = broadcast_tensors(
[batch_generated_ids, batch_generated_mask],
group_src=decoder_input_rank,
group=parallel_context.pp_pg,
)
batch.input_ids, batch.input_masks = broadcast_tensors(
[batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg
)
# Flush the store to release memory
state.store.flush()
assert len(state.store) == 0
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
assert (
batch_generated_ids.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_mask.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_mask.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_ids.shape[1] == batch_generated_mask.shape[1]
), f"Sequence length needs to match {batch_generated_ids.shape[1]} != {batch_generated_mask.shape[0]}"
for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)):
# TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
input_ids = batch.input_ids[i]
input_mask = batch.input_masks[i]
yield GenerationOutput(
input_ids=input_ids[input_mask],
generation_ids=generated_ids[generated_mask],
)
else:
yield GenerationOutput(
input_ids=TensorPointer(group_rank=decoder_input_rank),
generation_ids=TensorPointer(group_rank=decoder_input_rank),
)
@torch.inference_mode()
def decode_tokenized(
input_ids: torch.Tensor,
input_mask: torch.Tensor,
model: LlamaModel,
parallel_context: ParallelContext,
generation_config: GenerationArgs,
max_micro_batch_size: int,
max_new_tokens: int,
returns_logits: Optional[bool] = False,
) -> Generator[GenerationOutput, None, None]:
"""We assume the following:
- Everyone receives ALL the input text. # TODO @thomasw21: technically only specific ranks need to receive input.
- Only a specific rank will output the generated text_ids as `torch.Tensor`, the others return a `TensorPointer`. # TODO @thomasw21: Maybe all ranks should return the text.
- We assume that within a model replica, the inputs are already synchronized.
"""
if returns_logits:
raise NotImplementedError("return_logits is not implemented yet")
if generation_config:
if isinstance(generation_config.sampler, str):
sampler_type = SamplerType(generation_config.sampler.upper())
else:
sampler_type = generation_config.sampler
else:
sampler_type = SamplerType.GREEDY
decoder_input_rank, decoder_logit_rank = get_min_max_rank(module=model)
# Compute flag
is_decoder_input_rank = dist.get_rank(parallel_context.pp_pg) == decoder_input_rank
is_decoder_logit_rank = dist.get_rank(parallel_context.pp_pg) == decoder_logit_rank
max_nb_microbatches = decoder_logit_rank - decoder_input_rank + 1
# TODO @thomasw21: Fix this as we shouldn't get P2P like that
p2p = model.p2p
# That's annoying but I need this as soon as there's a change communication "cross"
pipeline_state = PipelineEvalBatchState()
with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state):
# We query the first `pipeline_size` batches
for batches in chunks(
iterable=micro_splitter(
input_ids,
input_mask,
max_micro_batch_size=max_micro_batch_size,
parallel_context=parallel_context,
input_rank=decoder_input_rank,
),
chunk_size=max_nb_microbatches,
):
if len(batches) == 0:
# It means we're out of element
return
# Number of micro batches
number_states_in_buffer = len(batches)
# Otherwise the pipelining doesn't work
assert number_states_in_buffer <= max_nb_microbatches
is_max_nb_microbatches = number_states_in_buffer == max_nb_microbatches
# Initialize decoder states
decoder_states: Iterable[GenerationStates] = (
GenerationStates(
new_input_ids=batch.input_ids,
new_input_mask=batch.input_masks,
store=Store(),
generation_ids=[batch.input_ids],
generation_mask=[batch.input_masks],
)
for batch in batches
)
for generation_iter in range(max_new_tokens):
all_new_decoder_input_ids_and_mask_same_rank: List[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
] = []
new_decoder_states: List[GenerationStates] = []
for state_id, state in enumerate(decoder_states):
new_decoder_states.append(state)
# Get the new logits
with attach_store(model=model, store=state.store):
# transpose: [sequence_length, batch_size, vocab_size] -> [batch_size, sequence_length, vocab_size]
sharded_logits = model(
input_ids=state.new_input_ids,
input_mask=state.new_input_mask,
)
if isinstance(sharded_logits, torch.Tensor):
sharded_logits = sharded_logits.transpose(0, 1)
# Communicate
# TODO @thomasw21: Make a diagram to show how this works
nb_send: int = 0
if is_decoder_input_rank:
if is_max_nb_microbatches:
if generation_iter == 0:
if state_id == number_states_in_buffer - 1:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if number_states_in_buffer - 1 == state_id or generation_iter == 0:
# Send everything
nb_send = len(pipeline_state.microbatches_activations_to_send)
else:
# `2` is because we receive decoder_ids AND decoder_mask from last rank
nb_send = len(pipeline_state.microbatches_activations_to_send) - 2
else:
if state_id == number_states_in_buffer - 1:
if not is_max_nb_microbatches:
nb_send = len(pipeline_state.microbatches_activations_to_send)
for _ in range(nb_send):
pipeline_state.run_communication()
if is_decoder_logit_rank:
assert isinstance(sharded_logits, torch.Tensor)
# run a logit chooser.
if sampler_type == SamplerType.GREEDY:
sampler = GreedySampler(pg=parallel_context.tp_pg)
elif sampler_type == SamplerType.TOP_K:
sampler = TopKSampler(
pg=parallel_context.tp_pg,
k=generation_config.top_k,
temperature=generation_config.temperature,
)
elif sampler_type == SamplerType.TOP_P:
sampler = TopPSampler(
pg=parallel_context.tp_pg,
p=generation_config.top_p,
temperature=generation_config.temperature,
)
elif sampler_type == SamplerType.BASIC:
sampler = BasicSampler(pg=parallel_context.tp_pg)
else:
raise NotImplementedError(f"Sampler type {sampler_type} is not implemented")
new_decoder_input_ids = sampler(sharded_logits=sharded_logits[:, -1, :])
# TODO @thomasw21: Handle this correctly, ie from some point after <eos> this should only generate masked tokens
# TODO @thomasw21: Actually I can probably build this thing on the next device directly. Will save some communication
new_decoder_input_mask = torch.ones(
size=(new_decoder_input_ids.shape[0], 1),
dtype=torch.bool,
device=new_decoder_input_ids.device,
)
# TODO @thomasw21: We need to have stop condition.
# broadcast new_tokens to everyone
if decoder_input_rank == decoder_logit_rank:
# It's the same rank so no need to do anything too fancy
all_new_decoder_input_ids_and_mask_same_rank.append(
(new_decoder_input_ids, new_decoder_input_mask)
)
else:
pipeline_state.register_send_activation(
new_decoder_input_ids, to_rank=decoder_input_rank, p2p=p2p
)
pipeline_state.register_send_activation(
new_decoder_input_mask, to_rank=decoder_input_rank, p2p=p2p
)
if not is_max_nb_microbatches and state_id == number_states_in_buffer - 1:
# Send new_decoder_input_ids AND new_decoder_input_ids
pipeline_state.run_communication()
pipeline_state.run_communication()
else:
assert isinstance(sharded_logits, TensorPointer)
all_new_decoder_input_ids_and_mask: Iterable[
Tuple[Union[torch.LongTensor, TensorPointer], Union[torch.BoolTensor, TensorPointer]]
]
if is_decoder_input_rank:
# We receive the tensor from other ranks unless `decoder_input_rank` == `decoder_logit_rank` in which case `all_new_decoder_input_ids` is already populated.
if decoder_input_rank == decoder_logit_rank:
# `all_new_decoder_input_ids_and_mask_same_rank` is already populated. Since `decoder_input_rank` and `decoder_logit_rank` are the same, there's no need to communicate as we can just store the new input_ids in a list.
assert len(all_new_decoder_input_ids_and_mask_same_rank) == number_states_in_buffer
all_new_decoder_input_ids_and_mask = all_new_decoder_input_ids_and_mask_same_rank
else:
def generator():
for _ in range(number_states_in_buffer):
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
pipeline_state.register_recv_activation(from_rank=decoder_logit_rank, p2p=p2p)
while len(pipeline_state.activations_buffer) < 2:
pipeline_state.run_communication()
new_decoder_input_ids = pipeline_state.activations_buffer.popleft()
new_decoder_input_mask = pipeline_state.activations_buffer.popleft()
yield new_decoder_input_ids, new_decoder_input_mask
all_new_decoder_input_ids_and_mask = iter(generator())
else:
all_new_decoder_input_ids_and_mask = (
(TensorPointer(group_rank=decoder_input_rank), TensorPointer(group_rank=decoder_input_rank))
for _ in range(number_states_in_buffer)
)
# Create new decoder states
decoder_states = (
GenerationStates(
new_input_ids=new_decoder_input_ids_and_mask[0],
new_input_mask=new_decoder_input_ids_and_mask[1],
store=state.store,
generation_ids=state.generation_ids + [new_decoder_input_ids_and_mask[0]],
generation_mask=state.generation_mask + [new_decoder_input_ids_and_mask[1]],
)
for state, new_decoder_input_ids_and_mask in zip(
new_decoder_states, all_new_decoder_input_ids_and_mask
)
)
# Flush communication
for _ in range(
max(
len(pipeline_state.microbatches_activations_to_send),
len(pipeline_state.microbatches_activations_to_recv),
)
):
pipeline_state.run_communication()
assert len(pipeline_state.microbatches_activations_to_send) == 0
assert len(pipeline_state.microbatches_activations_to_recv) == 0
# Yield result
decoder_states = list(decoder_states)
for state, batch in zip(decoder_states, batches):
if is_decoder_input_rank:
assert all(isinstance(elt, torch.Tensor) for elt in state.generation_ids)
batch_generated_ids = torch.cat(state.generation_ids, dim=-1)
batch_generated_mask = torch.cat(state.generation_mask, dim=-1)
else:
assert all(isinstance(elt, TensorPointer) for elt in state.generation_ids)
batch_generated_ids = TensorPointer(group_rank=decoder_input_rank)
batch_generated_mask = TensorPointer(group_rank=decoder_input_rank)
# Broadcast all data
batch_generated_ids, batch_generated_mask = broadcast_tensors(
[batch_generated_ids, batch_generated_mask],
group_src=decoder_input_rank,
group=parallel_context.pp_pg,
)
batch.input_ids, batch.input_masks = broadcast_tensors(
[batch.input_ids, batch.input_masks], group_src=decoder_input_rank, group=parallel_context.pp_pg
)
# Flush the store to release memory
state.store.flush()
assert len(state.store) == 0
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
assert (
batch_generated_ids.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_ids.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_mask.shape[0] == batch.input_ids.shape[0]
), f"Batch size needs to match {batch_generated_mask.shape[0]} != {batch.input_ids.shape[0]}"
assert (
batch_generated_ids.shape[1] == batch_generated_mask.shape[1]
), f"Sequence length needs to match {batch_generated_ids.shape[1]} != {batch_generated_mask.shape[0]}"
for i, (generated_ids, generated_mask) in enumerate(zip(batch_generated_ids, batch_generated_mask)):
# TODO @thomasw21: We could actually have all ranks return the output, since it's been already broadcasted
if dist.get_rank(parallel_context.pp_pg) == decoder_input_rank:
input_ids = batch.input_ids[i]
input_mask = batch.input_masks[i]
yield GenerationOutput(
input_ids=input_ids[input_mask],
generation_ids=generated_ids[generated_mask],
)
else:
yield GenerationOutput(
input_ids=TensorPointer(group_rank=decoder_input_rank),
generation_ids=TensorPointer(group_rank=decoder_input_rank),
)
# Distributed utilities
def broadcast_tensors(
tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None
) -> List[torch.Tensor]:
result = []
for tensor in tensors:
if dist.get_rank(group) == group_src:
assert isinstance(tensor, torch.Tensor)
meta = [
[
tensor.dtype,
tensor.requires_grad,
tensor.shape,
get_untyped_storage(tensor).size(),
tensor.stride(),
tensor.is_contiguous(),
tensor.storage_offset(),
]
]
else:
assert isinstance(tensor, TensorPointer)
meta = [None]
dist.broadcast_object_list(meta, src=get_global_rank(group_rank=group_src, group=group), group=group)
dtype, requires_grad, shape, untyped_storage_size, stride, is_contiguous, storage_offset = meta[0]
meta = P2PTensorMetaData(
dtype=dtype,
requires_grad=requires_grad,
shape=shape,
untyped_storage_size=untyped_storage_size,
stride=stride,
is_contiguous=is_contiguous,
storage_offset=storage_offset,
)
if dist.get_rank(group) != group_src:
tensor = meta.create_empty_storage(device=torch.device("cuda"))
else:
tensor = view_as_contiguous(tensor)
dist.broadcast(tensor, src=get_global_rank(group_rank=group_src, group=group), group=group)
# Set shape and stride
tensor = tensor.as_strided(size=tuple(meta.shape), stride=tuple(meta.stride))
result.append(tensor)
return result
import collections
import contextlib
from torch import nn
class Store(collections.defaultdict):
"""
We use the store to locally store on gpu some states so that we don't have to communicate.
This is useful at inference if we don't want to recompute kv_cache for example, or that we don't want to communicate it through the pipeline
"""
def __init__(self):
super().__init__(dict)
def flush(self):
# TODO @thomasw21: There's probably a simpler way than doing this.
for key in list(self.keys()):
del self[key]
class AttachableStore:
def _attach_store(self, store: Store):
assert not hasattr(self, "_store"), "You can't assign a store when there's already one attached"
self._store = store
def _detach_store(self):
delattr(self, "_store")
def get_local_store(self):
if hasattr(self, "_store"):
if isinstance(self, nn.Module):
assert self.training is False, "Store is used only in evaluation mode"
return self._store[id(self)]
else:
return None
@contextlib.contextmanager
def attach_store(model: nn.Module, store: Store):
list_module_containing_store = []
for module in model.modules():
if not isinstance(module, AttachableStore):
continue
module._attach_store(store)
list_module_containing_store.append(module)
try:
yield
finally:
for module in list_module_containing_store:
module._detach_store()
from dataclasses import dataclass
from enum import Enum, auto
from typing import Sequence
import torch
from nanotron import distributed as dist
def all_gather_batches(in_tensor: torch.Tensor, in_split: Sequence[int], group: dist.ProcessGroup) -> torch.Tensor:
# All gather along first dimension, allow un-equal splits
out_tensor = torch.empty((sum(in_split),) + in_tensor.shape[1:], dtype=in_tensor.dtype, device=in_tensor.device)
out_split_list = list(torch.split(out_tensor, in_split, dim=0))
dist.all_gather(out_split_list, in_tensor, group=group)
return out_tensor
class SamplerType(Enum):
TOP_P = auto()
TOP_K = auto()
GREEDY = auto()
BASIC = auto()
class Sampler:
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@dataclass
class TopPSampler(Sampler):
pg: dist.ProcessGroup
p: float = 0.9
temperature: float = 1.0
filter_value: float = 0.0
min_tokens_to_keep: int = 1
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float) / self.temperature, dim=-1) # [batch_size, vocab_size]
# Sort the probs and their corresponding indices in descending order
sorted_probs, sorted_indices = torch.sort(probs, descending=False, dim=-1)
# Calculate the cumulative sum of the sorted probs
# the bfloat16 type is not accurate enough for the cumulative sum
cumulative_probs = torch.cumsum(sorted_probs, dim=-1, dtype=torch.float) # [batch_size, vocab_size]
# Find the smallest set of indices for which the cumulative probability mass exceeds p
sorted_indices_to_remove = cumulative_probs <= (1 - self.p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# Construct the probability mask for original indices
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
filter_probs = probs.masked_fill(indices_to_remove, self.filter_value)
sampled_indices = torch.multinomial(filter_probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class GreedySampler(Sampler):
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local max logit and its index
# Note that max is deterministic, and always takes the first one.
max_values, max_indices = sharded_logits.max(dim=-1, keepdim=True) # [batch_size, 1]
# Add offset to the max indices
# TODO: We're assuming that TensorColumnLinear shards in a specific manner, i.e. rank 0 gets the first.
# It might require us to expose something from TensorColumnLinear.
max_indices = max_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather max logits and their indices from all shards, tp_rank first
max_values_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_values.dtype,
device=max_values.device,
)
max_indices_out_mat = torch.empty(
(total_out_size, 1),
dtype=max_indices.dtype,
device=max_indices.device,
)
local_max_values_in = list(torch.split(max_values, in_split, dim=0))
local_max_indices_in = list(torch.split(max_indices, in_split, dim=0))
local_max_values_out = list(torch.split(max_values_out_mat, out_split, dim=0))
local_max_indices_out = list(torch.split(max_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_max_values_out, local_max_values_in, group=self.pg)
dist.all_to_all(local_max_indices_out, local_max_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_max_values = torch.cat(local_max_values_out, dim=-1) # [sharded_batch_size, num_shards]
sharded_max_indices = torch.cat(local_max_indices_out, dim=-1) # [sharded_batch_size, num_shards]
# Find global max logit across all shards
# Note that max is deterministic, and always takes the first one.
# [sharded_batch_size, 1]
_global_max_values, global_max_indices = sharded_max_values.max(dim=-1, keepdim=True)
# Select the corresponding token index from the offsetted gathered indices
sharded_selected_tokens = sharded_max_indices.gather(1, global_max_indices)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sharded_selected_tokens, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class TopKSampler(Sampler):
pg: dist.ProcessGroup
k: int = 50
temperature: float = 1.0
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local top-k logits and their indices
local_top_k_values, local_top_k_indices = torch.topk(sharded_logits, self.k, dim=-1)
# Add offset to the indices
local_top_k_indices = local_top_k_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split local_top_k_values into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# The last shard could be smaller than shard_batch_size
local_top_k_values_in = list(torch.split(local_top_k_values, in_split, dim=0))
local_tok_k_indices_in = list(torch.split(local_top_k_indices, in_split, dim=0))
# Prepare tensors for all-to-all operation
# Gather top-k logits and their indices from all shards, tp_rank first
top_k_values_out_mat = torch.empty(
(total_out_size,) + local_top_k_values.shape[1:],
dtype=local_top_k_values.dtype,
device=local_top_k_values.device,
)
top_k_indices_out_mat = torch.empty(
(total_out_size,) + local_top_k_indices.shape[1:],
dtype=local_top_k_indices.dtype,
device=local_top_k_indices.device,
)
local_top_k_values_out = list(torch.split(top_k_values_out_mat, out_split, dim=0))
local_top_k_indices_out = list(torch.split(top_k_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_top_k_values_out, local_top_k_values_in, group=self.pg)
dist.all_to_all(local_top_k_indices_out, local_tok_k_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_local_top_k_values = torch.cat(local_top_k_values_out, dim=-1) # [sharded_batch_size, k * num_shards]
sharded_local_top_k_indices = torch.cat(
local_top_k_indices_out, dim=-1
) # [sharded_batch_size, k * num_shards]
# Select global top-k from the gathered top-k, now the top-k is across all vocab, batch_size is sharded
sharded_top_k_values, sharded_top_k_indices = torch.topk(
sharded_local_top_k_values, self.k, dim=-1
) # [sharded_batch_size, k]
# Select corresponding indices from the gathered indices
sharded_top_k_indices = sharded_local_top_k_indices.gather(
-1, sharded_top_k_indices
) # [sharded_batch_size, k]
# Apply temperature and compute softmax probabilities
probs = torch.softmax(sharded_top_k_values.to(dtype=torch.float) / self.temperature, dim=-1)
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1) # [sharded_batch_size]
# Select the corresponding token index from the global top-k indices
new_decoder_input_ids = sharded_top_k_indices.gather(-1, sampled_indices) # [sharded_batch_size]
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(new_decoder_input_ids, in_split, group=self.pg)
return gathered_new_decoder_input_ids
@dataclass
class BasicSampler(Sampler):
"""Basic sampler that samples from the full vocab according to the logits."""
pg: dist.ProcessGroup
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
# We will cross batch and vocab shards to sample from the full vocab and a part of the batch
# (right now logits are sharded on vocab and batch, so we need to do all-to-all)
batch_size, vocab_per_shard = sharded_logits.shape
# Split max_values/max_indices into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# Prepare tensors for all-to-all operation
# Gather logits from all vocab shards but shard on batch, tp_rank first
sharded_logits_out = torch.empty(
(total_out_size, vocab_per_shard),
dtype=sharded_logits.dtype,
device=sharded_logits.device,
) # [pg_size * sharded_batch_size, vocab_per_shard]
local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))
dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)
logits = torch.cat(local_sharded_logits_out, dim=-1) # [sharded_batch_size, vocab_size]
probs = torch.softmax(logits.to(dtype=torch.float), dim=-1) # [batch_size, vocab_size]
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1)
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(sampled_indices, in_split, group=self.pg)
return gathered_new_decoder_input_ids
import contextlib
import csv
import gc
import math
import os
import time
from datetime import datetime
from functools import partial
from math import ceil
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import LambdaLR
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.distributed import ProcessGroup
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer
from nanotron.optim.gradient_accumulator import (
FP32GradBucketManager,
FP32GradientAccumulator,
GradientAccumulator,
get_fp32_accum_hook,
)
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.optim.optimizer_from_gradient_accumulator import (
OptimizerFromGradientAccumulator,
)
from nanotron.optim.zero import ZeroDistributedOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
)
from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod
from nanotron.serialize.metadata import TrainingMetadata
logger = logging.get_logger(__name__)
def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size_divisible_by: int):
"""Pad vocab size so it is divisible by pg_size * make_vocab_size_divisible_by."""
multiple = make_vocab_size_divisible_by * pg_size
after = int(ceil(orig_vocab_size / multiple) * multiple)
if after != orig_vocab_size:
log_rank(
f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})",
logger=logger,
level=logging.WARNING,
rank=0,
)
return after
def init_random_states(parallel_config: ParallelismArgs, tp_pg: ProcessGroup):
# Get synchronized random states
if parallel_config is None or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
random_states = RandomStates(
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=tp_pg)}
)
else:
# NOTE: We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
return random_states
def lr_scheduler_builder(optimizer: Optimizer, lr_scheduler_args: LRSchedulerArgs, total_training_steps: int):
if lr_scheduler_args.lr_decay_steps is None:
lr_decay_steps = total_training_steps
if lr_scheduler_args.lr_warmup_steps is not None:
lr_decay_steps -= lr_scheduler_args.lr_warmup_steps
if lr_scheduler_args.lr_decay_starting_step is not None:
lr_decay_steps -= lr_scheduler_args.lr_decay_starting_step
else:
lr_decay_steps = lr_scheduler_args.lr_decay_steps
if lr_scheduler_args.lr_decay_starting_step is None:
if lr_scheduler_args.lr_warmup_steps is not None:
lr_decay_starting_step = lr_scheduler_args.lr_warmup_steps
else:
lr_decay_starting_step = 0
else:
lr_decay_starting_step = lr_scheduler_args.lr_decay_starting_step
def lr_lambda(current_step: int, initial_lr: float):
"""
current_step: current training step
initial_lr: the learning rate of a parameter group
More info on initial_lr:
And in standard parameterization, lr_lambda only takes a single learning rate.
But in µTransfer, each parameter has a custom learning rate (custom_lr = lr_scheduler_args.learning_rate * scaling_factor),
so each parameter group has a custom lr_lambda function.
LR Scheduling function, it has from 2 up to 4 phases:
- warmup,
- optional: constant (if lr_decay_starting_step is set)
- decay
- optional: constant (if lr_decay_steps and/or lr_decay_starting_step are set)
Warmup starts at lr=0 and ends at `lr=lr`
Then it stays constant at lr if lr_decay_starting_step is set and larger than lr_warmup_steps
Then it decays until `min_decay_lr` for lr_decay_steps if set, else: (total_training_steps - lr_warmup_steps or lr_decay_starting_step)
Then it stays constant at min_decay_lr if lr_decay_starting_step is set and total_training_steps is larger)
"""
# No warmup or decay
if lr_scheduler_args.lr_warmup_steps == 0 and lr_decay_steps == 0:
return initial_lr
# Warmup phase
elif lr_scheduler_args.lr_warmup_style is not None and current_step <= lr_scheduler_args.lr_warmup_steps:
if lr_scheduler_args.lr_warmup_style == "linear":
lmbda = initial_lr * current_step / max(lr_scheduler_args.lr_warmup_steps, 1)
elif lr_scheduler_args.lr_warmup_style == "constant":
lmbda = lr_scheduler_args.learning_rate
else:
raise ValueError(f"Unknown warmup style {lr_scheduler_args.lr_warmup_style}")
# Optional constant phase at learning_rate
elif current_step < lr_decay_starting_step:
lmbda = initial_lr
# Decay phase
elif lr_scheduler_args.lr_decay_style is not None and current_step < lr_decay_starting_step + lr_decay_steps:
if lr_scheduler_args.lr_decay_style == "cosine":
lmbda = (
lr_scheduler_args.min_decay_lr
+ (initial_lr - lr_scheduler_args.min_decay_lr)
* (1 + math.cos(math.pi * (current_step - lr_decay_starting_step) / lr_decay_steps))
/ 2
)
elif lr_scheduler_args.lr_decay_style == "linear":
lmbda = (
lr_scheduler_args.min_decay_lr
+ (initial_lr - lr_scheduler_args.min_decay_lr)
* (lr_decay_steps - (current_step - lr_decay_starting_step))
/ lr_decay_steps
)
elif lr_scheduler_args.lr_decay_style == "1-sqrt":
lmbda = lr_scheduler_args.min_decay_lr + (initial_lr - lr_scheduler_args.min_decay_lr) * (
1 - math.sqrt((current_step - lr_decay_starting_step) / lr_decay_steps)
)
else:
raise ValueError(f"Unknown decay style {lr_scheduler_args.lr_decay_style}")
# Optional constant phase at min_decay_lr
else:
lmbda = lr_scheduler_args.min_decay_lr
lmbda /= initial_lr # Normalization for pytorch
return lmbda
def get_lr_lambda_for_param_group(lr: float):
return partial(lr_lambda, initial_lr=lr)
# NOTE: get learning rate scheduler for each param group
lr_lambdas = []
for param_group in optimizer.get_base_optimizer().param_groups:
lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"]))
assert len(lr_lambdas) == len(
optimizer.get_base_optimizer().param_groups
), "Custom learning rate functions dont match the number of param groups"
log_rank(
f"[Optimizer Building] There are total {len(lr_lambdas)} custom learning rate function for parameter groups",
logger=logger,
level=logging.DEBUG,
)
lr_scheduler = LambdaLR(optimizer.get_base_optimizer(), lr_lambda=lr_lambdas)
return lr_scheduler
def get_custom_weight_decay_for_named_parameters(
named_parameters: Iterable[Tuple[str, torch.Tensor]],
model: NanotronModel,
module_id_to_prefix: Dict[int, str],
weight_decay: float,
) -> List[Dict[str, Any]]:
"""
Apply weight decay to all parameters except the ones that are in the named_param_without_weight_decay list.
"""
named_param_groups_with_custom_weight_decay = []
exclude_named_params = model.get_named_params_without_weight_decay()
for name, param in named_parameters:
if param.is_tied:
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
pass
if any(name.endswith(substring) for substring in exclude_named_params):
named_param_groups_with_custom_weight_decay.append({"named_params": [(name, param)], "weight_decay": 0.0})
else:
named_param_groups_with_custom_weight_decay.append(
{"named_params": [(name, param)], "weight_decay": weight_decay}
)
log_rank(
f"[Optimizer Building] Creating {len(named_param_groups_with_custom_weight_decay)} param groups with custom weight decay",
logger=logger,
level=logging.DEBUG,
)
return named_param_groups_with_custom_weight_decay
def get_custom_lr_for_named_parameters(
parametrization_method: ParametrizationMethod,
lr: float,
named_parameters: Iterable[Tuple[str, torch.Tensor]],
model: NanotronModel,
) -> List[Dict[str, Any]]:
"""
Get custom learning rates for parameters based on the parametrization method.
NOTE: in some paramtrization methods, we use a global learning rate for all parameters,
in others we use a custom learning rate for each parameter (eg: spectral µTransfer).
"""
assert parametrization_method in [ParametrizationMethod.SPECTRAL_MUP, ParametrizationMethod.STANDARD]
lr_mapper_cls = (
LearningRateForSpectralMup
if parametrization_method == ParametrizationMethod.SPECTRAL_MUP
else LearningRateForSP
)
log_rank(
f"[Optimizer Building] Using {lr_mapper_cls.__name__} as learning rate",
logger=logger,
level=logging.INFO,
rank=0,
)
# NOTE: since in the case of pipeline parallelism, each rank only has a subset of the model
# so we only get the parameters that are in the current rank
learning_rate_mapper = lr_mapper_cls(names_to_modules=model.named_modules_in_pp_rank, lr=lr)
named_param_groups_with_custom_lr = []
for (
name,
param,
) in named_parameters:
learning_rate = learning_rate_mapper.get_lr(name, param)
assert isinstance(learning_rate, float), f"Expected a float, got {learning_rate} for parameter {name}"
named_param_groups_with_custom_lr.append({"named_params": [(name, param)], "lr": learning_rate})
log_rank(
f"[Optimizer Building] Creating {len(named_param_groups_with_custom_lr)} param groups with custom learning rates",
logger=logger,
level=logging.DEBUG,
)
return named_param_groups_with_custom_lr
def merge_named_param_groups(
named_param_groups_with_lr: List[Dict[str, Any]],
named_param_groups_with_weight_decay: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
assert len(named_param_groups_with_lr) == len(
named_param_groups_with_weight_decay
), "Named param groups don't match in length"
named_param_groups = []
for group_with_lr, group_with_weight_decay in zip(
named_param_groups_with_lr, named_param_groups_with_weight_decay
):
assert group_with_lr["named_params"] == group_with_weight_decay["named_params"]
named_param_groups.append(
{
"named_params": group_with_lr["named_params"],
"lr": group_with_lr["lr"],
"weight_decay": group_with_weight_decay["weight_decay"],
}
)
return named_param_groups
def init_optimizer_and_grad_accumulator(
parametrization_method: ParametrizationMethod,
model: nn.Module,
optimizer_args: OptimizerArgs,
parallel_context: ParallelContext,
) -> Tuple[BaseOptimizer, GradientAccumulator]:
# Unwrap DDP
unwrapped_model: NanotronModel = model.module if isinstance(model, DistributedDataParallel) else model
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in unwrapped_model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(unwrapped_model)] = ""
named_parameters = list(unwrapped_model.get_named_params_with_correct_tied())
named_param_groups_with_lr = get_custom_lr_for_named_parameters(
parametrization_method=parametrization_method,
named_parameters=named_parameters,
model=unwrapped_model,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
)
named_param_groups_with_weight_decay = get_custom_weight_decay_for_named_parameters(
named_parameters=named_parameters,
model=unwrapped_model,
module_id_to_prefix=module_id_to_prefix,
weight_decay=optimizer_args.weight_decay,
)
named_param_groups = merge_named_param_groups(named_param_groups_with_lr, named_param_groups_with_weight_decay)
# Basic optimizer builder
def basic_optimizer_builder(named_param_groups):
optimizer = None
if optimizer_args.optimizer_factory.name == "adamW":
def optimizer(param_groups):
return torch.optim.AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)
elif optimizer_args.optimizer_factory.name == "sgd":
def optimizer(param_groups):
return torch.optim.SGD(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
)
else:
raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")
return NamedOptimizer(
named_params_or_groups=named_param_groups,
optimizer_builder=optimizer,
)
optimizer_builder = basic_optimizer_builder
# Gradient accumulator builder
grad_accumulator: Optional[GradientAccumulator] = None
if optimizer_args.accumulate_grad_in_fp32:
# TODO @thomasw21: Make an optimizer builder system, instead of doing everything in functional manner
def grad_optimizer_builder(named_param_groups):
result = OptimizerFromGradientAccumulator(
gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(
named_parameters=named_params,
grad_buckets_named_params=named_parameters,
),
named_params_or_groups=named_param_groups,
optimizer_builder=basic_optimizer_builder,
)
# TODO @thomasw21: get better API to get the grad_accumulator
nonlocal grad_accumulator
grad_accumulator = result.gradient_accumulator
return result
optimizer_builder = grad_optimizer_builder
if optimizer_args.zero_stage > 0:
# Build optimizer
optimizer = ZeroDistributedOptimizer(
named_params_or_groups=named_param_groups,
# TODO @thomasw21: We need a better API for gradient accumulation/zero etc ...
optimizer_builder=optimizer_builder,
dp_pg=parallel_context.dp_pg,
)
# SANITY CHECK: assert that optimizer's named_params point to model's params (check only the first one)
if (
len(optimizer.zero_named_param_groups) > 0
and len(optimizer.zero_named_param_groups[0]["named_params"]) > 0
):
optim_model_param_name, optim_model_param = optimizer.zero_named_param_groups[0]["named_params"][0]
if isinstance(model, DistributedDataParallel):
optim_model_param_name = f"module.{optim_model_param_name}"
param = model.get_parameter(optim_model_param_name)
assert param.data_ptr() == optim_model_param.data_ptr()
else:
# Build optimizer
optimizer = optimizer_builder(named_param_groups)
if grad_accumulator is not None and optimizer_args.zero_stage > 0:
# There's a way to only require to reduce_scatter the gradients instead of all_reducing
# In order to do so I need to pass which segments of each parameter should be reduced on which dp rank.
assert isinstance(optimizer, ZeroDistributedOptimizer)
param_name_to_dp_rank_offsets = optimizer.param_name_to_dp_rank_offsets
assert isinstance(grad_accumulator, FP32GradientAccumulator)
grad_accumulator.assign_param_offsets(
dp_rank=dist.get_rank(parallel_context.dp_pg),
param_name_to_offsets=param_name_to_dp_rank_offsets,
)
# Register DDP hook to make fp32 grad accumulation work
if isinstance(model, DistributedDataParallel) and grad_accumulator is not None:
assert isinstance(grad_accumulator, FP32GradientAccumulator)
model.register_comm_hook(
state=FP32GradBucketManager(
dp_pg=parallel_context.dp_pg,
accumulator=grad_accumulator,
param_id_to_name={
id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
if param.is_tied
else name
for name, param in unwrapped_model.named_parameters()
},
),
hook=get_fp32_accum_hook(
reduce_scatter=optimizer.inherit_from(ZeroDistributedOptimizer), reduce_op=dist.ReduceOp.AVG
),
)
return optimizer, grad_accumulator
def test_equal_dict(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> None:
"""Raise if doesn't match."""
if sub_paths is None:
sub_paths = []
first_keys = set(first.keys())
second_keys = set(second.keys())
assert first_keys == second_keys, f"Keys don't match.\nFirst: {first_keys}\nSecond: {second_keys}"
for key in first_keys:
first_elt = first[key]
second_elt = second[key]
if isinstance(first_elt, dict):
assert isinstance(second_elt, dict), f"{first_elt} doesn't match {second_elt}"
test_equal_dict(first_elt, second_elt, sub_paths=sub_paths + [str(key)])
elif isinstance(first_elt, torch.Tensor):
assert isinstance(second_elt, torch.Tensor), f"{first_elt} doesn't match {second_elt}"
torch.testing.assert_close(
first_elt,
second_elt,
atol=0.0,
rtol=0.0,
msg=lambda msg: f"tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}",
)
else:
assert (
first_elt == second_elt
), f"{first_elt} doesn't match {second_elt} at key {'.'.join(sub_paths + [str(key)])}"
def get_profiler(config: Config):
if config.profiler is not None:
if config.profiler.profiler_export_path is not None:
on_trace_ready = tensorboard_trace_handler(
config.profiler.profiler_export_path / datetime.now().strftime("%Y%m%d-%H%M%S")
)
else:
on_trace_ready = None
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1, skip_first=3),
on_trace_ready=on_trace_ready,
# record_shapes=True,
# profile_memory=True,
with_stack=True,
)
else:
prof = contextlib.nullcontext()
return prof
def get_all_comps(n: int) -> List[List[List[int]]]:
"""Return a 3D numpy array with a series of pairs to test latency/bandwidth between:
This basically make a square matrix from the triangle of pair-to-pair comparisons
[[[0 1]
[2 3]]
[[0 2]
[1 3]]
[[0 3]
[1 2]]]
"""
# n: power of two
if not ((n & (n - 1) == 0) and n != 0):
# every power of 2 has exactly 1 bit set to 1 (the bit in that number's log base-2 index).
# So when subtracting 1 from it, that bit flips to 0 and all preceding bits flip to 1.
# That makes these 2 numbers the inverse of each other so when AND-ing them, we will get 0 as the result
raise ValueError("n must be a power of two")
def op(lst, d=4, r=1):
lst = lst.reshape(-1, d)
lst[1::2] = np.roll(lst[1::2], r, axis=1)
return lst.T.reshape(-1)
x = np.array(list(range(n)))
comps = []
d = 1
while d < n:
for r in range(d):
comps.append(op(x, d=d, r=r).copy())
d *= 2
ret = np.stack(comps)
return ret.reshape(ret.shape[0], -1, 2).tolist()
def test_all_pair_to_pair(
parallel_context: ParallelContext, throughput_size: int, throughput_iters: int, only_node_to_node: bool = True
):
"""Test all pair-to-pair GPUs throughput
Args:
parallel_context: ParallelContext
throughput_size: size of the tensor to send
throughput_iters: number of warm-up iterations before testing the throughput
only_node_to_node: if True, only test node-to-node throughput
"""
comparisons = get_all_comps(parallel_context.world_pg.size())
wr = dist.get_rank(parallel_context.world_pg)
log_rank(
f"[TEST] Testing throughput between {comparisons}",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=0,
)
for j, comp in enumerate(comparisons):
dist.barrier(group=parallel_context.world_pg)
for i, (a, b) in enumerate(comp):
dist.barrier(group=parallel_context.world_pg)
if wr not in [a, b]:
continue
if only_node_to_node and (a % 8 != 0 or b % 8 != 0):
# We only check node-to-node throughput
continue
test_tensor = torch.zeros((int(throughput_size),), dtype=torch.uint8, device=torch.device("cuda"))
for k in range(throughput_iters):
pre = time.perf_counter()
torch.cuda.synchronize()
if wr == a:
dist.send(test_tensor, b, group=parallel_context.world_pg, tag=i + k)
elif wr == b:
dist.recv(test_tensor, a, group=parallel_context.world_pg, tag=i + k)
torch.cuda.synchronize()
duration = time.perf_counter() - pre
del test_tensor
gc.collect()
torch.cuda.empty_cache()
tput = (float(throughput_size) / duration) * 8 # *8 for gigabits/second
log_rank(
f"[TEST] {j, i, wr} Results throughput from {a} to {b}: {tput/1e9:.4f} Gbps",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=None,
)
log_rank(
"[TEST] All comparisons done",
logger=logger,
level=logging.WARNING,
group=parallel_context.world_pg,
rank=0,
)
def create_table_log(
config: Config,
parallel_context: ParallelContext,
model_tflops,
hardware_tflops,
tokens_per_sec,
bandwidth,
slurm_job_id,
):
return [
LogItem("job_id", slurm_job_id, "s"),
LogItem("name", config.general.run, "s"),
LogItem("nodes", math.ceil(parallel_context.world_pg.size() / torch.cuda.device_count()), "d"),
LogItem("seq_len", config.tokens.sequence_length, "d"),
LogItem("mbs", config.tokens.micro_batch_size, "d"),
LogItem("batch_accum", config.tokens.batch_accumulation_per_replica, "d"),
LogItem("gbs", config.global_batch_size, "d"),
LogItem("mTFLOPs", model_tflops, ".2f"),
LogItem("hTFLOPs", hardware_tflops, ".2f"),
LogItem("tok/s/gpu", tokens_per_sec / parallel_context.world_pg.size(), ".2f"),
LogItem("Bandwidth (GB/s)", bandwidth, ".2f"),
LogItem("Mem Alloc (GB)", torch.cuda.max_memory_allocated() / 1024**3, ".2f"),
LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"),
]
def create_table_output(table_log, column_widths):
header_row = "| " + " | ".join([item.tag.ljust(width) for item, width in zip(table_log, column_widths)]) + " |"
separator_row = "| " + " | ".join(["-" * width for width in column_widths]) + " |"
data_row = (
"| "
+ " | ".join(
[f"{item.scalar_value:{item.log_format}}".ljust(width) for item, width in zip(table_log, column_widths)]
)
+ " |"
)
return f"{header_row}\n{separator_row}\n{data_row}"
def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id):
if not os.path.exists(csv_filename):
os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
with open(csv_filename, mode="w") as fo:
writer = csv.writer(fo)
writer.writerow([item.tag for item in table_log])
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# elif model_tflops > 0:
# # replace line with same job_id
# with open(csv_filename, mode="r") as fi:
# lines = fi.readlines()
# with open(csv_filename, mode="w") as fo:
# writer = csv.writer(fo)
# for line in lines:
# if line.startswith(slurm_job_id):
# writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# else:
# fo.write(line)
else:
with open(csv_filename, mode="a") as fo:
writer = csv.writer(fo)
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
def log_throughput(
config: Config,
parallel_context: ParallelContext,
model_tflops=0,
hardware_tflops=0,
tokens_per_sec=0,
bandwidth=0,
):
slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A")
table_log = create_table_log(
config, parallel_context, model_tflops, hardware_tflops, tokens_per_sec, bandwidth, slurm_job_id
)
column_widths = [max(len(item.tag), len(f"{item.scalar_value:{item.log_format}}")) for item in table_log]
table_output = create_table_output(table_log, column_widths)
log_rank(
table_output,
logger=logger,
level=logging.INFO,
rank=0,
)
if dist.get_rank(parallel_context.world_pg) == 0:
write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id)
def compute_remain_train_steps_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, config: Config, metadata: TrainingMetadata
) -> int:
def is_last_stage():
sorted_stages = sorted(config.data_stages, key=lambda x: x.start_training_step)
return sorted_stages[-1].start_training_step == stage.start_training_step
def is_resume_from_training():
return metadata.last_train_step > 0
if is_last_stage() is True:
total_train_steps = config.tokens.train_steps
else:
next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None)
total_train_steps = next_stage.start_training_step
if metadata.last_train_step > stage.start_training_step:
# NOTE: if the last_train_step is larger than the start_training_step of the current stage,
# it means that the training has already passed this stage
# so there is no remaining steps
return 0
else:
last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step
return total_train_steps - last_train_steps
def get_consumed_train_samples_of_a_data_stage_from_ckp(
stage: DatasetStageArgs, metadata: TrainingMetadata
) -> Optional[int]:
start_training_step = stage.start_training_step
return next(
(s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step),
None,
)
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities. """
import logging
import os
import sys
from dataclasses import dataclass
from functools import lru_cache
from logging import (
CRITICAL,
DEBUG,
ERROR,
FATAL,
INFO,
NOTSET,
WARNING,
Formatter,
Logger,
)
from typing import TYPE_CHECKING, List, Optional, Union
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
if TYPE_CHECKING:
from nanotron.config import LoggingArgs
from nanotron.parallel import ParallelContext
log_levels = {
"debug": DEBUG,
"info": INFO,
"warning": WARNING,
"error": ERROR,
"critical": CRITICAL,
"fatal": FATAL,
"notset": NOTSET,
}
class NewLineStreamHandler(logging.StreamHandler):
"""
We want to apply formatter before each new line
https://stackoverflow.com/a/38458877
"""
def emit(self, record):
lines = record.msg.split("\n")
for line in lines:
record.msg = line
super().emit(record)
DEFAULT_HANDLER = NewLineStreamHandler()
DEFAULT_LOG_LEVEL = logging.WARNING
LIBRARY_NAME = __name__.split(".")[0]
def _get_default_logging_level():
"""
If NANOTRON_LOGGING_LEVEL env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to ``_default_log_level``
"""
env_level_str = os.getenv("NANOTRON_LOGGING_LEVEL", None)
if env_level_str:
if env_level_str in log_levels:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
f"Unknown option NANOTRON_LOGGING_LEVEL={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }"
)
return DEFAULT_LOG_LEVEL
def get_library_root_logger() -> Logger:
return get_logger(LIBRARY_NAME)
def _configure_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.addHandler(DEFAULT_HANDLER)
library_root_logger.setLevel(_get_default_logging_level())
def _reset_library_root_logger() -> None:
library_root_logger = get_library_root_logger()
library_root_logger.setLevel(logging.NOTSET)
def get_logger(name: Optional[str] = None, log_level: Optional[str] = None) -> Logger:
"""
Return a logger with the specified name.
"""
logger_already_exists = isinstance(logging.root.manager.loggerDict.get(name, None), Logger)
logger = logging.getLogger(name)
if logger_already_exists or name is None:
# if name is None we return root logger
return logger
# If the logger is in a `nanotron` module then we remove the capability to propagate
if LIBRARY_NAME == name.split(".", 1)[0]:
if log_level is not None:
logger.setLevel(log_level.upper())
elif LEVEL is not None:
logger.setLevel(LEVEL)
else:
logger.setLevel(_get_default_logging_level())
if HANDLER is not None:
logger.handlers.clear()
logger.addHandler(HANDLER)
logger.propagate = False
return logger
def get_verbosity() -> int:
"""
Return the current level for the Nanotron root logger as an int.
Returns:
:obj:`int`: The logging level.
.. note::
Nanotron has following logging levels:
- 50: ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- 40: ``nanotron.logging.ERROR``
- 30: ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- 20: ``nanotron.logging.INFO``
- 10: ``nanotron.logging.DEBUG``
"""
return get_library_root_logger().getEffectiveLevel()
LEVEL = None
def set_verbosity(verbosity: int) -> None:
"""
Set the verbosity level for the all `nanotron` loggers.
Args:
verbosity (:obj:`int`):
Logging level, e.g., one of:
- ``nanotron.logging.CRITICAL`` or ``nanotron.logging.FATAL``
- ``nanotron.logging.ERROR``
- ``nanotron.logging.WARNING`` or ``nanotron.logging.WARN``
- ``nanotron.logging.INFO``
- ``nanotron.logging.DEBUG``
"""
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
logger.setLevel(verbosity)
# We update all handles to be at the current verbosity as well.
for handle in logger.handlers:
handle.setLevel(verbosity)
global LEVEL
LEVEL = verbosity
HANDLER = None
def set_formatter(formatter: logging.Formatter) -> None:
"""
Set a new custom formatter as the current handler.
Note: it's important to first set level and then
:param formatter:
:return:
"""
handler = NewLineStreamHandler(sys.stdout)
handler.setFormatter(formatter)
handler.setLevel(get_verbosity())
handler.flush = sys.stderr.flush
all_nanotron_loggers = {
name: logger
for name, logger in logging.root.manager.loggerDict.items()
if isinstance(logger, Logger) and (name.startswith(f"{LIBRARY_NAME}.") or name == LIBRARY_NAME)
}
for name, logger in all_nanotron_loggers.items():
# We keep only a single handler
logger.handlers.clear()
logger.addHandler(handler)
global HANDLER
HANDLER = handler
def log_rank(
msg: str,
logger: Logger,
level: int,
group: Optional[dist.ProcessGroup] = None,
rank: Optional[int] = None,
**kwargs,
):
"""Log only if the current process is the rank specified."""
# Use default group is group is not provided
if group is None:
group = torch_dist.distributed_c10d._get_default_group()
# rank is None means everyone logs
if rank is None or dist.get_rank(group) == rank:
logger.log(level, msg, **kwargs)
@lru_cache(maxsize=None)
def warn_once(
msg: str, logger: Logger, group: Optional[dist.ProcessGroup] = None, rank: Optional[int] = None, **kwargs
):
log_rank(msg=msg, logger=logger, level=logging.WARNING, group=group, rank=rank, **kwargs)
def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str:
if abs(num) < 1:
return "{:.3g}".format(num)
SIZES = ["", "K", "M", "G", "T", "P", "E"]
num = float("{:.3g}".format(num))
magnitude = 0
i = 0
while abs(num) >= 1000 and i < len(SIZES) - 1:
magnitude += 1
num /= 1000.0 if not divide_by_1024 else 1024.0
i += 1
return "{}{}".format("{:f}".format(num).rstrip("0").rstrip("."), SIZES[magnitude])
def log_memory(logger: logging.Logger):
log_rank(
f" Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB."
f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB."
f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB",
logger=logger,
level=logging.INFO,
rank=0,
)
torch.cuda.reset_peak_memory_stats()
@dataclass
class LogItem:
tag: str
scalar_value: Union[float, int, str]
log_format: Optional[str] = None
@dataclass
class LoggerWriter:
global_step: int
def add_scalar(self, tag: str, scalar_value: Union[float, int], log_format=None) -> str:
if log_format == "human_format":
log_str = f"{tag}: {human_format(scalar_value)}"
else:
log_str = f"{tag}: {scalar_value:{log_format}}" if log_format is not None else f"{tag}: {scalar_value}"
return log_str
def add_scalars_from_list(self, log_entries: List[LogItem], iteration_step: int):
log_strs = [f"iteration: {iteration_step} / {self.global_step}"]
log_strs += [
self.add_scalar(log_item.tag, log_item.scalar_value, log_item.log_format) for log_item in log_entries
]
log_str = " | ".join(log_strs)
log_rank(log_str, logger=get_logger(__name__), level=logging.INFO)
def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelContext):
node_name = os.environ.get("SLURMD_NODENAME")
expert_parallel_log = (
f"|EXP={dist.get_rank(parallel_context.expert_pg)}" if parallel_context.expert_parallel_size > 1 else ""
)
formatter = Formatter(
fmt=f"%(asctime)s [%(levelname)s|DP={dist.get_rank(parallel_context.dp_pg)}|PP={dist.get_rank(parallel_context.pp_pg)}|"
f"TP={dist.get_rank(parallel_context.tp_pg)}{expert_parallel_log}{'|' + node_name if node_name else ''}]: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
log_level = log_levels[logging_level]
# main root logger
root_logger = get_logger()
root_logger.setLevel(log_level)
handler = NewLineStreamHandler(sys.stdout)
handler.setLevel(log_level)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# Nanotron
set_verbosity(log_level)
set_formatter(formatter=formatter)
def set_ranks_logging_level(parallel_context: ParallelContext, logging_config: "LoggingArgs"):
if dist.get_rank(parallel_context.world_pg) == 0:
if logging_config.log_level is not None:
set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
else:
if logging_config.log_level_replica is not None:
set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)
_configure_library_root_logger()
# flake8: noqa
from .base import DTypeInvariantTensor, NanotronModel, build_model, check_model_has_grad, init_on_device_and_dtype
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
if TYPE_CHECKING:
from nanotron.config import NanotronConfigs
from nanotron.parallel.parameters import NanotronParameter
logger = logging.get_logger(__name__)
class NanotronModel(nn.Module, metaclass=ABCMeta):
"""Abstract class for Nanotron models
We make the following assumptions:
- When building PP blocks, we assume that the modules order are in the same order as the forward pass."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.parallel_context: "ParallelContext"
self.config: "NanotronConfigs"
self.module_id_to_prefix: dict[int, str]
# Attributes defined when building the model
self.input_pp_rank: int
self.output_pp_rank: int
# Useful mapping to get param names
self.module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in self.named_modules()}
self.module_id_to_prefix[id(self)] = ""
def get_named_params_with_correct_tied(self) -> Iterator[Tuple[str, "NanotronParameter"]]:
"""Return named parameters with correct tied params names.
For example in the case of tied kv heads in MQA, we need to make sure tied params names are correct."""
def params_gen():
for name, param in self.named_parameters():
if param.is_tied:
yield (
param.get_tied_info().get_full_name_from_module_id_to_prefix(
module_id_to_prefix=self.module_id_to_prefix
),
param,
)
else:
yield name, param
yield from params_gen()
@abstractmethod
def init_model_randomly(self, config):
...
def tie_custom_params(self) -> None:
"""Tie custom parameters. For example for MQA marks kv heads as tied."""
pass
def get_embeddings_lm_head_tied_names(self) -> list[str]:
"""Returns the names of the embeddings and lm_head weights that are tied together. Returns empty list if not tied.
Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
"""
return []
def get_named_params_without_weight_decay(self) -> List[str]:
"""Return a list of named parameters that should not have weight decay applied to them."""
return []
def before_tbi_sanity_checks(self) -> None:
pass
def after_tbi_sanity_checks(self) -> None:
pass
def before_optim_step_sanity_checks(self) -> None:
pass
def after_optim_step_sanity_checks(self) -> None:
pass
def log_modules(self, level: int = logging.DEBUG, group: Optional[ProcessGroup] = None, rank: int = 0):
assert hasattr(self, "parallel_context"), "`NanotronModel` needs to have a `parallel_context` attribute"
for name, module in self.named_modules():
if not isinstance(module, PipelineBlock):
continue
log_rank(
f"module_name: {name} | PP: {module.rank}/{self.parallel_context.pp_pg.size()}",
logger=logger,
level=level,
group=group,
rank=rank,
)
@property
def named_modules_in_pp_rank(self) -> Dict[str, nn.Module]:
"""Return the named modules that only belongs to the current pp rank.
An example output:
{
'module_name': module,
...
}
NOTE: not include module_name.weight or bias, but only module_name
"""
def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]:
"""
Return all the leaf modules (modules without any child modules) in a PyTorch module.
"""
leaf_modules = []
for n, m in module.named_modules():
if not list(m.children()):
leaf_modules.append((n, m))
return leaf_modules
modules = get_leaf_modules(self)
named_modules_in_current_pp_rank = {}
for name, module in modules:
if isinstance(module, PipelineBlock):
# NOTE: these are the modules that aren't belong to the current pp rank
continue
named_modules_in_current_pp_rank[name] = module
return named_modules_in_current_pp_rank
class DTypeInvariantTensor(torch.Tensor):
"""DTypeInvariantTensor is a subclass of torch.Tensor that disallows modification of its dtype. Note that the data
and other attributes of the tensor can still be modified."""
def __new__(cls, *args, **kwargs):
tensor = super().__new__(cls, *args, **kwargs)
return tensor
def detach(self, *args, **kwargs):
raise RuntimeError("Cannot detach an DTypeInvariantTensor")
def to(self, *args, **kwargs):
if "dtype" in kwargs or any(isinstance(arg, torch.dtype) for arg in args):
raise RuntimeError("Cannot change the type of an DTypeInvariantTensor")
else:
return super().to(*args, **kwargs)
def type(self, *args, **kwargs):
raise RuntimeError("Cannot change the type of an DTypeInvariantTensor")
def float(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to float")
def double(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to double")
def half(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to half")
def long(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to long")
def int(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to int")
def short(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to short")
def char(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to char")
def byte(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to byte")
def bool(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to bool")
def bfloat16(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to bfloat16")
def build_model(
model_builder: Callable[[], NanotronModel],
parallel_context: ParallelContext,
dtype: torch.dtype,
target_pp_ranks: Optional[List[int]] = None,
device: Optional[torch.device] = torch.device("cuda"),
) -> NanotronModel:
"""Build the model and set the pp ranks for each pipeline block."""
# TODO: classes dont take same args
log_rank("Building model..", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg)
model: NanotronModel = model_builder()
# If no target pp ranks are specified, we assume that we want to use all pp ranks
if target_pp_ranks is None:
pp_size = parallel_context.pp_pg.size()
target_pp_ranks = list(range(pp_size))
else:
pp_size = len(target_pp_ranks)
# Set rank for each pipeline block
log_rank("Setting PP block ranks...", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg)
pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)]
# "cuda" is already defaulted for each process to it's own cuda device
with init_on_device_and_dtype(device=device, dtype=dtype):
# TODO: https://github.com/huggingface/nanotron/issues/65
# Balance compute across PP blocks
block_compute_costs = model.get_block_compute_costs()
block_cumulative_costs = np.cumsum(
[
block_compute_costs[module.module_builder] if module.module_builder in block_compute_costs else 0
for module in pipeline_blocks
]
)
thresholds = [block_cumulative_costs[-1] * ((rank + 1) / pp_size) for rank in range(pp_size)]
assert thresholds[-1] >= block_cumulative_costs[-1]
target_pp_rank_idx = 0
for block, cumulative_cost in zip(pipeline_blocks, block_cumulative_costs):
assert target_pp_rank_idx < pp_size
block.build_and_set_rank(target_pp_ranks[target_pp_rank_idx])
if cumulative_cost > thresholds[target_pp_rank_idx]:
target_pp_rank_idx += 1
model.input_pp_rank = target_pp_ranks[0]
model.output_pp_rank = target_pp_ranks[target_pp_rank_idx]
return model
# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does.
@contextmanager
def init_on_device_and_dtype(
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float,
):
"""
A context manager under which models are initialized with all parameters on the specified device.
Args:
device (`torch.device` defaults to `cpu`):
Device to initialize all parameters on.
dtype (`torch.dtype` defaults to `torch.float`):
Dtype to initialize all parameters on.
include_buffers (`bool`, defaults to `False`):
Whether or not to also default all buffers constructors given previous arguments.
Example:
```python
import torch.nn as nn
from accelerate import init_on_device
with init_on_device_and_dtype(device=torch.device("cuda")):
tst = nn.Liner(100, 100) # on `cuda` device
```
"""
old_register_parameter = nn.Module.register_parameter
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
if isinstance(param, DTypeInvariantTensor):
# if param is DTypeInvariantTensor we should avoid updating it
param.data = param.data.to(device)
else:
param.data = param.data.to(device, dtype)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
if isinstance(buffer, DTypeInvariantTensor):
# if buffer is DTypeInvariantTensor we should avoid updating it
buffer.data = buffer.data.to(device)
else:
module._buffers[name] = module._buffers[name].to(device, dtype)
# Patch tensor creation
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
kwargs["dtype"] = dtype
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def check_model_has_grad(model: NanotronModel, parallel_context: "ParallelContext"):
"""Check that there's at least a parameter in current PP rank that has a gradient."""
for param in model.parameters():
if param.requires_grad:
return True
raise ValueError(
f"Can't use DDP because model in PP={dist.get_rank(parallel_context.pp_pg)} has no gradient. Consider increasing the number of layers of your model, or put a smaller PP size.\n"
f"Model: {model}"
)
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMa model."""
from typing import Dict, List, Optional, Union
import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, LlamaConfig, ParallelismArgs
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
from nanotron.utils import checkpoint_method
logger = logging.get_logger(__name__)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.end = end
self.theta = theta
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
# TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
self.freqs_cis: torch.Tensor
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
self.register_buffer(
"freqs_cis",
torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
persistent=False,
)
assert self.freqs_cis.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.freqs_cis.dtype != torch.float:
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim)
).to(
"cuda"
) # should be computed on CPU, otherwise different results with Transformers.
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
freqs = torch.view_as_real(complex_freqs)
self.freqs_cis.copy_(freqs)
self._initialized_buffer = True
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
batch_size, seq_length, num_heads, inner_dim = x.shape
while (
position_ids is not None and position_ids[-1, -1] >= self.end
) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
self.end *= 2
self._initialized_buffer = False
if self._initialized_buffer is False:
print(f"Initializing rotary embeddings with end={self.end}")
self.init_rotary_embeddings()
dtype = x.dtype
assert inner_dim % 2 == 0
x = x.view(
batch_size, seq_length, num_heads, inner_dim // 2, 2
) # [batch_size, q_length, num_heads, inner_dim]
if x.dtype == torch.bfloat16:
x = x.float()
complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
if position_ids is None:
freqs_cis = self.freqs_cis[None, :seq_length, None, :]
else:
# TODO(kunhao): Should None follow the num_heads dimension?
if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
complex_freqs = torch.view_as_complex(freqs_cis)
x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
return x_out.type(dtype)
## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()
def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
self.inv_freq = self.inv_freq.to(
torch.float
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)
@torch.no_grad()
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
self.act = ACT2FN[act_fn_name]
def forward(self, merged_states: torch.Tensor):
gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
return self.act(gate_states) * up_states
class MLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
gate_up_contiguous_chunks = (
config.intermediate_size, # shape of gate_linear
config.intermediate_size, # shape of up_linear
)
self.gate_up_proj = TensorParallelColumnLinear(
config.hidden_size,
2 * config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
self.split_silu_mul = GLUActivation(config.hidden_act)
def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states}
class CoreAttention(nn.Module):
def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
# TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.is_using_mup = config.is_using_mup
self.checkpoint_attention = False # Because flash_attn already does checkpointing
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=False,
)
return attn_output
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
class CausalSelfAttention(nn.Module, AttachableStore):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
try:
assert (
config.num_key_value_heads % tp_pg.size() == 0
), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
except AttributeError:
log_rank(
"WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
logger=logger,
level=logging.WARNING,
rank=0,
)
# If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
config.num_key_value_heads = config.num_attention_heads
assert (
config.num_attention_heads % config.num_key_value_heads == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
self.n_repeats = config.num_attention_heads // config.num_key_value_heads
self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
self.is_using_mup = config.is_using_mup
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
# build the slice config for self.qkv for save/load
# shard are done within the contiguous chunk
qkv_contiguous_chunks = (
config.num_attention_heads * self.d_qk, # shape of q
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
self.qkv_proj = TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
if config.rope_interleaved:
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
self.rope_interleaved = config.rope_interleaved
# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)
self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
)
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
qkv_states = self.qkv_proj(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape
if self.is_gqa:
query_states, key_states, value_states = torch.split(
qkv_states,
[
self.n_local_q_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
],
dim=-1,
)
query_states = (
query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
)
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
else:
query_states, key_states, value_states = (
qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
.permute(2, 1, 0, 3, 4)
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
position_offsets = position_ids[:, -1]
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
# interleaved version.
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# non interleaved version.
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
if self.rotary_embedding.end > old_rotary_embed_end:
k_cache = torch.cat(
[
k_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
v_cache = torch.cat(
[
v_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_v,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
assert (
k_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
assert (
v_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size, kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size, kv_length, self.n_heads, d_v]
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=softmax_scale,
causal=True,
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
}
)
else: # Training case
# Apply rotary embeddings to query/key states
# NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk]
# Here it is, [batch_size, seq_length, num_heads, d_qk]
# [2, batch_size, seq_length, num_heads, d_qk]
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
# [batch_size, seq_length, 2, num_heads, d_qk]
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]
key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output = self.o_proj(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
self.recompute_layer = parallel_config.recompute_layer
def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual
return hidden_states, output["sequence_mask"]
def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)
return {
"hidden_states": hidden_states,
"sequence_mask": sequence_mask,
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
padding_idx=config.pad_token_id,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
store = self.get_local_store()
if store is not None:
if "past_length" in store:
past_length = store["past_length"]
else:
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# Store new past_length in store
store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class LlamaModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0)
if config.rope_interleaved:
log_rank(
"The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers",
logger=logger,
level=logging.INFO,
rank=0,
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlamaDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
def forward_with_hidden_states(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
try:
num_key_values_heads = self.config.num_key_value_heads
except AttributeError:
num_key_values_heads = self.config.num_attention_heads
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
num_key_value_heads=num_key_values_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(0, 1)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class LlamaForTraining(NanotronModel):
def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
@torch.no_grad()
def init_model_randomly(self, config: Config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
init_method = config.model.init_method
if isinstance(init_method, RandomInit):
parametrizator_cls = StandardParametrizator
elif isinstance(init_method, SpectralMupInit):
parametrizator_cls = SpectralMupParametrizator
else:
raise ValueError(f"Unknown init method {init_method}")
parametrizator = parametrizator_cls(config=config.model)
log_rank(
f"Parametrizing model parameters using {parametrizator.__class__.__name__}",
logger=logger,
level=logging.INFO,
rank=0,
)
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
parametrizator.parametrize(param_name, module)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_embeddings_lm_head_tied_names(self):
"""Get the names of the tied embeddings and lm_head weights"""
if self.config.tie_word_embeddings is True:
return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
else:
return []
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
def get_flops(
num_layers,
hidden_size,
num_heads,
num_key_value_heads,
vocab_size,
seq_len,
ffn_hidden_size,
batch_size=1,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
num_key_value_heads: number of key/value heads in the model
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if num_key_value_heads is None:
num_key_value_heads = num_heads
hidden_size_per_head = hidden_size // num_heads
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention
## qkv projection
decoder_qkv_proj_flops_fwd = (
2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head
+ 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head
)
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head
## attn out
decoder_attn_out_flops_fwd = (
2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size
)
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_qkv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
hardware_flops = model_flops # TODO: This is a placeholder for now
return model_flops, hardware_flops
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder (GPT with Multi-Query Attention, RoPe, SWA and GQA).
Some dependencies to update before using:
- install `torch>=2.0`
- install `flash-attn>=2.5.0`
"""
import inspect
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm, init
from torch.nn import functional as F
from nanotron import distributed as dist
from nanotron.config import ParallelismArgs, Starcoder2Config
from nanotron.generation.generate_store import AttachableStore
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonLayerNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.sharded_parameters import (
SplitConfig,
mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
column_linear,
sharded_cross_entropy,
)
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.random import RandomStates, branch_random_state
from nanotron.utils import checkpoint_method
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
@torch.jit.script
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class StarcoderRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX."""
def __init__(self, head_dim: int, base: int):
super().__init__()
self.base = base
self.head_dim = head_dim
self.seq_len_cached = -1
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
self.inv_freq: torch.Tensor
self.register_buffer(
"inv_freq",
torch.empty(head_dim // 2, dtype=torch.float),
persistent=False,
)
self.cos_cached: Optional[torch.Tensor] = None
self.sin_cached: Optional[torch.Tensor] = None
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
assert self.inv_freq.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.inv_freq.dtype != torch.float:
self.inv_freq = self.inv_freq.to(torch.float)
assert self.inv_freq.dtype == torch.float
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device="cuda") / self.head_dim)
)
self._initialized_buffer = True
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim]
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()[None, :, None, :] # [1, seq_len, 1, head_dim]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
)
def forward(self, query, key, past_key_values_length=0):
"""
Args:
query: [batch_size, seq_len, num_heads, head_dim]
key: [batch_size, seq_len, num_heads, head_dim]
past_key_values_length: int
Returns:
query: [batch_size, seq_len, num_heads, head_dim]
key: [batch_size, seq_len, num_heads, head_dim]
"""
# TODO @nouamane: support position_ids
if self._initialized_buffer is False:
self.init_rotary_embeddings()
seq_len = query.shape[1]
cos, sin = self.cos_sin(
seq_len, past_key_values_length, query.device, query.dtype
) # [1, seq_len, 1, head_dim]
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class MLP(nn.Module):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
d_ff = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
self.c_fc = TensorParallelColumnLinear(
config.hidden_size,
d_ff,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
)
self.act = torch.jit.script(ACT2FN[config.activation_function])
self.c_proj = TensorParallelRowLinear(
d_ff,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return {"hidden_states": hidden_states}
class CoreAttention(nn.Module):
"""
Attention module similar to CoreAttention where only the query is multi-headed.
"""
def __init__(self, config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
from flash_attn.flash_attn_interface import flash_attn_varlen_func
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters)
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
# we still divide the value dimension by the number of heads https://arxiv.org/pdf/1911.02150.pdf
self.d_v = config.hidden_size // config.num_attention_heads
self.dropout = config.attn_pdrop
assert config.scale_attn_weights, "Scale is only supported in torch 2.1.0"
# self.scale_factor = 1.0
# if config.scale_attn_weights:
# self.scale_factor = self.scale_factor / (self.d_qk**0.5)
self.checkpoint_attention = False # Because flash_attn already does checkpointing
if config.sliding_window_size is not None:
assert (
_flash_supports_window_size
), "Current version of flash-attn doesn't support sliding window: `pip install flash-attn>=2.3`"
self.sliding_window_size = config.sliding_window_size if layer_idx not in config.global_attn_layers else None
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, num_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=self.dropout if self.training else 0.0,
softmax_scale=None, # defaults to 1/sqrt(d_qk)
causal=causal,
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
return_attn_probs=False,
)
return attn_output
# Hack to propagage gradient correctly
def get_sliced_parameter(coalesced_tensor: torch.Tensor, slice_object: slice):
with torch.no_grad():
# This allows us to create a leaf tensor, despite sharing the underlying storage
result = NanotronParameter(tensor=coalesced_tensor[slice_object])
# We need sliced tensor to also get the gradient in order to run optimizer on them
# TODO @thomasw21: It's really had to make sure that our sliced view keeps the same memory space as the original gradient
def get_grad_view(orig_grad):
assert orig_grad.is_contiguous()
if result.grad is None:
# The gradient was reset to None, we need to reset the coalesced_tensor.grad as well
coalesced_tensor.grad = None
# TODO @thomasw21: Can I trigger hooks that we've set in `register_hook`
if coalesced_tensor.grad is None:
result.grad = orig_grad[slice_object]
else:
result.grad = coalesced_tensor.grad[slice_object]
return orig_grad
# If `coalesced_tensor` requires gradient, then we need to update the `result` grad attribute upon backward step.
if coalesced_tensor.requires_grad is True:
coalesced_tensor.register_hook(get_grad_view)
return result
class _MQAColumnLinearReduceScatterAsyncCommunication(torch.autograd.Function):
"""This computes `q` and `kv` computation in MQA setting.
Basic assumptions:
- `kv.weight` and `kv.bias` (if not None) are duplicated across tp_pg
- `tp_mode` is REDUCE_SCATTER
- `async_communication` is set to True
What this function does:
- in the forward pass:
- overlap input `all_gather` with `kv` computation
- overlap kv output `all_gather` with `q` computation
- in the backward pass:
- overlap input `all_gather` with gradient_input computation
- overlap gradient_input `reduce_scatter` with `kv` and `q` gradient computation
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
q_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
kv_weight: torch.Tensor,
kv_bias: Optional[torch.Tensor],
# Basically we assume that `qkv_weight` is already the concatenated version of `q.weight` and `kv.weight`
qkv_weight: torch.Tensor,
tp_pg: dist.ProcessGroup,
) -> Tuple[torch.Tensor, torch.Tensor]:
ctx.tp_pg = tp_pg
ctx.use_q_bias = q_bias is not None
ctx.use_kv_bias = kv_bias is not None
ctx.split_q_and_kv_id = q_weight.shape[0]
# All gather x if needed
gathered_x: torch.Tensor
gather_x_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_x = x
else:
first_dim = x.shape[0]
last_dims = x.shape[1:]
unsharded_first_dim = first_dim * tp_pg.size()
gathered_x = torch.empty(
unsharded_first_dim,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=x.requires_grad,
)
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
x = x.contiguous()
gather_x_handle = dist.all_gather_into_tensor(gathered_x, x, group=tp_pg, async_op=True)
# Compute kv (we assume that kv is duplicated across TP)
kv_out = F.linear(x, kv_weight, kv_bias)
# Wait for communication to finish
if gather_x_handle is not None:
gather_x_handle.wait()
# All gather `kv` output
gathered_kv_out: torch.Tensor
gather_kv_out_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_kv_out = kv_out
else:
first_dim = kv_out.shape[0]
last_dims = kv_out.shape[1:]
unsharded_first_dim = first_dim * tp_pg.size()
gathered_kv_out = torch.empty(
unsharded_first_dim,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=x.requires_grad,
)
gather_kv_out_handle = dist.all_gather_into_tensor(gathered_kv_out, kv_out, group=tp_pg, async_op=True)
# Compute q
q_out = F.linear(gathered_x, q_weight, q_bias)
# Wait for communication to finish
if gather_kv_out_handle is not None:
gather_kv_out_handle.wait()
ctx.save_for_backward(x, qkv_weight)
return q_out, gathered_kv_out
@staticmethod
def backward(
ctx, grad_q: torch.Tensor, grad_kv: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], None, None]:
tp_pg = ctx.tp_pg
split_q_and_kv_id = ctx.split_q_and_kv_id
use_q_bias = ctx.use_q_bias
use_kv_bias = ctx.use_kv_bias
x, qkv_weight = ctx.saved_tensors
# Gather `x`
gathered_x: torch.Tensor
gather_x_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_x = x
else:
first_dim = x.shape[0]
last_dims = x.shape[1:]
unsharded_batch_size = first_dim * tp_pg.size()
gathered_x = torch.empty(
unsharded_batch_size,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=False,
)
gather_x_handle = dist.all_gather_into_tensor(gathered_x, x, group=tp_pg, async_op=True)
# Backward computation on `kv` and `q` with regards to input
grad_qkv = torch.concat([grad_q, grad_kv], dim=-1)
grad_tensor = grad_qkv.matmul(qkv_weight)
# Wait for gather `x` to finish
if gather_x_handle is not None:
gather_x_handle.wait()
# Reduce scatter gradients with regards to input
sub_gradient_tensor: torch.Tensor
sub_gradient_tensor_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
sub_gradient_tensor = grad_tensor
else:
sub_gradient_tensor = torch.empty(
x.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
sub_gradient_tensor_handle = dist.reduce_scatter_tensor(
sub_gradient_tensor, grad_tensor, group=tp_pg, async_op=True
)
# Backward computation for `q` and `kv` with regards to
# flat_gathered_x = gathered_x.view(math.prod(gathered_x.shape[:-1]), gathered_x.shape[-1])
# flat_grad_kv = grad_kv.reshape(math.prod(grad_kv.shape[:-1]), grad_kv.shape[-1])
# flat_grad_q = grad_q.reshape(math.prod(grad_q.shape[:-1]), grad_q.shape[-1])
# grad_kv_weight = flat_grad_kv.t().matmul(flat_gathered_x)
# grad_kv_bias = flat_grad_kv.sum(dim=0) if use_kv_bias else None
# grad_q_weight = flat_grad_q.t().matmul(flat_gathered_x)
# grad_q_bias = flat_grad_q.sum(dim=0) if use_q_bias else None
flat_gathered_x = gathered_x.view(math.prod(gathered_x.shape[:-1]), gathered_x.shape[-1])
flat_grad_qkv = grad_qkv.view(math.prod(grad_qkv.shape[:-1]), grad_qkv.shape[-1])
grad_q_weight, grad_kv_weight = torch.split(
flat_grad_qkv.t().matmul(flat_gathered_x),
split_size_or_sections=[split_q_and_kv_id, grad_qkv.shape[-1] - split_q_and_kv_id],
dim=0,
)
if use_q_bias is True:
if use_kv_bias is True:
grad_qkv_bias = flat_grad_qkv.sum(dim=0)
grad_q_bias, grad_kv_bias = torch.split(
grad_qkv_bias,
split_size_or_sections=[split_q_and_kv_id, grad_qkv.shape[-1] - split_q_and_kv_id],
dim=0,
)
else:
grad_kv_bias = None
grad_q_bias = flat_grad_qkv[:, :split_q_and_kv_id].sum(dim=0)
else:
grad_q_bias = None
if use_kv_bias is False:
grad_kv_bias = flat_grad_qkv[:, split_q_and_kv_id:].sum(dim=0)
else:
grad_kv_bias = None
# Wait for `reduce_scatter`
if sub_gradient_tensor_handle is not None:
sub_gradient_tensor_handle.wait()
return sub_gradient_tensor, grad_q_weight, grad_q_bias, grad_kv_weight, grad_kv_bias, None, None
class MQAColumnLinears(nn.Module):
def __init__(
self,
in_features: int,
q_out_features: int,
kv_out_features: int,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
):
super().__init__()
self.pg = pg
self.world_size = pg.size()
assert in_features % self.world_size == 0
self.in_features = in_features
self.q_out_features = q_out_features // self.world_size
self.kv_out_features = kv_out_features
# Tp mode
self.mode = mode
self.async_communication = async_communication
self.use_MQAColumnLinearReduceScatterAsyncCommunication = (
self.mode is TensorParallelLinearMode.REDUCE_SCATTER and self.async_communication is True
)
# allocating tensor
# We don't need to make them persistent as we expose this storage via `self.q` and `self.kv`
self.register_buffer(
"_qkv_weight",
torch.empty(
self.q_out_features + self.kv_out_features,
self.in_features,
device=device,
dtype=dtype,
# We use another specific path that doesn't use `_qkv_weight`
requires_grad=not self.use_MQAColumnLinearReduceScatterAsyncCommunication,
),
persistent=False,
)
if bias is True:
self.register_buffer(
"_qkv_bias",
torch.empty(
self.q_out_features + self.kv_out_features,
device=device,
dtype=dtype,
requires_grad=not self.use_MQAColumnLinearReduceScatterAsyncCommunication,
),
persistent=False,
)
else:
self._qkv_bias = None
# Register parameters
# We are very lucky because the sharding allows parameters to still be contiguous.
# We use a hack to propagate gradients
q_param_dict = {"weight": get_sliced_parameter(self._qkv_weight, slice_object=slice(self.q_out_features))}
kv_param_dict = {
"weight": get_sliced_parameter(self._qkv_weight, slice_object=slice(self.q_out_features, None))
}
if bias is True:
q_param_dict["bias"] = get_sliced_parameter(self._qkv_bias, slice_object=slice(self.q_out_features))
kv_param_dict["bias"] = get_sliced_parameter(self._qkv_bias, slice_object=slice(self.q_out_features, None))
self.q = nn.ParameterDict(q_param_dict)
self.kv = nn.ParameterDict(kv_param_dict)
# Marking as tied/sharded
mark_all_parameters_in_module_as_sharded(self.q, pg=self.pg, split_config=SplitConfig(split_dim=0))
# Init
self.reset_parameters()
def reset_parameters(self) -> None:
"""Copied from nn.Linear.reset_parameters"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self._qkv_weight, a=math.sqrt(5))
if self._qkv_bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self._qkv_weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self._qkv_bias, -bound, bound)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.use_MQAColumnLinearReduceScatterAsyncCommunication:
assert self._qkv_weight.requires_grad is False
assert self._qkv_bias is None or self._qkv_bias.requires_grad is False
return _MQAColumnLinearReduceScatterAsyncCommunication.apply(
x, self.q.weight, self.q.bias, self.kv.weight, self.kv.bias, self._qkv_weight, self.pg
)
qkv = column_linear(
input=x,
weight=self._qkv_weight,
bias=self._qkv_bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
q, kv = torch.split(qkv, dim=-1, split_size_or_sections=[self.q_out_features, self.kv_out_features])
return q, kv
class CausalSelfMQA(nn.Module, AttachableStore):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
self.tp_pg_size = tp_pg.size()
self.n_heads = config.num_attention_heads // tp_pg.size()
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.mode = tp_mode
self.pg = tp_pg
# only Q_size is parallelized
self.qkv = MQAColumnLinears(
in_features=self.d_model,
q_out_features=config.num_attention_heads * self.d_qk,
kv_out_features=self.d_qk + self.d_v,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
)
self.maybe_rotary = (
StarcoderRotaryEmbedding(head_dim=self.d_qk, base=config.rope_theta)
if config.use_rotary_embeddings
else lambda q, k, t: (q, k)
)
self.o = TensorParallelRowLinear(
config.num_attention_heads * self.d_v,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
assert config.multi_query is True
assert config.grouped_query is False
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_dim]
sequence_mask, # [batch_size, seq_length]
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
batch_size = hidden_states.shape[1]
def unshape(states):
"""Given a [batch_dim * seq_length, num_heads, d_v] returns a [seq_length, batch_dim, num_heads * d_v]"""
if states.ndim == 3:
total = states.shape[0]
assert total % batch_size == 0
seq_length = total // batch_size
else:
seq_length = states.shape[1]
return (
states.view(batch_size, seq_length, self.n_heads, self.d_v)
.transpose(0, 1)
.contiguous()
.view(seq_length, batch_size, self.n_heads * self.d_v)
)
def shape(
query_states, # [q_length, batch_size, num_heads * d_qk]
kv_states, # [kv_length, batch_size, d_qk + d_v]
):
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
q_length = query_states.shape[0]
kv_length = kv_states.shape[0]
query_states = query_states.view(
q_length, batch_size, self.n_heads, self.d_qk
) # [q_length, batch_size, num_heads, d_qk]
query_states = (
query_states.permute(1, 0, 2, 3).contiguous().view(batch_size, q_length, self.n_heads, self.d_qk)
) # [batch_size, q_length, num_heads, d_qk]
key_states, value_states = torch.split(
kv_states, [self.d_qk, self.d_v], dim=-1
) # [kv_length, batch_size, d_qk], [kv_length, batch_size, d_v]
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, kv_length, self.d_qk).unsqueeze(dim=2)
) # [batch_size, kv_length, 1, d_qk]
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, kv_length, self.d_v).unsqueeze(dim=2)
) # [batch_size, kv_length, 1, d_v]
return query_states, key_states, value_states
# get query/key/value states
query_states, kv_states = self.qkv(
hidden_states
) # [seq_length, batch_size, num_heads * d_qk], [seq_length, batch_size, d_qk + d_v]
query_states, key_states, value_states = shape(query_states=query_states, kv_states=kv_states)
# [batch_size, q_length, num_heads, d_qk], [batch_size, kv_length, 1, d_qk], [batch_size, kv_length, 1, d_v]
seq_length_dim = 1
q_length = query_states.shape[seq_length_dim]
# Get cached key/values from store if available
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert kv_states.requires_grad is False
assert value_states.requires_grad is False
# Compute rotary embeddings
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
past_key_values_length = store["past_key_values_length"]
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
past_key_values_length = 0
position_offsets = position_ids[:, -1]
query_states, key_states = self.maybe_rotary(
query_states, key_states, past_key_values_length=past_key_values_length
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
1,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, 1, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_heads, d_qk)
k=key_unpad, # (total_kv, 1, d_qk)
v=value_unpad, # (total_kv, 1, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(batch_size, kv_length, 1, self.d_qk) # [batch_size, kv_length, 1, d_qk]
value_states = value_states.view(batch_size, kv_length, 1, self.d_v) # [batch_size, kv_length, 1, d_v]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
"past_key_values_length": past_key_values_length,
}
)
else:
query_states, key_states = self.maybe_rotary(query_states, key_states, past_key_values_length=0)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[seq_length_dim]
query_states = query_states.view(batch_size * q_length, self.n_heads, self.d_qk)
key_states = key_states.view(batch_size * kv_length, 1, self.d_qk)
value_states = value_states.view(batch_size * kv_length, 1, self.d_v)
attention_output = self.attention(
query_states=query_states, # [batch_size * q_length, num_heads, d_qk]
key_states=key_states, # [batch_size * kv_length, 1, d_qk]
value_states=value_states, # [batch_size * kv_length, 1, d_v]
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
) # [batch_size, num_heads, seq_length, d_v]
output = self.o(unshape(attention_output))
return {"hidden_states": output, "sequence_mask": sequence_mask}
############################
# GQA
############################
class CausalSelfGQA(nn.Module, AttachableStore):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
self.maybe_rotary = (
StarcoderRotaryEmbedding(head_dim=self.head_dim, base=config.rope_theta)
if config.use_rotary_embeddings
else lambda q, k, t: (q, k)
)
self.num_kv_heads = config.num_kv_heads if (not config.multi_query) else 1
self.n_local_q_heads = self.num_heads // tp_pg.size()
self.n_local_kv_heads = config.num_kv_heads // tp_pg.size()
assert (
config.num_kv_heads >= tp_pg.size()
), f"Number of kv heads ({config.num_kv_heads}) must be >= TP size ({tp_pg.size()})."
self.n_repeats = self.n_local_q_heads // self.n_local_kv_heads
qkv_contiguous_chunks = None
self.query_key_value = TensorParallelColumnLinear(
self.hidden_size,
self.num_heads * self.head_dim + 2 * self.num_kv_heads * self.head_dim,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
)
self.dense = TensorParallelRowLinear(
self.hidden_size,
self.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
assert config.multi_query is False
assert config.grouped_query is True
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # (seq_length, batch_size, hidden_size)
sequence_mask, # (batch_size, seq_length)
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
fused_qkv = self.query_key_value(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim]
q_length, batch_size, _ = fused_qkv.size()
qkv = fused_qkv.view(q_length, batch_size, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim)
query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3)
query_states = query.transpose(0, 1).reshape(
batch_size, q_length, self.n_local_q_heads, self.head_dim
) # TODO @nouamane: can we transpose qkv instead?
key_states = key.transpose(0, 1).reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim)
value_states = value.transpose(0, 1).reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim)
# Get cached key/values from store if available
store = self.get_local_store()
if store is not None:
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
# Compute rotary embeddings
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
past_key_values_length = store["past_key_values_length"]
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
past_key_values_length = 0
position_offsets = position_ids[:, -1]
query_states, key_states = self.maybe_rotary(
query_states, key_states, past_key_values_length=past_key_values_length
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.head_dim,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.head_dim),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, self.n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, self.n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, self.n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.head_dim
) # [batch_size, q_length, self.n_local_q_heads, self.head_dim]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
# Update store
if past_key_values_length == 0:
past_key_values_length = sequence_mask.shape[1] - 1 # we add 1 when we load the value
else:
past_key_values_length += 1
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
"past_key_values_length": past_key_values_length,
}
)
else:
# Apply rotary embeddings to query/key states
query_states, key_states = self.maybe_rotary(query_states, key_states, past_key_values_length=0)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, head_dim]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.reshape(
batch_size * q_length, self.n_local_q_heads, self.head_dim
) # [batch_size * q_length, self.n_local_q_heads, head_dim]
key_states = key_states.reshape(
batch_size * kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size * kv_length, self.n_local_kv_heads, head_dim]
value_states = value_states.reshape(
batch_size * kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size * kv_length, self.n_local_kv_heads, head_dim]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
) # [batch_size * seq_length, self.n_local_q_heads, head_dim]
attention_output = attention_output.view(batch_size, q_length, self.n_local_q_heads * self.head_dim).transpose(
0, 1
)
output = self.dense(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
@torch.jit.script
def dropout_add(x, residual, prob, training):
# type: (Tensor, Tensor, float, bool) -> Tensor
# From: https://github.com/NVIDIA/Megatron-LM/blob/285068c8108e0e8e6538f54fe27c3ee86c5217a2/megatron/model/transformer.py#L586
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def dropout_add_fused_train(x: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor:
return dropout_add(x, residual, prob, True)
class GPTBlock(nn.Module):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
random_states: RandomStates,
layer_idx: int,
):
super(GPTBlock, self).__init__()
self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
if config.multi_query is True:
self.attn = CausalSelfMQA(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
elif config.grouped_query is True:
self.attn = CausalSelfGQA(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
else:
raise ValueError("Either `multi_query` or `grouped_query` must be True") # TODO: @nouamane not necessarily
self.attn_dropout = config.attn_pdrop
self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
self.ff_dropout = config.resid_pdrop
self.random_states = random_states
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
if self.training:
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout)
else:
# No need for random state context manager
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"]
if self.training:
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout)
else:
# No need for random state context manager
hidden_states = hidden_states + residual
return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: Starcoder2Config, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
# store = self.get_local_store()
# if store is not None:
# if "past_length" in store:
# past_length = store["past_length"]
# else:
# past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
# cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# # Store new past_length in store
# store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class GPTModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: Starcoder2Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.random_states = random_states
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
self.token_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
self.embeds_dropout = PipelineBlock(
p2p=self.p2p,
module_builder=nn.Dropout,
module_kwargs={"p": config.embd_pdrop},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=GPTBlock,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"random_states": random_states,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonLayerNorm,
module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
)
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": parallel_config.tp_linear_async_communication
if parallel_config is not None
else False,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = self.embeds_dropout(input=input_embeds)["hidden_states"]
hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(
0, 1
) # TODO @nouamane: case where TP=1 should be simpler
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class Starcoder2ForTraining(NanotronModel):
def __init__(
self,
config: Starcoder2Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
super().__init__()
self.model = GPTModel(
config=config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=random_states,
)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.config: Starcoder2Config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Union[torch.Tensor, TensorPointer]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
return {
"loss": self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
}
def tie_custom_params(self) -> None:
# find all params with names qkv.kv.weight and qkv.kv.bias in them
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if ".qkv.kv." in name:
assert not param.is_tied, f"Parameter {name} is already tied"
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(self.parallel_context.tp_pg))),
)
]
tie_parameters(
root_module=self,
ties=shared_weights,
parallel_context=self.parallel_context,
# We always SUM grads, because kv weights are always duplicated in MQA
reduce_op=dist.ReduceOp.SUM,
)
@torch.no_grad()
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, LayerNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, MQAColumnLinears):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_embeddings_lm_head_tied_names(self) -> List[str]:
return [
"model.token_embeddings.pp_block.token_embedding.weight",
"model.lm_head.pp_block.weight",
]
def before_tbi_sanity_checks(self):
# SANITY CHECK: Check ".qkv.kv." params are tied
for name, kv_param in self.named_parameters():
if ".qkv.kv." in name:
assert kv_param.is_tied, f"{name} is not tied (kv weights/biases should be tied in GPTBigcode)"
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 2 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size,
seq_len=sequence_length,
batch_size=global_batch_size,
kv_channels=None,
glu_activation=False,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
def get_flops(
num_layers,
hidden_size,
num_heads,
vocab_size,
seq_len,
kv_channels=None,
ffn_hidden_size=None,
batch_size=1,
glu_activation=False,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
kv_channels: hidden size of the key and value heads
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if kv_channels is None:
assert hidden_size % num_heads == 0
kv_channels = hidden_size // num_heads
if ffn_hidden_size is None:
ffn_hidden_size = 4 * hidden_size
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention (MQA)
## q projection
decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels
## kv projection, shared across heads
decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len
### SWA (sliding window attention / local attention)
# window_size = 4096
# decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels
# decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels
## attn out
decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
if glu_activation:
# 3 matmuls instead of 2 in FFN
# ref. https://arxiv.org/pdf/2002.05202.pdf
# Used for example in T5 v1.1
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_q_proj_flops_fwd
+ decoder_kv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now
return model_flops, hardware_flops
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import torch
from packaging import version
from torch import Tensor, nn
from nanotron import logging
logger = logging.get_logger(__name__)
class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)
def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Module):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return torch.clip(gelu(x), self.min, self.max)
class AccurateGELUActivation(nn.Module):
"""
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
https://github.com/hendrycks/GELUs
Implemented along with MEGA (Moving Average Equipped Gated Attention)
"""
def __init__(self):
super().__init__()
self.precomputed_constant = math.sqrt(2 / math.pi)
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def forward(self, input: Tensor) -> Tensor:
return nn.functional.silu(input)
class MishActivation(nn.Module):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9.0"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * torch.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
class LaplaceActivation(nn.Module):
"""
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
https://arxiv.org/abs/2209.10655
Inspired by squared relu, but with bounded range and gradient for better stability
"""
def forward(self, input, mu=0.707107, sigma=0.282095):
input = (input - mu).div(sigma * math.sqrt(2.0))
return 0.5 * (1.0 + torch.erf(input))
class ReLUSquaredActivation(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward(self, input):
relu_applied = nn.functional.relu(input)
squared = torch.square(relu_applied)
return squared
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
import torch
from torch import nn
class TritonLayerNorm(nn.LayerNorm):
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=False,
return_dropout_mask=return_dropout_mask,
)
# This is equivalent to LLaMA RMSNorm
# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
None,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment