Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
if TYPE_CHECKING:
from nanotron.config import NanotronConfigs
from nanotron.parallel.parameters import NanotronParameter
logger = logging.get_logger(__name__)
class NanotronModel(nn.Module, metaclass=ABCMeta):
"""Abstract class for Nanotron models
We make the following assumptions:
- When building PP blocks, we assume that the modules order are in the same order as the forward pass."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.parallel_context: "ParallelContext"
self.config: "NanotronConfigs"
self.module_id_to_prefix: dict[int, str]
# Attributes defined when building the model
self.input_pp_rank: int
self.output_pp_rank: int
# Useful mapping to get param names
self.module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in self.named_modules()}
self.module_id_to_prefix[id(self)] = ""
def get_named_params_with_correct_tied(self) -> Iterator[Tuple[str, "NanotronParameter"]]:
"""Return named parameters with correct tied params names.
For example in the case of tied kv heads in MQA, we need to make sure tied params names are correct."""
def params_gen():
for name, param in self.named_parameters():
if param.is_tied:
yield (
param.get_tied_info().get_full_name_from_module_id_to_prefix(
module_id_to_prefix=self.module_id_to_prefix
),
param,
)
else:
yield name, param
yield from params_gen()
@abstractmethod
def init_model_randomly(self, config):
...
def tie_custom_params(self) -> None:
"""Tie custom parameters. For example for MQA marks kv heads as tied."""
pass
def get_embeddings_lm_head_tied_names(self) -> list[str]:
"""Returns the names of the embeddings and lm_head weights that are tied together. Returns empty list if not tied.
Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
"""
return []
def get_named_params_without_weight_decay(self) -> List[str]:
"""Return a list of named parameters that should not have weight decay applied to them."""
return []
def before_tbi_sanity_checks(self) -> None:
pass
def after_tbi_sanity_checks(self) -> None:
pass
def before_optim_step_sanity_checks(self) -> None:
pass
def after_optim_step_sanity_checks(self) -> None:
pass
def log_modules(self, level: int = logging.DEBUG, group: Optional[ProcessGroup] = None, rank: int = 0):
assert hasattr(self, "parallel_context"), "`NanotronModel` needs to have a `parallel_context` attribute"
for name, module in self.named_modules():
if not isinstance(module, PipelineBlock):
continue
log_rank(
f"module_name: {name} | PP: {module.rank}/{self.parallel_context.pp_pg.size()}",
logger=logger,
level=level,
group=group,
rank=rank,
)
@property
def named_modules_in_pp_rank(self) -> Dict[str, nn.Module]:
"""Return the named modules that only belongs to the current pp rank.
An example output:
{
'module_name': module,
...
}
NOTE: not include module_name.weight or bias, but only module_name
"""
def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]:
"""
Return all the leaf modules (modules without any child modules) in a PyTorch module.
"""
leaf_modules = []
for n, m in module.named_modules():
if not list(m.children()):
leaf_modules.append((n, m))
return leaf_modules
modules = get_leaf_modules(self)
named_modules_in_current_pp_rank = {}
for name, module in modules:
if isinstance(module, PipelineBlock):
# NOTE: these are the modules that aren't belong to the current pp rank
continue
named_modules_in_current_pp_rank[name] = module
return named_modules_in_current_pp_rank
class DTypeInvariantTensor(torch.Tensor):
"""DTypeInvariantTensor is a subclass of torch.Tensor that disallows modification of its dtype. Note that the data
and other attributes of the tensor can still be modified."""
def __new__(cls, *args, **kwargs):
tensor = super().__new__(cls, *args, **kwargs)
return tensor
def detach(self, *args, **kwargs):
raise RuntimeError("Cannot detach an DTypeInvariantTensor")
def to(self, *args, **kwargs):
if "dtype" in kwargs or any(isinstance(arg, torch.dtype) for arg in args):
raise RuntimeError("Cannot change the type of an DTypeInvariantTensor")
else:
return super().to(*args, **kwargs)
def type(self, *args, **kwargs):
raise RuntimeError("Cannot change the type of an DTypeInvariantTensor")
def float(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to float")
def double(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to double")
def half(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to half")
def long(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to long")
def int(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to int")
def short(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to short")
def char(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to char")
def byte(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to byte")
def bool(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to bool")
def bfloat16(self, *args, **kwargs):
raise RuntimeError("Cannot convert the type of an DTypeInvariantTensor to bfloat16")
def build_model(
model_builder: Callable[[], NanotronModel],
parallel_context: ParallelContext,
dtype: torch.dtype,
target_pp_ranks: Optional[List[int]] = None,
device: Optional[torch.device] = torch.device("cuda"),
) -> NanotronModel:
"""Build the model and set the pp ranks for each pipeline block."""
# TODO: classes dont take same args
log_rank("Building model..", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg)
model: NanotronModel = model_builder()
# If no target pp ranks are specified, we assume that we want to use all pp ranks
if target_pp_ranks is None:
pp_size = parallel_context.pp_pg.size()
target_pp_ranks = list(range(pp_size))
else:
pp_size = len(target_pp_ranks)
# Set rank for each pipeline block
log_rank("Setting PP block ranks...", logger=logger, level=logging.INFO, rank=0, group=parallel_context.world_pg)
pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)]
# "cuda" is already defaulted for each process to it's own cuda device
with init_on_device_and_dtype(device=device, dtype=dtype):
# TODO: https://github.com/huggingface/nanotron/issues/65
# Balance compute across PP blocks
block_compute_costs = model.get_block_compute_costs()
block_cumulative_costs = np.cumsum(
[
block_compute_costs[module.module_builder] if module.module_builder in block_compute_costs else 0
for module in pipeline_blocks
]
)
thresholds = [block_cumulative_costs[-1] * ((rank + 1) / pp_size) for rank in range(pp_size)]
assert thresholds[-1] >= block_cumulative_costs[-1]
target_pp_rank_idx = 0
for block, cumulative_cost in zip(pipeline_blocks, block_cumulative_costs):
assert target_pp_rank_idx < pp_size
block.build_and_set_rank(target_pp_ranks[target_pp_rank_idx])
if cumulative_cost > thresholds[target_pp_rank_idx]:
target_pp_rank_idx += 1
model.input_pp_rank = target_pp_ranks[0]
model.output_pp_rank = target_pp_ranks[target_pp_rank_idx]
return model
# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does.
@contextmanager
def init_on_device_and_dtype(
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float,
):
"""
A context manager under which models are initialized with all parameters on the specified device.
Args:
device (`torch.device` defaults to `cpu`):
Device to initialize all parameters on.
dtype (`torch.dtype` defaults to `torch.float`):
Dtype to initialize all parameters on.
include_buffers (`bool`, defaults to `False`):
Whether or not to also default all buffers constructors given previous arguments.
Example:
```python
import torch.nn as nn
from accelerate import init_on_device
with init_on_device_and_dtype(device=torch.device("cuda")):
tst = nn.Liner(100, 100) # on `cuda` device
```
"""
old_register_parameter = nn.Module.register_parameter
old_register_buffer = nn.Module.register_buffer
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
if isinstance(param, DTypeInvariantTensor):
# if param is DTypeInvariantTensor we should avoid updating it
param.data = param.data.to(device)
else:
param.data = param.data.to(device, dtype)
def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
if buffer is not None:
if isinstance(buffer, DTypeInvariantTensor):
# if buffer is DTypeInvariantTensor we should avoid updating it
buffer.data = buffer.data.to(device)
else:
module._buffers[name] = module._buffers[name].to(device, dtype)
# Patch tensor creation
tensor_constructors_to_patch = {
torch_function_name: getattr(torch, torch_function_name)
for torch_function_name in ["empty", "zeros", "ones", "full"]
}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
kwargs["dtype"] = dtype
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
nn.Module.register_buffer = old_register_buffer
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
setattr(torch, torch_function_name, old_torch_function)
def check_model_has_grad(model: NanotronModel, parallel_context: "ParallelContext"):
"""Check that there's at least a parameter in current PP rank that has a gradient."""
for param in model.parameters():
if param.requires_grad:
return True
raise ValueError(
f"Can't use DDP because model in PP={dist.get_rank(parallel_context.pp_pg)} has no gradient. Consider increasing the number of layers of your model, or put a smaller PP size.\n"
f"Model: {model}"
)
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMa model."""
from typing import Dict, List, Optional, Union
import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, LlamaConfig, ParallelismArgs
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
from nanotron.utils import checkpoint_method
logger = logging.get_logger(__name__)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.end = end
self.theta = theta
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
# TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
self.freqs_cis: torch.Tensor
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
self.register_buffer(
"freqs_cis",
torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
persistent=False,
)
assert self.freqs_cis.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.freqs_cis.dtype != torch.float:
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim)
).to(
"cuda"
) # should be computed on CPU, otherwise different results with Transformers.
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
freqs = torch.view_as_real(complex_freqs)
self.freqs_cis.copy_(freqs)
self._initialized_buffer = True
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
batch_size, seq_length, num_heads, inner_dim = x.shape
while (
position_ids is not None and position_ids[-1, -1] >= self.end
) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
self.end *= 2
self._initialized_buffer = False
if self._initialized_buffer is False:
print(f"Initializing rotary embeddings with end={self.end}")
self.init_rotary_embeddings()
dtype = x.dtype
assert inner_dim % 2 == 0
x = x.view(
batch_size, seq_length, num_heads, inner_dim // 2, 2
) # [batch_size, q_length, num_heads, inner_dim]
if x.dtype == torch.bfloat16:
x = x.float()
complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
if position_ids is None:
freqs_cis = self.freqs_cis[None, :seq_length, None, :]
else:
# TODO(kunhao): Should None follow the num_heads dimension?
if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
complex_freqs = torch.view_as_complex(freqs_cis)
x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
return x_out.type(dtype)
## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()
def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
self.inv_freq = self.inv_freq.to(
torch.float
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)
@torch.no_grad()
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
self.act = ACT2FN[act_fn_name]
def forward(self, merged_states: torch.Tensor):
gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
return self.act(gate_states) * up_states
class MLP(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
gate_up_contiguous_chunks = (
config.intermediate_size, # shape of gate_linear
config.intermediate_size, # shape of up_linear
)
self.gate_up_proj = TensorParallelColumnLinear(
config.hidden_size,
2 * config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
self.split_silu_mul = GLUActivation(config.hidden_act)
def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states}
class CoreAttention(nn.Module):
def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
# TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.is_using_mup = config.is_using_mup
self.checkpoint_attention = False # Because flash_attn already does checkpointing
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=False,
)
return attn_output
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
class CausalSelfAttention(nn.Module, AttachableStore):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
try:
assert (
config.num_key_value_heads % tp_pg.size() == 0
), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
except AttributeError:
log_rank(
"WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
logger=logger,
level=logging.WARNING,
rank=0,
)
# If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
config.num_key_value_heads = config.num_attention_heads
assert (
config.num_attention_heads % config.num_key_value_heads == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
self.n_repeats = config.num_attention_heads // config.num_key_value_heads
self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
self.is_using_mup = config.is_using_mup
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
# build the slice config for self.qkv for save/load
# shard are done within the contiguous chunk
qkv_contiguous_chunks = (
config.num_attention_heads * self.d_qk, # shape of q
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
self.qkv_proj = TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
if config.rope_interleaved:
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
self.rope_interleaved = config.rope_interleaved
# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)
self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
)
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
qkv_states = self.qkv_proj(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape
if self.is_gqa:
query_states, key_states, value_states = torch.split(
qkv_states,
[
self.n_local_q_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
],
dim=-1,
)
query_states = (
query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
)
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
else:
query_states, key_states, value_states = (
qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
.permute(2, 1, 0, 3, 4)
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
position_offsets = position_ids[:, -1]
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
# interleaved version.
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# non interleaved version.
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
if self.rotary_embedding.end > old_rotary_embed_end:
k_cache = torch.cat(
[
k_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
v_cache = torch.cat(
[
v_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_v,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
assert (
k_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
assert (
v_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size, kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size, kv_length, self.n_heads, d_v]
# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=softmax_scale,
causal=True,
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
}
)
else: # Training case
# Apply rotary embeddings to query/key states
# NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk]
# Here it is, [batch_size, seq_length, num_heads, d_qk]
# [2, batch_size, seq_length, num_heads, d_qk]
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
# [batch_size, seq_length, 2, num_heads, d_qk]
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]
key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output = self.o_proj(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
self.recompute_layer = parallel_config.recompute_layer
def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual
return hidden_states, output["sequence_mask"]
def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)
return {
"hidden_states": hidden_states,
"sequence_mask": sequence_mask,
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
padding_idx=config.pad_token_id,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
store = self.get_local_store()
if store is not None:
if "past_length" in store:
past_length = store["past_length"]
else:
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# Store new past_length in store
store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class LlamaModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0)
if config.rope_interleaved:
log_rank(
"The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers",
logger=logger,
level=logging.INFO,
rank=0,
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlamaDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
def forward_with_hidden_states(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
try:
num_key_values_heads = self.config.num_key_value_heads
except AttributeError:
num_key_values_heads = self.config.num_attention_heads
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
num_key_value_heads=num_key_values_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(0, 1)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class LlamaForTraining(NanotronModel):
def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
@torch.no_grad()
def init_model_randomly(self, config: Config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
init_method = config.model.init_method
if isinstance(init_method, RandomInit):
parametrizator_cls = StandardParametrizator
elif isinstance(init_method, SpectralMupInit):
parametrizator_cls = SpectralMupParametrizator
else:
raise ValueError(f"Unknown init method {init_method}")
parametrizator = parametrizator_cls(config=config.model)
log_rank(
f"Parametrizing model parameters using {parametrizator.__class__.__name__}",
logger=logger,
level=logging.INFO,
rank=0,
)
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
parametrizator.parametrize(param_name, module)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_embeddings_lm_head_tied_names(self):
"""Get the names of the tied embeddings and lm_head weights"""
if self.config.tie_word_embeddings is True:
return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
else:
return []
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
def get_flops(
num_layers,
hidden_size,
num_heads,
num_key_value_heads,
vocab_size,
seq_len,
ffn_hidden_size,
batch_size=1,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
num_key_value_heads: number of key/value heads in the model
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if num_key_value_heads is None:
num_key_value_heads = num_heads
hidden_size_per_head = hidden_size // num_heads
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention
## qkv projection
decoder_qkv_proj_flops_fwd = (
2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head
+ 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head
)
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head
## attn out
decoder_attn_out_flops_fwd = (
2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size
)
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_qkv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
hardware_flops = model_flops # TODO: This is a placeholder for now
return model_flops, hardware_flops
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder (GPT with Multi-Query Attention, RoPe, SWA and GQA).
Some dependencies to update before using:
- install `torch>=2.0`
- install `flash-attn>=2.5.0`
"""
import inspect
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm, init
from torch.nn import functional as F
from nanotron import distributed as dist
from nanotron.config import ParallelismArgs, Starcoder2Config
from nanotron.generation.generate_store import AttachableStore
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
from nanotron.nn.layer_norm import TritonLayerNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.sharded_parameters import (
SplitConfig,
mark_all_parameters_in_module_as_sharded,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.functional import (
column_linear,
sharded_cross_entropy,
)
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.random import RandomStates, branch_random_state
from nanotron.utils import checkpoint_method
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
@torch.jit.script
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class StarcoderRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX."""
def __init__(self, head_dim: int, base: int):
super().__init__()
self.base = base
self.head_dim = head_dim
self.seq_len_cached = -1
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
self.inv_freq: torch.Tensor
self.register_buffer(
"inv_freq",
torch.empty(head_dim // 2, dtype=torch.float),
persistent=False,
)
self.cos_cached: Optional[torch.Tensor] = None
self.sin_cached: Optional[torch.Tensor] = None
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
assert self.inv_freq.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.inv_freq.dtype != torch.float:
self.inv_freq = self.inv_freq.to(torch.float)
assert self.inv_freq.dtype == torch.float
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device="cuda") / self.head_dim)
)
self._initialized_buffer = True
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim]
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()[None, :, None, :] # [1, seq_len, 1, head_dim]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
)
def forward(self, query, key, past_key_values_length=0):
"""
Args:
query: [batch_size, seq_len, num_heads, head_dim]
key: [batch_size, seq_len, num_heads, head_dim]
past_key_values_length: int
Returns:
query: [batch_size, seq_len, num_heads, head_dim]
key: [batch_size, seq_len, num_heads, head_dim]
"""
# TODO @nouamane: support position_ids
if self._initialized_buffer is False:
self.init_rotary_embeddings()
seq_len = query.shape[1]
cos, sin = self.cos_sin(
seq_len, past_key_values_length, query.device, query.dtype
) # [1, seq_len, 1, head_dim]
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class MLP(nn.Module):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
d_ff = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
self.c_fc = TensorParallelColumnLinear(
config.hidden_size,
d_ff,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
)
self.act = torch.jit.script(ACT2FN[config.activation_function])
self.c_proj = TensorParallelRowLinear(
d_ff,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return {"hidden_states": hidden_states}
class CoreAttention(nn.Module):
"""
Attention module similar to CoreAttention where only the query is multi-headed.
"""
def __init__(self, config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
from flash_attn.flash_attn_interface import flash_attn_varlen_func
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters)
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
# we still divide the value dimension by the number of heads https://arxiv.org/pdf/1911.02150.pdf
self.d_v = config.hidden_size // config.num_attention_heads
self.dropout = config.attn_pdrop
assert config.scale_attn_weights, "Scale is only supported in torch 2.1.0"
# self.scale_factor = 1.0
# if config.scale_attn_weights:
# self.scale_factor = self.scale_factor / (self.d_qk**0.5)
self.checkpoint_attention = False # Because flash_attn already does checkpointing
if config.sliding_window_size is not None:
assert (
_flash_supports_window_size
), "Current version of flash-attn doesn't support sliding window: `pip install flash-attn>=2.3`"
self.sliding_window_size = config.sliding_window_size if layer_idx not in config.global_attn_layers else None
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, num_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=self.dropout if self.training else 0.0,
softmax_scale=None, # defaults to 1/sqrt(d_qk)
causal=causal,
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
return_attn_probs=False,
)
return attn_output
# Hack to propagage gradient correctly
def get_sliced_parameter(coalesced_tensor: torch.Tensor, slice_object: slice):
with torch.no_grad():
# This allows us to create a leaf tensor, despite sharing the underlying storage
result = NanotronParameter(tensor=coalesced_tensor[slice_object])
# We need sliced tensor to also get the gradient in order to run optimizer on them
# TODO @thomasw21: It's really had to make sure that our sliced view keeps the same memory space as the original gradient
def get_grad_view(orig_grad):
assert orig_grad.is_contiguous()
if result.grad is None:
# The gradient was reset to None, we need to reset the coalesced_tensor.grad as well
coalesced_tensor.grad = None
# TODO @thomasw21: Can I trigger hooks that we've set in `register_hook`
if coalesced_tensor.grad is None:
result.grad = orig_grad[slice_object]
else:
result.grad = coalesced_tensor.grad[slice_object]
return orig_grad
# If `coalesced_tensor` requires gradient, then we need to update the `result` grad attribute upon backward step.
if coalesced_tensor.requires_grad is True:
coalesced_tensor.register_hook(get_grad_view)
return result
class _MQAColumnLinearReduceScatterAsyncCommunication(torch.autograd.Function):
"""This computes `q` and `kv` computation in MQA setting.
Basic assumptions:
- `kv.weight` and `kv.bias` (if not None) are duplicated across tp_pg
- `tp_mode` is REDUCE_SCATTER
- `async_communication` is set to True
What this function does:
- in the forward pass:
- overlap input `all_gather` with `kv` computation
- overlap kv output `all_gather` with `q` computation
- in the backward pass:
- overlap input `all_gather` with gradient_input computation
- overlap gradient_input `reduce_scatter` with `kv` and `q` gradient computation
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
q_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
kv_weight: torch.Tensor,
kv_bias: Optional[torch.Tensor],
# Basically we assume that `qkv_weight` is already the concatenated version of `q.weight` and `kv.weight`
qkv_weight: torch.Tensor,
tp_pg: dist.ProcessGroup,
) -> Tuple[torch.Tensor, torch.Tensor]:
ctx.tp_pg = tp_pg
ctx.use_q_bias = q_bias is not None
ctx.use_kv_bias = kv_bias is not None
ctx.split_q_and_kv_id = q_weight.shape[0]
# All gather x if needed
gathered_x: torch.Tensor
gather_x_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_x = x
else:
first_dim = x.shape[0]
last_dims = x.shape[1:]
unsharded_first_dim = first_dim * tp_pg.size()
gathered_x = torch.empty(
unsharded_first_dim,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=x.requires_grad,
)
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
x = x.contiguous()
gather_x_handle = dist.all_gather_into_tensor(gathered_x, x, group=tp_pg, async_op=True)
# Compute kv (we assume that kv is duplicated across TP)
kv_out = F.linear(x, kv_weight, kv_bias)
# Wait for communication to finish
if gather_x_handle is not None:
gather_x_handle.wait()
# All gather `kv` output
gathered_kv_out: torch.Tensor
gather_kv_out_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_kv_out = kv_out
else:
first_dim = kv_out.shape[0]
last_dims = kv_out.shape[1:]
unsharded_first_dim = first_dim * tp_pg.size()
gathered_kv_out = torch.empty(
unsharded_first_dim,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=x.requires_grad,
)
gather_kv_out_handle = dist.all_gather_into_tensor(gathered_kv_out, kv_out, group=tp_pg, async_op=True)
# Compute q
q_out = F.linear(gathered_x, q_weight, q_bias)
# Wait for communication to finish
if gather_kv_out_handle is not None:
gather_kv_out_handle.wait()
ctx.save_for_backward(x, qkv_weight)
return q_out, gathered_kv_out
@staticmethod
def backward(
ctx, grad_q: torch.Tensor, grad_kv: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], None, None]:
tp_pg = ctx.tp_pg
split_q_and_kv_id = ctx.split_q_and_kv_id
use_q_bias = ctx.use_q_bias
use_kv_bias = ctx.use_kv_bias
x, qkv_weight = ctx.saved_tensors
# Gather `x`
gathered_x: torch.Tensor
gather_x_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
gathered_x = x
else:
first_dim = x.shape[0]
last_dims = x.shape[1:]
unsharded_batch_size = first_dim * tp_pg.size()
gathered_x = torch.empty(
unsharded_batch_size,
*last_dims,
device=x.device,
dtype=x.dtype,
requires_grad=False,
)
gather_x_handle = dist.all_gather_into_tensor(gathered_x, x, group=tp_pg, async_op=True)
# Backward computation on `kv` and `q` with regards to input
grad_qkv = torch.concat([grad_q, grad_kv], dim=-1)
grad_tensor = grad_qkv.matmul(qkv_weight)
# Wait for gather `x` to finish
if gather_x_handle is not None:
gather_x_handle.wait()
# Reduce scatter gradients with regards to input
sub_gradient_tensor: torch.Tensor
sub_gradient_tensor_handle: Optional[dist.Work] = None
if tp_pg.size() == 1:
sub_gradient_tensor = grad_tensor
else:
sub_gradient_tensor = torch.empty(
x.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
sub_gradient_tensor_handle = dist.reduce_scatter_tensor(
sub_gradient_tensor, grad_tensor, group=tp_pg, async_op=True
)
# Backward computation for `q` and `kv` with regards to
# flat_gathered_x = gathered_x.view(math.prod(gathered_x.shape[:-1]), gathered_x.shape[-1])
# flat_grad_kv = grad_kv.reshape(math.prod(grad_kv.shape[:-1]), grad_kv.shape[-1])
# flat_grad_q = grad_q.reshape(math.prod(grad_q.shape[:-1]), grad_q.shape[-1])
# grad_kv_weight = flat_grad_kv.t().matmul(flat_gathered_x)
# grad_kv_bias = flat_grad_kv.sum(dim=0) if use_kv_bias else None
# grad_q_weight = flat_grad_q.t().matmul(flat_gathered_x)
# grad_q_bias = flat_grad_q.sum(dim=0) if use_q_bias else None
flat_gathered_x = gathered_x.view(math.prod(gathered_x.shape[:-1]), gathered_x.shape[-1])
flat_grad_qkv = grad_qkv.view(math.prod(grad_qkv.shape[:-1]), grad_qkv.shape[-1])
grad_q_weight, grad_kv_weight = torch.split(
flat_grad_qkv.t().matmul(flat_gathered_x),
split_size_or_sections=[split_q_and_kv_id, grad_qkv.shape[-1] - split_q_and_kv_id],
dim=0,
)
if use_q_bias is True:
if use_kv_bias is True:
grad_qkv_bias = flat_grad_qkv.sum(dim=0)
grad_q_bias, grad_kv_bias = torch.split(
grad_qkv_bias,
split_size_or_sections=[split_q_and_kv_id, grad_qkv.shape[-1] - split_q_and_kv_id],
dim=0,
)
else:
grad_kv_bias = None
grad_q_bias = flat_grad_qkv[:, :split_q_and_kv_id].sum(dim=0)
else:
grad_q_bias = None
if use_kv_bias is False:
grad_kv_bias = flat_grad_qkv[:, split_q_and_kv_id:].sum(dim=0)
else:
grad_kv_bias = None
# Wait for `reduce_scatter`
if sub_gradient_tensor_handle is not None:
sub_gradient_tensor_handle.wait()
return sub_gradient_tensor, grad_q_weight, grad_q_bias, grad_kv_weight, grad_kv_bias, None, None
class MQAColumnLinears(nn.Module):
def __init__(
self,
in_features: int,
q_out_features: int,
kv_out_features: int,
pg: dist.ProcessGroup,
mode: TensorParallelLinearMode,
bias=True,
device=None,
dtype=None,
async_communication: bool = False,
):
super().__init__()
self.pg = pg
self.world_size = pg.size()
assert in_features % self.world_size == 0
self.in_features = in_features
self.q_out_features = q_out_features // self.world_size
self.kv_out_features = kv_out_features
# Tp mode
self.mode = mode
self.async_communication = async_communication
self.use_MQAColumnLinearReduceScatterAsyncCommunication = (
self.mode is TensorParallelLinearMode.REDUCE_SCATTER and self.async_communication is True
)
# allocating tensor
# We don't need to make them persistent as we expose this storage via `self.q` and `self.kv`
self.register_buffer(
"_qkv_weight",
torch.empty(
self.q_out_features + self.kv_out_features,
self.in_features,
device=device,
dtype=dtype,
# We use another specific path that doesn't use `_qkv_weight`
requires_grad=not self.use_MQAColumnLinearReduceScatterAsyncCommunication,
),
persistent=False,
)
if bias is True:
self.register_buffer(
"_qkv_bias",
torch.empty(
self.q_out_features + self.kv_out_features,
device=device,
dtype=dtype,
requires_grad=not self.use_MQAColumnLinearReduceScatterAsyncCommunication,
),
persistent=False,
)
else:
self._qkv_bias = None
# Register parameters
# We are very lucky because the sharding allows parameters to still be contiguous.
# We use a hack to propagate gradients
q_param_dict = {"weight": get_sliced_parameter(self._qkv_weight, slice_object=slice(self.q_out_features))}
kv_param_dict = {
"weight": get_sliced_parameter(self._qkv_weight, slice_object=slice(self.q_out_features, None))
}
if bias is True:
q_param_dict["bias"] = get_sliced_parameter(self._qkv_bias, slice_object=slice(self.q_out_features))
kv_param_dict["bias"] = get_sliced_parameter(self._qkv_bias, slice_object=slice(self.q_out_features, None))
self.q = nn.ParameterDict(q_param_dict)
self.kv = nn.ParameterDict(kv_param_dict)
# Marking as tied/sharded
mark_all_parameters_in_module_as_sharded(self.q, pg=self.pg, split_config=SplitConfig(split_dim=0))
# Init
self.reset_parameters()
def reset_parameters(self) -> None:
"""Copied from nn.Linear.reset_parameters"""
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self._qkv_weight, a=math.sqrt(5))
if self._qkv_bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self._qkv_weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self._qkv_bias, -bound, bound)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.use_MQAColumnLinearReduceScatterAsyncCommunication:
assert self._qkv_weight.requires_grad is False
assert self._qkv_bias is None or self._qkv_bias.requires_grad is False
return _MQAColumnLinearReduceScatterAsyncCommunication.apply(
x, self.q.weight, self.q.bias, self.kv.weight, self.kv.bias, self._qkv_weight, self.pg
)
qkv = column_linear(
input=x,
weight=self._qkv_weight,
bias=self._qkv_bias,
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
)
q, kv = torch.split(qkv, dim=-1, split_size_or_sections=[self.q_out_features, self.kv_out_features])
return q, kv
class CausalSelfMQA(nn.Module, AttachableStore):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
self.tp_pg_size = tp_pg.size()
self.n_heads = config.num_attention_heads // tp_pg.size()
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.mode = tp_mode
self.pg = tp_pg
# only Q_size is parallelized
self.qkv = MQAColumnLinears(
in_features=self.d_model,
q_out_features=config.num_attention_heads * self.d_qk,
kv_out_features=self.d_qk + self.d_v,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
)
self.maybe_rotary = (
StarcoderRotaryEmbedding(head_dim=self.d_qk, base=config.rope_theta)
if config.use_rotary_embeddings
else lambda q, k, t: (q, k)
)
self.o = TensorParallelRowLinear(
config.num_attention_heads * self.d_v,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
assert config.multi_query is True
assert config.grouped_query is False
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_dim]
sequence_mask, # [batch_size, seq_length]
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
batch_size = hidden_states.shape[1]
def unshape(states):
"""Given a [batch_dim * seq_length, num_heads, d_v] returns a [seq_length, batch_dim, num_heads * d_v]"""
if states.ndim == 3:
total = states.shape[0]
assert total % batch_size == 0
seq_length = total // batch_size
else:
seq_length = states.shape[1]
return (
states.view(batch_size, seq_length, self.n_heads, self.d_v)
.transpose(0, 1)
.contiguous()
.view(seq_length, batch_size, self.n_heads * self.d_v)
)
def shape(
query_states, # [q_length, batch_size, num_heads * d_qk]
kv_states, # [kv_length, batch_size, d_qk + d_v]
):
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
q_length = query_states.shape[0]
kv_length = kv_states.shape[0]
query_states = query_states.view(
q_length, batch_size, self.n_heads, self.d_qk
) # [q_length, batch_size, num_heads, d_qk]
query_states = (
query_states.permute(1, 0, 2, 3).contiguous().view(batch_size, q_length, self.n_heads, self.d_qk)
) # [batch_size, q_length, num_heads, d_qk]
key_states, value_states = torch.split(
kv_states, [self.d_qk, self.d_v], dim=-1
) # [kv_length, batch_size, d_qk], [kv_length, batch_size, d_v]
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, kv_length, self.d_qk).unsqueeze(dim=2)
) # [batch_size, kv_length, 1, d_qk]
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, kv_length, self.d_v).unsqueeze(dim=2)
) # [batch_size, kv_length, 1, d_v]
return query_states, key_states, value_states
# get query/key/value states
query_states, kv_states = self.qkv(
hidden_states
) # [seq_length, batch_size, num_heads * d_qk], [seq_length, batch_size, d_qk + d_v]
query_states, key_states, value_states = shape(query_states=query_states, kv_states=kv_states)
# [batch_size, q_length, num_heads, d_qk], [batch_size, kv_length, 1, d_qk], [batch_size, kv_length, 1, d_v]
seq_length_dim = 1
q_length = query_states.shape[seq_length_dim]
# Get cached key/values from store if available
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert kv_states.requires_grad is False
assert value_states.requires_grad is False
# Compute rotary embeddings
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
past_key_values_length = store["past_key_values_length"]
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
past_key_values_length = 0
position_offsets = position_ids[:, -1]
query_states, key_states = self.maybe_rotary(
query_states, key_states, past_key_values_length=past_key_values_length
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
1,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, 1, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_heads, d_qk)
k=key_unpad, # (total_kv, 1, d_qk)
v=value_unpad, # (total_kv, 1, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(batch_size, kv_length, 1, self.d_qk) # [batch_size, kv_length, 1, d_qk]
value_states = value_states.view(batch_size, kv_length, 1, self.d_v) # [batch_size, kv_length, 1, d_v]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
"past_key_values_length": past_key_values_length,
}
)
else:
query_states, key_states = self.maybe_rotary(query_states, key_states, past_key_values_length=0)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[seq_length_dim]
query_states = query_states.view(batch_size * q_length, self.n_heads, self.d_qk)
key_states = key_states.view(batch_size * kv_length, 1, self.d_qk)
value_states = value_states.view(batch_size * kv_length, 1, self.d_v)
attention_output = self.attention(
query_states=query_states, # [batch_size * q_length, num_heads, d_qk]
key_states=key_states, # [batch_size * kv_length, 1, d_qk]
value_states=value_states, # [batch_size * kv_length, 1, d_v]
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
) # [batch_size, num_heads, seq_length, d_v]
output = self.o(unshape(attention_output))
return {"hidden_states": output, "sequence_mask": sequence_mask}
############################
# GQA
############################
class CausalSelfGQA(nn.Module, AttachableStore):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
self.maybe_rotary = (
StarcoderRotaryEmbedding(head_dim=self.head_dim, base=config.rope_theta)
if config.use_rotary_embeddings
else lambda q, k, t: (q, k)
)
self.num_kv_heads = config.num_kv_heads if (not config.multi_query) else 1
self.n_local_q_heads = self.num_heads // tp_pg.size()
self.n_local_kv_heads = config.num_kv_heads // tp_pg.size()
assert (
config.num_kv_heads >= tp_pg.size()
), f"Number of kv heads ({config.num_kv_heads}) must be >= TP size ({tp_pg.size()})."
self.n_repeats = self.n_local_q_heads // self.n_local_kv_heads
qkv_contiguous_chunks = None
self.query_key_value = TensorParallelColumnLinear(
self.hidden_size,
self.num_heads * self.head_dim + 2 * self.num_kv_heads * self.head_dim,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
)
self.dense = TensorParallelRowLinear(
self.hidden_size,
self.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=True,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
assert config.multi_query is False
assert config.grouped_query is True
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # (seq_length, batch_size, hidden_size)
sequence_mask, # (batch_size, seq_length)
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
fused_qkv = self.query_key_value(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * head_dim + 2 * n_local_kv_heads * head_dim]
q_length, batch_size, _ = fused_qkv.size()
qkv = fused_qkv.view(q_length, batch_size, self.n_local_kv_heads, self.n_repeats + 2, self.head_dim)
query, key, value = torch.split(qkv, [self.n_repeats, 1, 1], dim=3)
query_states = query.transpose(0, 1).reshape(
batch_size, q_length, self.n_local_q_heads, self.head_dim
) # TODO @nouamane: can we transpose qkv instead?
key_states = key.transpose(0, 1).reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim)
value_states = value.transpose(0, 1).reshape(batch_size, q_length, self.n_local_kv_heads, self.head_dim)
# Get cached key/values from store if available
store = self.get_local_store()
if store is not None:
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
# Compute rotary embeddings
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
past_key_values_length = store["past_key_values_length"]
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
past_key_values_length = 0
position_offsets = position_ids[:, -1]
query_states, key_states = self.maybe_rotary(
query_states, key_states, past_key_values_length=past_key_values_length
)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.head_dim,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.head_dim),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, self.n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, self.n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, self.n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.head_dim
) # [batch_size, q_length, self.n_local_q_heads, self.head_dim]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size, kv_length, self.n_local_kv_heads, self.head_dim]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
# Update store
if past_key_values_length == 0:
past_key_values_length = sequence_mask.shape[1] - 1 # we add 1 when we load the value
else:
past_key_values_length += 1
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
"past_key_values_length": past_key_values_length,
}
)
else:
# Apply rotary embeddings to query/key states
query_states, key_states = self.maybe_rotary(query_states, key_states, past_key_values_length=0)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, head_dim]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.reshape(
batch_size * q_length, self.n_local_q_heads, self.head_dim
) # [batch_size * q_length, self.n_local_q_heads, head_dim]
key_states = key_states.reshape(
batch_size * kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size * kv_length, self.n_local_kv_heads, head_dim]
value_states = value_states.reshape(
batch_size * kv_length, self.n_local_kv_heads, self.head_dim
) # [batch_size * kv_length, self.n_local_kv_heads, head_dim]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
) # [batch_size * seq_length, self.n_local_q_heads, head_dim]
attention_output = attention_output.view(batch_size, q_length, self.n_local_q_heads * self.head_dim).transpose(
0, 1
)
output = self.dense(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
@torch.jit.script
def dropout_add(x, residual, prob, training):
# type: (Tensor, Tensor, float, bool) -> Tensor
# From: https://github.com/NVIDIA/Megatron-LM/blob/285068c8108e0e8e6538f54fe27c3ee86c5217a2/megatron/model/transformer.py#L586
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def dropout_add_fused_train(x: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor:
return dropout_add(x, residual, prob, True)
class GPTBlock(nn.Module):
def __init__(
self,
config: Starcoder2Config,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
random_states: RandomStates,
layer_idx: int,
):
super(GPTBlock, self).__init__()
self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
if config.multi_query is True:
self.attn = CausalSelfMQA(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
elif config.grouped_query is True:
self.attn = CausalSelfGQA(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
else:
raise ValueError("Either `multi_query` or `grouped_query` must be True") # TODO: @nouamane not necessarily
self.attn_dropout = config.attn_pdrop
self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
self.ff_dropout = config.resid_pdrop
self.random_states = random_states
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
if self.training:
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout)
else:
# No need for random state context manager
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"]
if self.training:
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout)
else:
# No need for random state context manager
hidden_states = hidden_states + residual
return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: Starcoder2Config, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
# store = self.get_local_store()
# if store is not None:
# if "past_length" in store:
# past_length = store["past_length"]
# else:
# past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
# cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# # Store new past_length in store
# store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class GPTModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: Starcoder2Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.random_states = random_states
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
self.token_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
self.embeds_dropout = PipelineBlock(
p2p=self.p2p,
module_builder=nn.Dropout,
module_kwargs={"p": config.embd_pdrop},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=GPTBlock,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"random_states": random_states,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonLayerNorm,
module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
)
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": parallel_config.tp_linear_async_communication
if parallel_config is not None
else False,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]
with branch_random_state(
self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE
):
hidden_states = self.embeds_dropout(input=input_embeds)["hidden_states"]
hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(
0, 1
) # TODO @nouamane: case where TP=1 should be simpler
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class Starcoder2ForTraining(NanotronModel):
def __init__(
self,
config: Starcoder2Config,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: RandomStates,
):
super().__init__()
self.model = GPTModel(
config=config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=random_states,
)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.config: Starcoder2Config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Union[torch.Tensor, TensorPointer]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
return {
"loss": self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
}
def tie_custom_params(self) -> None:
# find all params with names qkv.kv.weight and qkv.kv.bias in them
for module_name, module in self.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if ".qkv.kv." in name:
assert not param.is_tied, f"Parameter {name} is already tied"
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(self.parallel_context.tp_pg))),
)
]
tie_parameters(
root_module=self,
ties=shared_weights,
parallel_context=self.parallel_context,
# We always SUM grads, because kv weights are always duplicated in MQA
reduce_op=dist.ReduceOp.SUM,
)
@torch.no_grad()
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, LayerNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, MQAColumnLinears):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_embeddings_lm_head_tied_names(self) -> List[str]:
return [
"model.token_embeddings.pp_block.token_embedding.weight",
"model.lm_head.pp_block.weight",
]
def before_tbi_sanity_checks(self):
# SANITY CHECK: Check ".qkv.kv." params are tied
for name, kv_param in self.named_parameters():
if ".qkv.kv." in name:
assert kv_param.is_tied, f"{name} is not tied (kv weights/biases should be tied in GPTBigcode)"
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 2 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size,
seq_len=sequence_length,
batch_size=global_batch_size,
kv_channels=None,
glu_activation=False,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
def get_flops(
num_layers,
hidden_size,
num_heads,
vocab_size,
seq_len,
kv_channels=None,
ffn_hidden_size=None,
batch_size=1,
glu_activation=False,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
kv_channels: hidden size of the key and value heads
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if kv_channels is None:
assert hidden_size % num_heads == 0
kv_channels = hidden_size // num_heads
if ffn_hidden_size is None:
ffn_hidden_size = 4 * hidden_size
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention (MQA)
## q projection
decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels
## kv projection, shared across heads
decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len
### SWA (sliding window attention / local attention)
# window_size = 4096
# decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels
# decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels
## attn out
decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
if glu_activation:
# 3 matmuls instead of 2 in FFN
# ref. https://arxiv.org/pdf/2002.05202.pdf
# Used for example in T5 v1.1
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_q_proj_flops_fwd
+ decoder_kv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now
return model_flops, hardware_flops
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import torch
from packaging import version
from torch import Tensor, nn
from nanotron import logging
logger = logging.get_logger(__name__)
class PytorchGELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://arxiv.org/abs/1606.08415.
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
match due to rounding errors.
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"PytorchGELUTanh. Please upgrade torch."
)
def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Module):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return torch.clip(gelu(x), self.min, self.max)
class AccurateGELUActivation(nn.Module):
"""
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
https://github.com/hendrycks/GELUs
Implemented along with MEGA (Moving Average Equipped Gated Attention)
"""
def __init__(self):
super().__init__()
self.precomputed_constant = math.sqrt(2 / math.pi)
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def forward(self, input: Tensor) -> Tensor:
return nn.functional.silu(input)
class MishActivation(nn.Module):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9.0"):
self.act = self._mish_python
else:
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * torch.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
class LaplaceActivation(nn.Module):
"""
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
https://arxiv.org/abs/2209.10655
Inspired by squared relu, but with bounded range and gradient for better stability
"""
def forward(self, input, mu=0.707107, sigma=0.282095):
input = (input - mu).div(sigma * math.sqrt(2.0))
return 0.5 * (1.0 + torch.erf(input))
class ReLUSquaredActivation(nn.Module):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
def forward(self, input):
relu_applied = nn.functional.relu(input)
squared = torch.square(relu_applied)
return squared
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
import torch
from torch import nn
class TritonLayerNorm(nn.LayerNorm):
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=False,
return_dropout_mask=return_dropout_mask,
)
# This is equivalent to LLaMA RMSNorm
# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
return layer_norm_fn(
input,
self.weight,
None,
residual=residual,
eps=self.eps,
dropout_p=dropout_p,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
is_rms_norm=True,
return_dropout_mask=return_dropout_mask,
)
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.inherit_from_other_optimizer import InheritFromOtherOptimizer
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator
from nanotron.optim.zero import ZeroDistributedOptimizer
__all__ = [
"BaseOptimizer",
"InheritFromOtherOptimizer",
"NamedOptimizer",
"OptimizerFromGradientAccumulator",
"ZeroDistributedOptimizer",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment