Commit ccfcffb1 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #805 canceled with stages
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk23.10-py38
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
RUN source /opt/dtk-23.10/env.sh
# 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -r requirements.txt
torch>=2.1.0dev
lightning==2.1.2
lightning[app]
jsonargparse[signatures] # CLI
pandas
pyarrow
tokenizers
sentencepiece
wandb
zstd
# for finetuning
bitsandbytes==0.40.0
transformers==4.31.0
peft==0.4.0
accelerate==0.21.0
einops==0.6.1
evaluate==0.4.0
scikit-learn==1.2.2
sentencepiece==0.1.99
wandb==0.15.3
# other optional dependencies are
# sentencepiece # pythia, falcon, redpajama
# tokenizers # llama-based models
# bitsandbytes>=0.41.1 # quantize/bnb.py
# scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released
# datasets # quantize/gptq.py
# zstandard # scripts/prepare_redpajama.py
# git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval
from lit_gpt.model import GPT
from lit_gpt.config import Config
from lit_gpt.tokenizer import Tokenizer
from lit_gpt.fused_cross_entropy import FusedCrossEntropyLoss
from lightning_utilities.core.imports import RequirementCache
if not bool(RequirementCache("torch>=2.1.0dev")):
raise ImportError(
"Lit-GPT requires torch nightly (future torch 2.1). Please follow the installation instructions in the"
" repository README.md"
)
_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
if not bool(_LIGHTNING_AVAILABLE):
raise ImportError(
"Lit-GPT requires Lightning nightly (future lightning 2.1). Please run:\n"
f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
)
__all__ = ["GPT", "Config", "Tokenizer"]
"""Implementation of the paper:
LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention
https://arxiv.org/abs/2303.16199
Port for Lit-GPT
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from typing_extensions import Self
from lit_gpt.config import Config as BaseConfig
from lit_gpt.model import GPT as BaseModel
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import KVCache, RoPECache, apply_rope
@dataclass
class Config(BaseConfig):
adapter_prompt_length: int = 10
adapter_start_layer: int = 2
class GPT(BaseModel):
"""The implementation is identical to `lit_gpt.model.GPT` with the exception that
the `Block` saves the layer index and passes it down to the attention layer."""
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []
self.adapter_kv_caches: List[KVCache] = []
def reset_cache(self) -> None:
super().reset_cache()
self.adapter_kv_caches.clear()
def forward(
self,
idx: torch.Tensor,
max_seq_length: Optional[int] = None,
input_pos: Optional[torch.Tensor] = None,
lm_head_chunk_size: int = 0,
) -> Union[torch.Tensor, List[torch.Tensor]]:
B, T = idx.size()
use_kv_cache = input_pos is not None
block_size = self.config.block_size
if max_seq_length is None:
max_seq_length = block_size
if use_kv_cache: # not relevant otherwise
assert (
max_seq_length >= T
), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"
if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx)
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
if use_kv_cache and self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx)
cos, sin = self.rope_cache
if use_kv_cache:
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
cos = cos[:T]
sin = sin[:T]
mask = None
# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if not use_kv_cache:
for block in self.transformer.h:
x, *_ = block(x, (cos, sin), max_seq_length)
else:
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))
self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)]
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i], self.adapter_kv_caches[i] = block(
x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i]
)
x = self.transformer.ln_f(x)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
return self.lm_head(x) # (b, t, vocab_size)
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, CausalSelfAttention):
module.reset_parameters()
class Block(nn.Module):
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__()
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config, block_idx)
if not config.shared_attention_norm:
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.config = config
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
adapter_kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
n_1 = self.norm_1(x)
h, new_kv_cache, new_adapter_kv_cache = self.attn(
n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache
)
if self.config.parallel_residual:
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = x + h + self.mlp(n_2)
else:
if self.config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration"
" (non-parallel residual and shared attention norm)."
)
x = x + h
x = x + self.mlp(self.norm_2(x))
return x, new_kv_cache, new_adapter_kv_cache
class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
self.reset_parameters()
self.block_idx = block_idx
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
adapter_kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
# repeat k and v if necessary
if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
# for MHA this is a no-op
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
n_elem = int(self.config.rotary_percentage * self.config.head_size)
cos, sin = rope
q_roped = apply_rope(q[..., :n_elem], cos, sin)
k_roped = apply_rope(k[..., :n_elem], cos, sin)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
if kv_cache is not None:
cache_k, cache_v = kv_cache
cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
# check if reached token limit
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy_(2, input_pos, k)
v = cache_v.index_copy_(2, input_pos, v)
kv_cache = k, v
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
if self.block_idx >= self.config.adapter_start_layer:
aT = self.config.adapter_prompt_length
if adapter_kv_cache is not None:
ak, av = adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, C)
aqkv = self.attn(prefix)
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
if self.config.n_query_groups != 1:
# for MHA this is a no-op
ak = ak.repeat_interleave(q_per_kv, dim=2)
av = av.repeat_interleave(q_per_kv, dim=2)
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
adapter_kv_cache = (ak, av)
amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)
ay = self.scaled_dot_product_attention(q, ak, av, amask)
y = y + self.gating_factor * ay
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.proj(y)
return y, kv_cache, adapter_kv_cache
def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.gating_factor)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with older checkpoints."""
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def mark_only_adapter_as_trainable(model: GPT) -> None:
"""Sets `requires_grad=False` for all non-adapter weights."""
for name, param in model.named_parameters():
param.requires_grad = adapter_filter(name, param)
def adapter_filter(key: str, value: Any) -> bool:
return "adapter_wte" in key or "gating_factor" in key
"""Implementation of the paper:
LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model
https://arxiv.org/abs/2304.15010
Port for Lit-GPT
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch.nn as nn
from typing_extensions import Self
import lit_gpt
from lit_gpt.adapter import GPT as BaseModel
from lit_gpt.adapter import Block as BaseBlock
from lit_gpt.adapter import Config as BaseConfig
from lit_gpt.adapter import KVCache, RoPECache
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import apply_rope
from lit_gpt.utils import map_old_state_dict_weights
@dataclass
class Config(BaseConfig):
@property
def mlp_class(self) -> Type:
return getattr(lit_gpt.adapter_v2, self._mlp_class)
def adapter_filter(key: str, value: Any) -> bool:
adapter_substrings = (
# regular adapter v1 parameters
"adapter_wte",
"gating_factor",
# adapter v2: new bias and scale used in Linear
"adapter_scale",
"adapter_bias",
# adapter v2: Norm parameters are now trainable
"norm_1",
"norm_2",
"ln_f",
)
return any(s in key for s in adapter_substrings)
class AdapterV2Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, **kwargs) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)
self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)
self.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.adapter_scale * (self.linear(x) + self.adapter_bias)
def reset_parameters(self) -> None:
nn.init.zeros_(self.adapter_bias)
nn.init.ones_(self.adapter_scale)
class GPT(BaseModel):
def __init__(self, config: Config) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []
self.adapter_kv_caches: List[KVCache] = []
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, CausalSelfAttention):
module.reset_parameters()
if isinstance(module, AdapterV2Linear):
module.reset_parameters()
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"lm_head.weight": "lm_head.linear.weight"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class Block(BaseBlock):
"""The implementation is identical to `lit_gpt.model.Block` with the exception that
we replace the attention layer where adaption is implemented."""
def __init__(self, config: Config, block_idx: int) -> None:
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config, block_idx)
if not config.shared_attention_norm:
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.config = config
class CausalSelfAttention(BaseCausalSelfAttention):
def __init__(self, config: Config, block_idx: int) -> None:
"""Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
parameter-efficient fine-tuning.
*Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
query, key and value for each head) we can do this in a single pass with a single weight matrix.
"""
# Skip the parent class __init__ altogether and replace it to avoid useless allocations
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias)
# output projection
self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
# gate for adaption
self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))
self.reset_parameters()
self.block_idx = block_idx
self.config = config
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
adapter_kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
# repeat k and v if necessary
if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
# for MHA this is a no-op
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
n_elem = int(self.config.rotary_percentage * self.config.head_size)
cos, sin = rope
q_roped = apply_rope(q[..., :n_elem], cos, sin)
k_roped = apply_rope(k[..., :n_elem], cos, sin)
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
if kv_cache is not None:
cache_k, cache_v = kv_cache
cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
# check if reached token limit
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy_(2, input_pos, k)
v = cache_v.index_copy_(2, input_pos, v)
kv_cache = k, v
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
if self.block_idx >= self.config.adapter_start_layer:
aT = self.config.adapter_prompt_length
if adapter_kv_cache is not None:
ak, av = adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, C)
aqkv = self.attn(prefix)
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
_, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)
if self.config.n_query_groups != 1:
# for MHA this is a no-op
ak = ak.repeat_interleave(q_per_kv, dim=2)
av = av.repeat_interleave(q_per_kv, dim=2)
ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs)
av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs)
adapter_kv_cache = (ak, av)
amask = torch.ones(T, aT, dtype=torch.bool, device=x.device)
ay = self.scaled_dot_product_attention(q, ak, av, amask)
y = y + self.gating_factor * ay
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.proj(y)
return y, kv_cache, adapter_kv_cache
def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.gating_factor)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
# For compatibility with older checkpoints
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc.weight": "fc.linear.weight",
"fc.bias": "fc.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc_1.weight": "fc_1.linear.weight",
"fc_1.bias": "fc_1.linear.bias",
"fc_2.weight": "fc_2.linear.weight",
"fc_2.bias": "fc_2.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def mark_only_adapter_v2_as_trainable(model: GPT) -> None:
"""Sets requires_grad=False for all non-adapter weights"""
for name, param in model.named_parameters():
param.requires_grad = adapter_filter(name, param)
from dataclasses import dataclass
from typing import Any, Literal, Optional, Type
import torch
from typing_extensions import Self
import lit_gpt.model
from lit_gpt.utils import find_multiple
@dataclass
class Config:
org: str = "Lightning-AI"
name: str = "lit-GPT"
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
# Example with `n_head=4`
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ │ │ │
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
_norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
intermediate_size: Optional[int] = None
condense_ratio: int = 1
def __post_init__(self):
# error checking
assert self.n_embd % self.n_head == 0
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
# compute the number of query groups
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
# compute the intermediate size for MLP if not set
if self.intermediate_size is None:
if self._mlp_class == "LLaMAMLP":
raise ValueError("The config needs to set the `intermediate_size`")
self.intermediate_size = 4 * self.n_embd
@property
def head_size(self) -> int:
return self.n_embd // self.n_head
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
conf_dict = name_to_config[name].copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@property
def mlp_class(self) -> Type:
# `self._mlp_class` cannot be the type to keep the config json serializable
return getattr(lit_gpt.model, self._mlp_class)
@property
def norm_class(self) -> Type:
# `self._norm_class` cannot be the type to keep the config json serializable
if self._norm_class == "RMSNorm":
from lit_gpt.rmsnorm import RMSNorm
return RMSNorm
elif self._norm_class == "FusedRMSNorm":
from lit_gpt.rmsnorm import FusedRMSNorm
return FusedRMSNorm
return getattr(torch.nn, self._norm_class)
########################
# Stability AI StableLM
########################
configs = [
# https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
dict(org="stabilityai", name="stablelm-base-alpha-3b", padding_multiple=512),
# https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
dict(org="stabilityai", name="stablelm-base-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
dict(org="stabilityai", name="stablelm-tuned-alpha-3b", n_head=32, padding_multiple=512),
# https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
dict(org="stabilityai", name="stablelm-tuned-alpha-7b", n_head=48, n_embd=6144, padding_multiple=256),
]
####################
# EleutherAI Pythia
####################
pythia = [
# https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
dict(org="EleutherAI", name="pythia-70m", block_size=2048, n_layer=6, n_embd=512, n_head=8, padding_multiple=128),
# https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
dict(
org="EleutherAI", name="pythia-160m", block_size=2048, n_layer=12, n_embd=768, n_head=12, padding_multiple=128
),
# https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
dict(
org="EleutherAI", name="pythia-410m", block_size=2048, n_layer=24, n_embd=1024, n_head=16, padding_multiple=128
),
# https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
dict(org="EleutherAI", name="pythia-1b", block_size=2048, n_layer=16, n_embd=2048, n_head=8, padding_multiple=128),
# https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
dict(
org="EleutherAI", name="pythia-1.4b", block_size=2048, n_layer=24, n_embd=2048, n_head=16, padding_multiple=128
),
# https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
dict(
org="EleutherAI", name="pythia-2.8b", block_size=2048, n_layer=32, n_embd=2560, n_head=32, padding_multiple=128
),
# https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
dict(
org="EleutherAI", name="pythia-6.9b", block_size=2048, n_layer=32, n_embd=4096, n_head=32, padding_multiple=256
),
# https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
dict(
org="EleutherAI", name="pythia-12b", block_size=2048, n_layer=36, n_embd=5120, n_head=40, padding_multiple=512
),
]
configs.extend(pythia)
for c in pythia:
copy = c.copy()
copy["name"] = f"{c['name']}-deduped"
configs.append(copy)
####################################
# togethercomputer RedPajama INCITE
####################################
redpajama_incite = [
# https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
dict(
org="togethercomputer",
name="RedPajama-INCITE-{}-3B-v1",
block_size=2048,
n_layer=32,
n_embd=2560,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
),
# https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
dict(
org="togethercomputer",
name="RedPajama-INCITE-7B-{}",
block_size=2048,
n_layer=32,
n_embd=4096,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
),
# this redirects to the checkpoint above. kept for those who had the old weights already downloaded
dict(
org="togethercomputer",
name="RedPajama-INCITE-{}-7B-v0.1",
block_size=2048,
n_layer=32,
n_embd=4096,
n_head=32,
padding_multiple=256,
rotary_percentage=1.0,
parallel_residual=False,
),
]
for c in redpajama_incite:
for kind in ("Base", "Chat", "Instruct"):
copy = c.copy()
copy["name"] = c["name"].format(kind)
configs.append(copy)
#################
# TII UAE Falcon
#################
falcon = [
# https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
dict(
org="tiiuae",
name="falcon-7b{}",
block_size=2048,
padded_vocab_size=65024,
n_layer=32,
n_head=71,
n_embd=4544,
rotary_percentage=1.0,
parallel_residual=True,
n_query_groups=1,
bias=False,
# this is not in the config, but in the original model implementation, only for this config
shared_attention_norm=True,
),
# https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
dict(
org="tiiuae",
name="falcon-40b{}",
block_size=2048,
padded_vocab_size=65024,
n_layer=60,
n_head=128,
n_embd=8192,
rotary_percentage=1.0,
parallel_residual=True,
n_query_groups=8,
bias=False,
),
]
for c in falcon:
for kind in ("", "-instruct"):
copy = c.copy()
copy["name"] = c["name"].format(kind)
configs.append(copy)
#############################
# StatNLP Research
#############################
tiny_LLaMA = [
# https://twitter.com/cwolferesearch/status/1691929174175264858
dict(
org="StatNLP-research",
name="tiny_LLaMA_1b",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=22,
n_head=32,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6
_mlp_class="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
),
dict(
org="StatNLP-research",
name="tiny_LLaMA_120M",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=12,
n_embd=768,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=2048,
n_query_groups=1,
),
dict(
org="StatNLP-research",
name="code_tiny_LLaMA_1b",
block_size=8192,
vocab_size=32000,
padding_multiple=64,
n_layer=22,
n_head=32,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6
_mlp_class="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
condense_ratio= 4
),
]
configs.extend(tiny_LLaMA)
#############################
# OpenLM Research Open LLaMA
#############################
open_LLaMA = [
# https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
dict(
org="openlm-research",
name="open_llama_3b",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=26,
n_head=32,
n_embd=3200,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=8640,
),
# https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
dict(
org="openlm-research",
name="open_llama_7b",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
dict(
org="openlm-research",
name="open_llama_13b",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
),
]
configs.extend(open_LLaMA)
###############
# LMSYS Vicuna
###############
vicuna = [
# https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
dict(
org="lmsys",
name="vicuna-7b-v1.3",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
dict(
org="lmsys",
name="vicuna-13b-v1.3",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
dict(
org="lmsys",
name="vicuna-33b-v1.3",
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=60,
n_head=52,
n_embd=6656,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=17920,
),
dict(
org="lmsys",
name="vicuna-7b-v1.5",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
),
dict(
org="lmsys",
name="vicuna-7b-v1.5-16k",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
condense_ratio=4,
),
dict(
org="lmsys",
name="vicuna-13b-v1.5",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
),
dict(
org="lmsys",
name="vicuna-13b-v1.5-16k",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
condense_ratio=4,
),
]
configs.extend(vicuna)
#################
# LMSYS LongChat
#################
long_chat = [
# https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
dict(
org="lmsys",
name="longchat-7b-16k",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
condense_ratio=8,
),
# https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
dict(
org="lmsys",
name="longchat-13b-16k",
block_size=16384,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
condense_ratio=8,
),
]
configs.extend(long_chat)
######################
# NousResearch Hermes
######################
nous_research = [
# https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
dict(
org="NousResearch",
name="Nous-Hermes-13b",
block_size=2048,
padded_vocab_size=32001,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
)
]
configs.extend(nous_research)
###############
# Meta LLaMA 2
###############
llama_2 = [
# https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
dict(
org="meta-llama",
name="Llama-2-7b{}-hf",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
),
dict(
org="meta-llama",
name="CodeLlama-2-7b-hf",
block_size=4096,
vocab_size=32016,
padded_vocab_size=32016,
padding_multiple=64,
n_layer=32,
n_head=32,
n_embd=4096,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=11008,
),
# https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
dict(
org="meta-llama",
name="Llama-2-13b{}-hf",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=40,
n_head=40,
n_embd=5120,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=13824,
),
# https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
dict(
org="meta-llama",
name="Llama-2-70b{}-hf",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=28672,
),
]
for c in llama_2:
for kind in ("", "-chat"):
copy = c.copy()
copy["name"] = c["name"].format(kind)
configs.append(copy)
##########################
# Stability AI FreeWilly2
##########################
freewilly_2 = [
# https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
dict(
org="stabilityai",
name="FreeWilly2",
block_size=4096,
vocab_size=32000,
padding_multiple=64,
n_layer=80,
n_head=64,
n_embd=8192,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_eps=1e-5,
_mlp_class="LLaMAMLP",
intermediate_size=28672,
)
]
configs.extend(freewilly_2)
name_to_config = {config["name"]: config for config in configs}
# Copyright (c) 2023, Tri Dao.
import torch
import torch.nn as nn
import xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
smoothing=0.0,
ignored_index=-100,
inplace_backward=False,
process_group=None,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
If process_group is not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss needs to be aggregated across processes.
"""
batch, vocab_size = logits.shape
assert labels.shape == (batch,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
ctx.total_classes = world_size * vocab_size
if world_size == 1:
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
losses.masked_fill_(labels == ignored_index, 0)
labels_local = labels
else:
rank = torch.distributed.get_rank(process_group)
vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index)
ignored_mask = labels == ignored_index
labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index)
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses, lse_local = xentropy_cuda_lib.forward(
logits, labels_local, smoothing, world_size * vocab_size
)
assert lse_local.shape == (batch,)
assert losses.shape == (batch,)
losses.masked_fill_(ignored_mask, 0)
# For labels == ignored_index, the loss is always 0.
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# lse_local - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes)
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather = torch.empty(
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
)
torch.distributed.all_gather_into_tensor(
lse_allgather, lse_local.contiguous(), group=process_group
)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
# If there's no smoothing, the total losses are lse_local - predicted_logit,
# we just have to subtract the lse_local and add the lse (global).
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
lse_local = lse_allgather[
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
]
handle_losses.wait()
if smoothing == 0.0:
losses += lse - lse_local
else:
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
lse - lse_allgather.sum(dim=0)
)
losses.masked_fill_(ignored_mask, 0)
ctx.save_for_backward(logits, lse, labels_local)
ctx.smoothing = smoothing
ctx.ignored_index = ignored_index
ctx.inplace_backward = inplace_backward
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, lse, labels = ctx.saved_tensors
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
grad_logits = xentropy_cuda_lib.backward(
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
)
return grad_logits, None, None, None, None, None, None
class FusedCrossEntropyLoss(nn.Module):
def __init__(
self,
ignore_index=-100,
reduction="mean",
label_smoothing=0.0,
inplace_backward=True,
process_group=None,
):
super().__init__()
if reduction not in ["mean", "none"]:
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
self.ignore_index = ignore_index
self.reduction = reduction
self.label_smoothing = label_smoothing
self.inplace_backward = inplace_backward
self.process_group = process_group
def forward(self, input, target):
assert input.is_cuda and target.is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
if len(input.shape) == 3:
input = input.view(-1, input.size(-1))
target = target.view(-1)
loss = SoftmaxCrossEntropyLossFn.apply(
input,
target,
self.label_smoothing,
self.ignore_index,
self.inplace_backward,
self.process_group,
)
if self.reduction == "mean":
return loss.sum() / (target != self.ignore_index).sum()
else:
return loss
\ No newline at end of file
# Copyright (c) 2023, Tri Dao.
import math
from typing import Optional, Tuple
import rotary_emb
import torch
from einops import rearrange, repeat
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (
out_ro.chunk(2, dim=-1)
if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2])
)
rotary_emb.apply_rotary(
x1,
x2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
o1,
o2,
False,
)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x
@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
)
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (
dx_ro.chunk(2, dim=-1)
if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2])
)
rotary_emb.apply_rotary(
do1,
do2,
rearrange(cos[:seqlen], "s d -> s 1 d"),
rearrange(sin[:seqlen], "s d -> s 1 d"),
dx1,
dx2,
True,
)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None
apply_rotary_emb_func = ApplyRotaryEmb.apply
# Derived from https://github.com/microsoft/LoRA
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
r"""
Low Ranking Adaptation for LLMs scheme.
┌───────────────────┐
┆ h ┆
└───────────────────┘
|
+
/ \
┌─────────────────┐ ╭───────────────╮ Matrix initialization:
┆ ┆ \ B / B = 0
┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
┆ weights ┆ ╰─────────╯
┆ ┆ | r | r - rank
┆ W e R^(d*d) ┆ | ◀─────▶ |
┆ ┆ ╭─────────╮
└─────────────────┘ / A \
▲ / d*r \
\ ╰───────────────╯
\ ▲
\ /
\ /
┌───────────────────┐
┆ x ┆
└───────────────────┘
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
pretrained weights and thus fine-tune the model.
The goal of this approach is to move weight updates into a separate matrix which is decomposed with
two matrices of a lower rank.
"""
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing_extensions import Self
import lit_gpt
from lit_gpt.config import Config as BaseConfig
from lit_gpt.model import GPT as BaseModel
from lit_gpt.model import Block as BaseBlock
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention
from lit_gpt.model import KVCache, RoPECache
from lit_gpt.utils import map_old_state_dict_weights
class LoRALayer(nn.Module):
def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
"""Store LoRA specific attributes in a class.
Args:
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
"""
super().__init__()
assert r >= 0
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
class LoRALinear(LoRALayer):
# LoRA implemented in a dense layer
def __init__(
self,
# ↓ this part is for pretrained weights
in_features: int,
out_features: int,
# ↓ the remaining part is for LoRA
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
"""LoRA wrapper around linear class.
This class has three weight matrices:
1. Pretrained weights are stored as `self.linear.weight`
2. LoRA A matrix as `self.lora_A`
3. LoRA B matrix as `self.lora_B`
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
Args:
in_features: number of input features of the pretrained weights
out_features: number of output features of the pretrained weights
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
"""
super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.linear.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
self.reset_parameters()
def reset_parameters(self):
"""Reset all the weights, even including pretrained ones."""
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
# Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def merge(self):
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
if self.r > 0 and not self.merged:
# Merge the weights and mark it
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
# if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
pretrained = self.linear(x)
if self.r == 0 or self.merged:
return pretrained
lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return pretrained + lora
class LoRAQKVLinear(LoRALinear):
# LoRA implemented in a dense layer
def __init__(
self,
# ↓ this part is for pretrained weights
in_features: int,
out_features: int,
# ↓ the remaining part is for LoRA
n_head: int,
n_query_groups: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
**kwargs,
):
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
This class has three weight matrices:
1. Pretrained weights are stored as `self.linear.weight`
2. LoRA A matrix as `self.lora_A`
3. LoRA B matrix as `self.lora_B`
Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
Args:
in_features: number of input features of the pretrained weights
out_features: number of output features of the pretrained weights
n_head: number of attention heads
n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
lora_alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
"""
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.n_head = n_head
self.n_query_groups = n_query_groups
if isinstance(enable_lora, bool):
enable_lora = [enable_lora] * 3
assert len(enable_lora) == 3
self.enable_lora = enable_lora
# Actual trainable parameters
# To better understand initialization let's imagine that we have such parameters:
# ⚬ in_features: 128 (embeddings_size)
# ⚬ out_features: 384 (3 * embedding_size)
# ⚬ r: 2
# ⚬ enable_lora: [True, False, True]
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(self.linear.weight.new_zeros((r * sum(enable_lora), in_features))) # (4, 128)
enable_q, enable_k, enable_v = enable_lora
self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups)
# qkv_shapes will be used to split a tensor with weights correctly
qkv_shapes = (
self.linear.in_features * enable_q,
self.kv_embd_size * enable_k,
self.kv_embd_size * enable_v,
)
self.qkv_shapes = [s for s in qkv_shapes if s]
self.lora_B = nn.Parameter(self.linear.weight.new_zeros(sum(self.qkv_shapes), r)) # (256, 2))
# Notes about shapes above
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
# 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
# F.linear function weights are automatically transposed. In addition conv1d requires channels to
# be before seq length
# - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
# 128*2; 2 tells to have two channels per group for group convolution
# Scaling:
# This balances the pretrained model`s knowledge and the new task-specific adaptation
# https://lightning.ai/pages/community/tutorial/lora-llm/
# So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
# alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
# tune these values to your needs. This value can be even slightly greater than 1.0!
# https://github.com/cloneofsimo/lora
self.scaling = self.lora_alpha / self.r
# Compute the indices
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
# but not keys, then the weights update should be:
#
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
# [....................................],
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
# ↑ ↑ ↑
# ________________________________________
# | query | key | value |
# ----------------------------------------
self.lora_ind = []
if enable_q:
self.lora_ind.extend(range(0, self.linear.in_features))
if enable_k:
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
if enable_v:
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
self.reset_parameters()
def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad weight updates with zeros.
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
then the weights update should be:
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
[....................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
↑ ↑ ↑
________________________________________
| query | key | value |
----------------------------------------
Args:
x: tensor with weights update that will be padded with zeros if necessary
Returns:
A tensor with weight updates and zeros for deselected q, k or v
"""
# we need to do zero padding only if LoRA is disabled for one of QKV matrices
if all(self.enable_lora):
return x
# Let's image that:
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
# ⚬ embeddings_size: 128
# ⚬ self.linear.out_features: 384 (3 * embeddings_size)
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
# for example when we want to merge/unmerge LoRA weights and pretrained weights
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)
result = result.index_copy(
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
If the number of heads is equal to the number of query groups - grouped queries are disabled
(see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
conv layers side by side).
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
apply each part of the weight matrix to the corresponding input's part and concatenate the result.
Args:
input: input matrix of shape (B, C, T)
weight: weight matrix of shape (C_output, rank, 1).
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
Returns:
A tensor with a shape (B, C_output, T)
"""
if self.n_head == self.n_query_groups:
return F.conv1d(input, weight, groups=sum(self.enable_lora)) # (B, C_output, T)
# Notation:
# ⚬ N: number of enabled LoRA layers (self.enable_lora)
# ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
# ⚬ r: rank of all LoRA layers (equal in size)
input_splitted = input.chunk(sum(self.enable_lora), dim=1) # N * (B, C // N, T)
weight_splitted = weight.split(self.qkv_shapes) # N * (C_output', r, 1)
return torch.cat(
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 # (B, C_output', T)
) # (B, C_output, T)
def merge(self):
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
# Let's assume that:
# ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
# ⚬ self.lora_A.data: (4, 128)
# ⚬ self.lora_B.data: (256, 2)
if self.r > 0 and any(self.enable_lora) and not self.merged:
delta_w = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).squeeze(
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
# W = W + delta_W (merge)
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
self.merged = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Do the forward pass.
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
Args:
x: input tensor of shape (batch_size, context_length, embedding_size)
Returns:
Output tensor of shape (batch_size, context_length, 3 * embedding_size)
"""
# Let's assume that:
# ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
# ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
# ⚬ self.lora_A.data: (4, 128)
# ⚬ self.lora_B.data: (256, 2)
# if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
# otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
pretrained = self.linear(x)
if self.r == 0 or not any(self.enable_lora) or self.merged:
return pretrained
after_A = F.linear(self.lora_dropout(x), self.lora_A) # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
# For F.conv1d:
# ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
# ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).transpose(
-2, -1
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
return pretrained + lora
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
Args:
model: model with LoRA layers
bias:
``"none"``: all bias weights will be frozen,
``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
``"all"``: all bias weights will be unfrozen.
Raises:
NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
"""
# freeze all layers except LoRA's
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
# depending on the `bias` value unfreeze bias weights
if bias == "none":
return
if bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
def lora_filter(key: str, value: Any) -> bool:
return "lora_" in key
@dataclass
class Config(BaseConfig):
"""
Args:
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
alpha: alpha is needed for scaling updates as alpha/r
"This scaling helps to reduce the need to retune hyperparameters when we vary r"
https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
to_*: either apply LoRA to the specified weights or not
"""
r: int = 0
alpha: int = 1
dropout: float = 0.0
to_query: bool = False
to_key: bool = False
to_value: bool = False
to_projection: bool = False
to_mlp: bool = False
to_head: bool = False
@property
def mlp_class(self) -> Type:
return getattr(lit_gpt.lora, self._mlp_class)
class GPT(BaseModel):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = LoRALinear(
config.n_embd,
config.padded_vocab_size,
bias=False,
r=(config.r if config.to_head else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []
def forward(
self,
idx: torch.Tensor,
max_seq_length: Optional[int] = None,
input_pos: Optional[torch.Tensor] = None,
lm_head_chunk_size: int = 0,
) -> Union[torch.Tensor, List[torch.Tensor]]:
B, T = idx.size()
use_kv_cache = input_pos is not None
block_size = self.config.block_size
if max_seq_length is None:
max_seq_length = block_size
if use_kv_cache: # not relevant otherwise
assert (
max_seq_length >= T
), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"
if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx) # 2 * (block_size, head_size * rotary_percentage)
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
if use_kv_cache and self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size)
cos, sin = self.rope_cache
if use_kv_cache:
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
cos = cos[:T]
sin = sin[:T]
mask = None
# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
if not use_kv_cache:
for block in self.transformer.h:
x, *_ = block(x, (cos, sin), max_seq_length)
else:
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])
x = self.transformer.ln_f(x)
if lm_head_chunk_size > 0:
# chunk the lm head logits to reduce the peak memory used by autograd
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)]
return self.lm_head(x) # (B, T, vocab_size)
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness."""
super()._init_weights(module)
if isinstance(module, LoRALinear):
module.reset_parameters()
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {"lm_head.weight": "lm_head.linear.weight"}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class Block(BaseBlock):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
if not config.shared_attention_norm:
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.config = config
class CausalSelfAttention(BaseCausalSelfAttention):
def __init__(self, config: Config) -> None:
"""Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for
parameter-efficient fine-tuning.
*Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for
query, key and value for each head) we can do this in a single pass with a single weight matrix.
"""
# Skip the parent class __init__ altogether and replace it to avoid
# useless allocations
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = LoRAQKVLinear(
in_features=config.n_embd,
out_features=shape,
r=config.r,
lora_alpha=config.alpha,
lora_dropout=config.dropout,
enable_lora=(config.to_query, config.to_key, config.to_value),
bias=config.bias,
# for MQA/GQA support
n_head=config.n_head,
n_query_groups=config.n_query_groups,
)
# output projection
self.proj = LoRALinear(
config.n_embd,
config.n_embd,
bias=config.bias,
r=(config.r if config.to_projection else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
self.config = config
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc = LoRALinear(
config.n_embd,
config.intermediate_size,
bias=config.bias,
r=(config.r if config.to_mlp else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
self.proj = LoRALinear(
config.intermediate_size,
config.n_embd,
bias=config.bias,
r=(config.r if config.to_mlp else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc.weight": "fc.linear.weight",
"fc.bias": "fc.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
class LLaMAMLP(lit_gpt.model.LLaMAMLP):
def __init__(self, config: Config) -> None:
nn.Module.__init__(self)
self.fc_1 = LoRALinear(
config.n_embd,
config.intermediate_size,
bias=config.bias,
r=(config.r if config.to_mlp else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
self.fc_2 = LoRALinear(
config.n_embd,
config.intermediate_size,
bias=config.bias,
r=(config.r if config.to_mlp else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
self.proj = LoRALinear(
config.intermediate_size,
config.n_embd,
bias=config.bias,
r=(config.r if config.to_mlp else 0),
lora_alpha=config.alpha,
lora_dropout=config.dropout,
)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
"fc_1.weight": "fc_1.linear.weight",
"fc_1.bias": "fc_1.linear.bias",
"fc_2.weight": "fc_2.linear.weight",
"fc_2.bias": "fc_2.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def merge_lora_weights(model: GPT) -> None:
"""Merge LoRA weights into the full-rank weights to speed up inference."""
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
"""Full definition of a GPT NeoX Language Model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""
import math
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import Self
from flash_attn import flash_attn_func
from lit_gpt.config import Config
from xformers.ops import SwiGLU
from .fused_rotary_embedding import apply_rotary_emb_func
RoPECache = Tuple[torch.Tensor, torch.Tensor]
KVCache = Tuple[torch.Tensor, torch.Tensor]
FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
class GPT(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.rope_cache: Optional[RoPECache] = None
self.mask_cache: Optional[torch.Tensor] = None
self.kv_caches: List[KVCache] = []
def _init_weights(self, module: nn.Module, n_layer) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
# GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
# RWKV: set it to 1e-4
# torch.nn.init.uniform_(module.weight, -1e-4, 1e-4)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# GPT-NeoX
for name, p in module.named_parameters():
if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
def reset_cache(self) -> None:
self.kv_caches.clear()
if self.mask_cache is not None and self.mask_cache.device.type == "xla":
# https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179
self.rope_cache = None
self.mask_cache = None
def forward(
self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T = idx.size()
use_kv_cache = input_pos is not None
block_size = self.config.block_size
if max_seq_length is None:
max_seq_length = block_size
if use_kv_cache: # not relevant otherwise
assert (
max_seq_length >= T
), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}"
if self.rope_cache is None:
self.rope_cache = self.build_rope_cache(idx)
# passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
# this will be resolved by https://github.com/pytorch/pytorch/issues/96099
if use_kv_cache and self.mask_cache is None:
self.mask_cache = self.build_mask_cache(idx)
cos, sin = self.rope_cache
if use_kv_cache:
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, :max_seq_length]
else:
cos = cos[:T]
sin = sin[:T]
mask = None
# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if not use_kv_cache:
for block in self.transformer.h:
x, *_ = block(x, (cos, sin), max_seq_length)
else:
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])
x = self.transformer.ln_f(x)
return self.lm_head(x) # (b, t, vocab_size)
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
return build_rope_cache(
seq_len=self.config.block_size,
n_elem=int(self.config.rotary_percentage * self.config.head_size),
dtype=torch.bfloat16,
device=idx.device,
condense_ratio=self.config.condense_ratio,
)
def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
B = idx.size(0)
heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups
k_cache_shape = (
B,
max_seq_length,
heads,
rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
)
v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
device = idx.device
return [
(torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
for _ in range(self.config.n_layer)
]
class Block(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
if not config.shared_attention_norm:
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.mlp = config.mlp_class(config)
self.config = config
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
n_1 = self.norm_1(x)
h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)
if self.config.parallel_residual:
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = x + h + self.mlp(n_2)
else:
if self.config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration"
" (non-parallel residual and shared attention norm)."
)
x = x + h
x = x + self.mlp(self.norm_2(x))
return x, new_kv_cache
class CausalSelfAttention(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.config = config
def forward(
self,
x: torch.Tensor,
rope: RoPECache,
max_seq_length: int,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
# qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
# repeat k and v if necessary
# Peiyuan: we do not need to do this as flash attention 2 already support GQA
# if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
# # for MHA this is a no-op
# k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
# v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs)
k = k.reshape(B, T, -1, self.config.head_size)
v = v.reshape(B, T, -1, self.config.head_size)
cos, sin = rope
# apply rope in fp32 significanly stabalize training
# fused rope expect (batch_size, seqlen, nheads, headdim)
q = apply_rotary_emb_func(q, cos, sin, False, True)
k = apply_rotary_emb_func(k, cos, sin, False, True)
# n_elem = int(self.config.rotary_percentage * self.config.head_size)
# q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
# k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
# print( (q_roped - q).sum())
# q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
# k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
if kv_cache is not None:
cache_k, cache_v = kv_cache
cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
# check if reached token limit
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=1)
cache_v = torch.roll(cache_v, -1, dims=1)
k = cache_k.index_copy_(1, input_pos, k)
v = cache_v.index_copy_(1, input_pos, v)
kv_cache = k, v
y = self.scaled_dot_product_attention(q, k, v, mask=mask)
y = y.reshape(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.proj(y)
return y, kv_cache
def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
):
scale = 1.0 / math.sqrt(self.config.head_size)
if (
FlashAttention2Available
and mask is None
and q.device.type == "cuda"
and q.dtype in (torch.float16, torch.bfloat16)
):
from flash_attn import flash_attn_func
return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if q.size() != k.size():
k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)
class GptNeoxMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.gelu(x)
return self.proj(x)
class LLaMAMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
# self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
# self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
# self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x_fc_1 = self.fc_1(x)
# x_fc_2 = self.fc_2(x)
# x = torch.nn.functional.silu(x_fc_1) * x_fc_2
# return self.proj(x)
return self.swiglu(x)
def build_rope_cache(
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
) -> RoPECache:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta)
cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
# added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
if dtype == torch.bfloat16:
return cos.bfloat16(), sin.bfloat16()
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
return cos.half(), sin.half()
return cos, sin
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
return roped.type_as(x)
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
import os
import random
import struct
import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
def code(dtype):
for k in dtypes:
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
HDR_MAGIC = b"LITPKDS"
HDR_SIZE = 24 # bytes
class PackedDataset(IterableDataset):
def __init__(
self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
):
self._filenames = filenames
self._n_chunks = n_chunks
self._block_size = block_size
self._seed = seed
self._shuffle = shuffle
self._wrap = wrap
self._num_processes = num_processes
self._process_rank = process_rank
def __iter__(self):
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info is not None else 1
worker_id = worker_info.id if worker_info is not None else 0
num_shards = num_workers * self._num_processes
shard_id = self._process_rank * num_workers + worker_id
max_num_files = len(self._filenames) // num_shards * num_shards
filenames = self._filenames[shard_id:max_num_files:num_shards]
return PackedDatasetIterator(
filenames=filenames,
n_chunks=self._n_chunks,
block_size=self._block_size,
seed=self._seed,
shuffle=self._shuffle,
wrap=self._wrap,
)
class PackedDatasetBuilder(object):
def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
if dtype == "auto":
if vocab_size is None:
raise ValueError("vocab_size cannot be None when dtype='auto'")
if vocab_size is not None and vocab_size < 65500:
self._dtype = np.uint16
else:
self._dtype = np.int32
else:
self._dtype = dtype
self._counter = 0
self._chunk_size = chunk_size
self._outdir = outdir
self._prefix = prefix
self._sep_token = sep_token
self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
self._arr.fill(self._sep_token)
self._idx = 0
self._version = 1
self._filenames = []
def _write_chunk(self):
filename = f"{self._prefix}_{self._counter:010d}.bin"
filename = os.path.join(self._outdir, filename)
with open(filename, "wb") as f:
f.write(HDR_MAGIC)
f.write(struct.pack("<Q", self._version))
f.write(struct.pack("<B", code(self._dtype)))
f.write(struct.pack("<Q", self._chunk_size))
f.write(self._arr.tobytes(order="C"))
self._filenames.append(filename)
self._counter += 1
self._arr.fill(self._sep_token)
self._idx = 0
@property
def dtype(self):
return self._dtype
@property
def filenames(self):
return self._filenames.copy()
def add_array(self, arr):
while self._idx + arr.shape[0] > self._chunk_size:
part_len = self._chunk_size - self._idx
self._arr[self._idx : self._idx + part_len] = arr[:part_len]
self._write_chunk()
arr = arr[part_len:]
arr_len = arr.shape[0]
self._arr[self._idx : self._idx + arr_len] = arr
self._idx += arr_len
def write_reminder(self):
self._write_chunk()
class PackedDatasetIterator:
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
self._seed = seed
self._shuffle = shuffle
self._rng = np.random.default_rng(seed) if shuffle else None
self._block_idxs = None
self._wrap = wrap
# TODO: instead of filenames, we could have a single text stream
# (or text file) with the sequence of all files to be
# fetched/loaded.
self._filenames = filenames
self._file_idx = 0
self._n_chunks = n_chunks
self._dtype = None
self._block_size = block_size
self._n_blocks = None
self._mmaps = []
self._buffers = []
self._block_idxs = []
self._curr_idx = 0
self._load_n_chunks()
def _read_header(self, path):
with open(path, "rb") as f:
magic = f.read(len(HDR_MAGIC))
assert magic == HDR_MAGIC, "File doesn't match expected format."
version = struct.unpack("<Q", f.read(8))
assert version == (1,)
(dtype_code,) = struct.unpack("<B", f.read(1))
dtype = dtypes[dtype_code]
(chunk_size,) = struct.unpack("<Q", f.read(8))
return dtype, chunk_size
def _close_mmaps(self):
for mmap in self._mmaps:
mmap._mmap.close()
def _load_n_chunks(self):
self._close_mmaps()
self._mmaps = []
self._buffers = []
if self._n_chunks > len(self._filenames[self._file_idx :]):
# if not self._wrap:
# raise StopIteration
self._file_idx = 0
for i in range(self._n_chunks):
filename = self._filenames[self._file_idx + i]
if self._dtype is None:
self._dtype, self._chunk_size = self._read_header(filename)
self._n_blocks = self._chunk_size // self._block_size
# TODO: check header matches with previous files
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
self._mmaps.append(mmap)
self._buffers.append(memoryview(mmap))
self._file_idx += self._n_chunks
n_all_blocks = self._n_chunks * self._n_blocks
self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
self._curr_idx = 0
def __del__(self):
self._close_mmaps()
del self._mmaps
del self._buffers
def __iter__(self):
return self
def __next__(self):
if self._curr_idx >= len(self._block_idxs):
self._load_n_chunks()
# TODO: trigger fetching next next n_chunks if remote
block_idx = self._block_idxs[self._curr_idx]
chunk_id = block_idx // self._n_blocks
buffer = self._buffers[chunk_id]
elem_id = (block_idx % self._n_blocks) * self._block_size
offset = np.dtype(self._dtype).itemsize * elem_id
arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
self._curr_idx += 1
return torch.from_numpy(arr.astype(np.int64))
class CombinedDataset(IterableDataset):
def __init__(self, datasets, seed, weights=None):
self._seed = seed
self._datasets = datasets
self._weights = weights
n_datasets = len(datasets)
if weights is None:
self._weights = [1 / n_datasets] * n_datasets
def __iter__(self):
return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
class CombinedDatasetIterator:
def __init__(self, datasets, seed, weights):
self._datasets = [iter(el) for el in datasets]
self._weights = weights
self._rng = random.Random(seed)
def __next__(self):
(dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
return next(dataset)
import torch
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
import dropout_layer_norm
import torch
from torch.nn import init
def maybe_align(x, alignment_in_bytes=16):
"""Assume that x already has last dim divisible by alignment_in_bytes"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat,
residualmat,
gamma,
beta,
rowscale,
colscale,
None,
None,
dropout_p,
epsilon,
1.0,
0,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
if colscale is not None:
assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat,
dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
None,
None,
dropout_p,
1.0,
0,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_subset_forward(
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat,
residualmat,
gamma,
beta,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_subset_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(-1, hidden_size)
dxmat = dx.view(xmat.shape) if dx is not None else None
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
out_subset = out_subset.view(-1) if out_subset is not None else None
if colscale is not None:
assert x0 is not None, "x0 is required to compute the gradient of colscale"
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
dzmat,
dxmat,
xmat,
x0mat,
dmask,
mu,
rsigma,
gamma,
None,
colscale,
x0_subset,
out_subset,
dropout_p,
rowscale_const,
x0_numrows,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
if colscale is None:
return dx0mat, dresidualmat, dgamma, dbeta
else:
dcolscale = rest[0]
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
def _dropout_add_layer_norm_parallel_residual_forward(
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
(
z0mat,
z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
x0mat,
x1mat,
residualmat,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
None,
residual_in_fp32,
is_rms_norm,
)
# dmask0 and dmask1 are None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma
def _dropout_add_layer_norm_parallel_residual_backward(
dz0,
dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm=False,
):
"""Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
hidden_size = gamma0.numel()
xmat = x.view((-1, hidden_size))
dz0mat = dz0.view(xmat.shape)
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
dxmat = dx.view(xmat.shape) if dx is not None else None
(
dx0mat,
dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
*rest,
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
dz0mat,
dz1mat,
dxmat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
is_rms_norm,
)
# dresidualmat is None if not has_residual
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0,
residual,
gamma,
beta,
rowscale,
colscale,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
ctx.save_for_backward(
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
if not return_dmask:
return (
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
)
else:
dmask = (
dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask)
return (
(zmat.view(x0.shape), dmask)
if not prenorm
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
)
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
rowscale,
colscale,
dropout_p,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(x.shape)
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (
dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
None,
dcolscale,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0,
residual,
gamma,
beta,
colscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
is_rms_norm,
)
# Only need to save x0 if we need to compute gradient wrt colscale
x0_saved = x0 if colscale is not None else None
x_shape = (-1, *x0.shape[1:])
ctx.save_for_backward(
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.rowscale_const = rowscale_const
ctx.x0_numrows = x0.shape[:-1].numel()
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta is not None
z_shape = (-1, *x0.shape[1:])
if not return_dmask:
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
else:
z = zmat.view(z_shape)
dmask = (
dmask.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask)
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
dz,
dx,
x,
x0,
dmask,
mu,
rsigma,
gamma,
colscale,
x0_subset,
out_subset,
dropout_p,
ctx.rowscale_const,
ctx.x0_numrows,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(-1, *x.shape[1:])
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
dcolscale = rest[0] if colscale is not None else None
return (
dx0,
dresidual,
dgamma,
dbeta if ctx.has_beta else None,
dcolscale,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32=False,
prenorm=False,
is_rms_norm=False,
return_dmask=False,
):
x0 = maybe_align(x0.contiguous(), 16)
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma0 = maybe_align(gamma0.contiguous(), 16)
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
(
z0mat,
z1mat,
xmat,
dmask0,
dmask1,
mu,
rsigma,
) = _dropout_add_layer_norm_parallel_residual_forward(
x0,
x1,
residual,
gamma0,
beta0,
gamma1,
beta1,
dropout_p,
epsilon,
residual_in_fp32,
is_rms_norm,
)
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
ctx.prenorm = prenorm
ctx.dropout_p = dropout_p
ctx.has_x1 = x1 is not None
ctx.has_residual = residual is not None
ctx.is_rms_norm = is_rms_norm
ctx.has_beta = beta0 is not None
z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
if not return_dmask:
return z if not prenorm else (*z, xmat.view(x0.shape))
else:
dmask0 = (
dmask0.view(x0.shape)
if dropout_p > 0.0
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
dmask1 = (
dmask1.view(x0.shape)
if dropout_p > 0.0 and x1 is not None
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
)
ctx.mark_non_differentiable(dmask0)
ctx.mark_non_differentiable(dmask1)
return (
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
)
@staticmethod
def backward(ctx, dz0, dz1, *args):
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1
has_residual = ctx.has_residual
(
dx0mat,
dx1mat,
dresidualmat,
dgamma0,
dbeta0,
dgamma1,
dbeta1,
) = _dropout_add_layer_norm_parallel_residual_backward(
dz0,
dz1,
dx,
x,
dmask0,
dmask1,
mu,
rsigma,
gamma0,
gamma1,
dropout_p,
has_x1,
has_residual,
ctx.is_rms_norm,
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
return (
dx0,
dx1,
dresidual,
dgamma0,
dbeta0 if ctx.has_beta else None,
dgamma1,
dbeta1 if ctx.has_beta else None,
None,
None,
None,
None,
None,
None,
)
def layer_norm(x, weight, bias, epsilon):
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
def dropout_add_layer_norm(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
rowscale=None,
layerscale=None,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormFn.apply(
x0,
residual,
weight,
bias,
rowscale,
layerscale,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
def dropout_add_layer_norm_subset(
x0,
residual,
weight,
bias,
dropout_p,
epsilon,
layerscale=None,
x0_subset=None,
out_subset=None,
rowscale_const=1.0,
out_numrows=0,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormSubsetFn.apply(
x0,
residual,
weight,
bias,
layerscale,
x0_subset,
out_subset,
dropout_p,
epsilon,
rowscale_const,
out_numrows,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
def dropout_add_layer_norm_parallel_residual(
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
prenorm=False,
residual_in_fp32=False,
return_dropout_mask=False,
):
"""residual_in_fp32 only has an effect if residual is None.
Otherwise residual dtype is residual.dtype.
"""
return DropoutAddLayerNormParallelResidualFn.apply(
x0,
x1,
residual,
weight0,
bias0,
weight1,
bias1,
dropout_p,
epsilon,
residual_in_fp32,
prenorm,
False,
return_dropout_mask,
)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(
self,
hidden_size,
prenorm=False,
p=0.0,
eps=1e-5,
residual_in_fp32=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.eps = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x0, residual=None):
return dropout_add_layer_norm(
x0,
residual,
self.weight,
self.bias,
self.p if self.training else 0.0,
self.eps,
prenorm=self.prenorm,
residual_in_fp32=self.residual_in_fp32,
)
def rms_norm(x, weight, epsilon):
return DropoutAddLayerNormFn.apply(
x, None, weight, None, None, None, 0.0, epsilon, False, False, True
)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(size))
self.dim = dim
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
def forward(self, x):
return rms_norm(x, self.weight, self.eps)
class RMSNorm(torch.nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.weight * x_normed
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
import time
from collections import deque
from contextlib import nullcontext
from typing import Any, Callable, Deque, Dict, Optional
import torch
from lightning import Callback, Fabric, LightningModule, Trainer
from lightning.fabric.utilities.rank_zero import rank_zero_only as fabric_rank_zero_only
from lightning.pytorch.utilities.rank_zero import rank_zero_only as trainer_rank_zero_only
from torch.utils.flop_counter import FlopCounterMode
import math
from lit_gpt import GPT, Config
from lit_gpt.utils import num_parameters
GPU_AVAILABLE_FLOPS = {
# source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
# nvidia publishes spec sheet with a 2x sparsity factor
"h100-sxm": {
"64-true": 67e12,
"32-true": 67e12,
"16-true": 1.979e15 / 2,
"16-mixed": 1.979e15 / 2,
"bf16-true": 1.979e15 / 2,
"bf16-mixed": 1.979e15 / 2,
"8-true": 3.958e15 / 2,
"8-mixed": 3.958e15 / 2,
},
"h100-pcie": {
"64-true": 51e12,
"32-true": 51e12,
"16-true": 1.513e15 / 2,
"16-mixed": 1.513e15 / 2,
"bf16-true": 1.513e15 / 2,
"bf16-mixed": 1.513e15 / 2,
"8-true": 3.026e15 / 2,
"8-mixed": 3.026e15 / 2,
},
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
# sxm and pcie have same flop counts
"a100": {
"64-true": 19.5e12,
"32-true": 19.5e12,
"16-true": 312e12,
"16-mixed": 312e12,
"bf16-true": 312e12,
"bf16-mixed": 312e12,
},
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
"a10g": {"32-true": 31.2e12, "16-true": 125e12, "16-mixed": 125e12, "bf16-true": 125e12, "bf16-mixed": 125e12},
# source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
"v100-sxm": {"64-true": 7.8e12, "32-true": 15.7e12, "16-true": 125e12, "16-mixed": 125e12},
"v100-pcie": {"64-true": 7e12, "32-true": 14e12, "16-true": 112e12, "16-mixed": 112e12},
"v100s-pcie": {"64-true": 8.2e12, "32-true": 16.4e12, "16-true": 130e12, "16-mixed": 130e12},
# source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
# sxm and pcie have same flop counts
"t4": {"32-true": 8.1e12, "16-true": 65e12, "16-mixed": 65e12, "8-true": 130e12, "int4": 260e12},
# https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
"quadro rtx 5000": {"32-true": 11.2e12, "16-true": 89.2e12, "16-mixed": 89.2e12},
}
TPU_AVAILABLE_FLOPS = {
# flop count for each TPU generation is the same for all precisions
# since bfloat16 precision is always used for performing matrix operations
# for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
# source: https://arxiv.org/pdf/1907.10701.pdf
"v2": 45e12,
# source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
"v3": 123e12,
# source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
"v4": 275e12,
}
def get_flops_available(device: torch.device, precision: str) -> Optional[float]:
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device).lower()
if "h100" in device_name and "hbm3" in device_name:
device_name = "h100-sxm"
elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
device_name = "h100-pcie"
elif "a100" in device_name:
device_name = "a100"
elif "a10g" in device_name:
device_name = "a10g"
elif "v100-sxm" in device_name:
device_name = "v100-sxm"
elif "v100-pcie" in device_name:
device_name = "v100-pcie"
elif "t4" in device_name:
device_name = "t4"
elif "quadro rtx 5000" in device_name:
device_name = "quadro rtx 5000"
else:
device_name = None
if device_name is not None:
try:
return int(GPU_AVAILABLE_FLOPS[device_name][precision])
except KeyError:
raise KeyError(
f"flop count not found for {device_name} with precision: {precision}; "
"MFU cannot be calculated and reported."
)
elif device.type == "xla":
from torch_xla.experimental import tpu
device_name = tpu.get_tpu_env()["TYPE"].lower()
try:
return int(TPU_AVAILABLE_FLOPS[device_name])
except KeyError:
raise KeyError(
f"flop count not found for {device_name} with precision: {precision}; "
"MFU cannot be calculated and reported."
)
return None
# Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
class SpeedMonitorBase:
"""Logs the training throughput and utilization.
+-------------------------------------+-----------------------------------------------------------+
| Key | Logged data |
+=====================================+===========================================================+
| | Rolling average (over `window_size` most recent |
| `throughput/batches_per_sec` | batches) of the number of batches processed per second |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/samples_per_sec` | batches) of the number of samples processed per second |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | Rolling average (over `window_size` most recent |
| `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
| | This may include padding depending on dataset |
+-------------------------------------+-----------------------------------------------------------+
| | Estimates flops by `flops_per_batch * batches_per_sec` |
| `throughput/flops_per_sec` | |
| | |
+-------------------------------------+-----------------------------------------------------------+
| `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/tokens_per_sec` divided by world size. This |
| `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/flops_per_sec` divided by world size. Only |
| `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
| | |
+-------------------------------------+-----------------------------------------------------------+
| | `throughput/device/flops_per_sec` divided by world size. |
| `throughput/device/mfu` | |
| | |
+-------------------------------------+-----------------------------------------------------------+
| `time/train` | Total elapsed training time |
+-------------------------------------+-----------------------------------------------------------+
| `time/val` | Total elapsed validation time |
+-------------------------------------+-----------------------------------------------------------+
| `time/total` | Total elapsed time (time/train + time/val) |
+-------------------------------------+-----------------------------------------------------------+
Notes:
- The implementation assumes that devices are homogeneous as it normalizes by the world size.
- Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
batches/sec to measure throughput under this circumstance.
- Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
There is no widespread, realistic, and reliable implementation to compute them.
We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
will almost always be an overestimate when compared to the true value.
Args:
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Defaults to 100.
time_unit (str, optional): Time unit to use for `time` logging. Can be one of
'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
"""
def __init__(
self,
flops_available: float,
log_dict: Callable[[Dict, int], None],
window_size: int = 100,
time_unit: str = "hours",
log_iter_interval: int = 1,
):
self.flops_available = flops_available
self.log_dict = log_dict
self.log_iter_interval = log_iter_interval
# Track the batch num samples and wct to compute throughput over a window of batches
self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
self.history_training_loss: Deque[int] = deque(maxlen=log_iter_interval)
self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
self.divider = 1
if time_unit == "seconds":
self.divider = 1
elif time_unit == "minutes":
self.divider = 60
elif time_unit == "hours":
self.divider = 60 * 60
elif time_unit == "days":
self.divider = 60 * 60 * 24
else:
raise ValueError(
f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
)
# Keep track of time spent evaluating
self.total_eval_wct = 0.0
self.iter = -1
def on_train_batch_end(
self,
samples: int, # total samples seen (per device)
train_elapsed: float, # total training time (seconds)
world_size: int,
step_count: int,
flops_per_batch: Optional[int] = None, # (per device)
lengths: Optional[int] = None, # total length of the samples seen (per device)
train_loss: Optional[float] = None,
):
self.iter += 1
metrics = {}
self.history_samples.append(samples)
self.history_training_loss.append(train_loss)
if lengths is not None:
self.history_lengths.append(lengths)
# if lengths are passed, there should be as many values as samples
assert len(self.history_samples) == len(self.history_lengths)
self.history_wct.append(train_elapsed)
if len(self.history_wct) == self.history_wct.maxlen:
elapsed_batches = len(self.history_samples) - 1
elapsed_samples = self.history_samples[-1] - self.history_samples[0]
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
samples_per_sec = elapsed_samples * world_size / elapsed_wct
dev_samples_per_sec = elapsed_samples / elapsed_wct
metrics.update(
{
"throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
"throughput/samples_per_sec": samples_per_sec,
"throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
"throughput/device/samples_per_sec": dev_samples_per_sec,
}
)
if lengths is not None:
elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
avg_length = elapsed_lengths / elapsed_batches
metrics.update(
{
"throughput/tokens_per_sec": samples_per_sec * avg_length,
"throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
"total_tokens": avg_length * world_size * samples,
}
)
if train_loss is not None:
avg_loss = sum(self.history_training_loss) / len(self.history_training_loss)
metrics.update(
{
"metric/train_loss": avg_loss,
"metric/train_ppl": math.exp(avg_loss)
}
)
if flops_per_batch is not None:
# sum of flops per batch across ranks
self.history_flops.append(flops_per_batch * world_size)
if len(self.history_flops) == self.history_flops.maxlen:
elapsed_flops = sum(self.history_flops) - self.history_flops[0]
elapsed_wct = self.history_wct[-1] - self.history_wct[0]
flops_per_sec = elapsed_flops / elapsed_wct
device_flops_per_sec = flops_per_sec / world_size
metrics.update(
{"throughput/flops_per_sec": flops_per_sec, "throughput/device/flops_per_sec": device_flops_per_sec}
)
if self.flops_available:
metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
metrics.update(
{
"time/train": train_elapsed / self.divider,
"time/val": self.total_eval_wct / self.divider,
"time/total": (train_elapsed + self.total_eval_wct) / self.divider,
"samples": samples,
}
)
if self.iter % self.log_iter_interval == 0:
self.log_dict(metrics, step_count)
def eval_end(self, eval_elapsed: float):
self.total_eval_wct += eval_elapsed # seconds
class SpeedMonitorFabric(SpeedMonitorBase):
def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
# TODO: this will not work properly if a precision plugin is passed to Fabric
flops_available = get_flops_available(fabric.device, fabric._connector._precision_input)
super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
@fabric_rank_zero_only
def on_train_batch_end(self, *args: Any, **kwargs: Any):
super().on_train_batch_end(*args, **kwargs)
class SpeedMonitorCallback(Callback):
def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
super().__init__()
self.speed_monitor: Optional[SpeedMonitorBase] = None
self.speed_monitor_kwargs = kwargs
self.length_fn = length_fn
self.batch_size = batch_size
self.eval_t0: int = 0
self.train_t0: int = 0
self.total_lengths: int = 0
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if self.speed_monitor is not None:
return # already setup
# TODO: this will not work properly if a precision plugin is passed to Trainer
flops_available = get_flops_available(
trainer.strategy.root_device, trainer._accelerator_connector._precision_flag
)
self.speed_monitor = SpeedMonitorBase(flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs)
@trainer_rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if trainer.fit_loop._should_accumulate():
return
self.train_t0 = time.perf_counter()
@trainer_rank_zero_only
def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: Any, batch: Any, batch_idx: int
) -> None:
self.total_lengths += self.length_fn(batch)
if trainer.fit_loop._should_accumulate():
return
train_elapsed = time.perf_counter() - self.train_t0
assert self.speed_monitor is not None
iter_num = trainer.fit_loop.total_batch_idx
assert (measured_flops := pl_module.measured_flops) is not None
self.speed_monitor.on_train_batch_end(
(iter_num + 1) * self.batch_size,
train_elapsed,
# this assumes that device FLOPs are the same and that all devices have the same batch size
trainer.world_size,
flops_per_batch=measured_flops,
lengths=self.total_lengths,
)
@trainer_rank_zero_only
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.eval_t0 = time.perf_counter()
@trainer_rank_zero_only
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
eval_elapsed = time.perf_counter() - self.eval_t0
assert self.speed_monitor is not None
self.speed_monitor.eval_end(eval_elapsed)
def flops_per_param(config: Config, n_params: int) -> int:
flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation
# this assumes that all samples have a fixed length equal to the block size
# which is most likely false during finetuning
flops_per_seq = flops_per_token * config.block_size
attn_flops_per_seq = config.n_layer * 2 * 2 * (config.n_embd * (config.block_size**2))
return flops_per_seq + attn_flops_per_seq
def estimate_flops(model: GPT) -> int:
"""Measures estimated FLOPs for MFU.
Refs:
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
"""
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
n_trainable_params = num_parameters(model, requires_grad=True)
trainable_flops = flops_per_param(model.config, n_trainable_params)
# forward + backward + gradients (assumes no gradient accumulation)
ops_per_step = 3 if model.training else 1
n_frozen_params = num_parameters(model, requires_grad=False)
frozen_flops = flops_per_param(model.config, n_frozen_params)
# forward + backward
frozen_ops_per_step = 2 if model.training else 1
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
def measure_flops(model: GPT, x: torch.Tensor) -> int:
"""Measures real FLOPs for HFU"""
flop_counter = FlopCounterMode(model, display=False)
ctx = nullcontext() if model.training else torch.no_grad()
with ctx, flop_counter:
y = model(x)
if model.training:
y.sum().backward()
return flop_counter.get_total_flops()
import json
from pathlib import Path
from typing import Optional
import torch
class Tokenizer:
def __init__(self, checkpoint_dir: Path) -> None:
# some checkpoints have both files, `.model` takes precedence
if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
from sentencepiece import SentencePieceProcessor
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
self.backend = "sentencepiece"
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
from tokenizers import Tokenizer as HFTokenizer
self.processor = HFTokenizer.from_file(str(vocabulary_path))
self.backend = "huggingface"
with open(checkpoint_dir / "tokenizer_config.json") as fp:
config = json.load(fp)
bos_token = config.get("bos_token")
self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
self.eos_id = self.token_to_id(config["eos_token"])
else:
raise NotImplementedError
@property
def vocab_size(self) -> int:
if self.backend == "huggingface":
return self.processor.get_vocab_size(with_added_tokens=False)
if self.backend == "sentencepiece":
return self.processor.vocab_size()
raise RuntimeError
def token_to_id(self, token: str) -> int:
if self.backend == "huggingface":
id_ = self.processor.token_to_id(token)
elif self.backend == "sentencepiece":
id_ = self.processor.piece_to_id(token)
else:
raise RuntimeError
if id_ is None:
raise ValueError(f"token {token!r} not found in the collection.")
return id_
def encode(
self,
string: str,
device: Optional[torch.device] = None,
bos: bool = False,
eos: bool = True,
max_length: int = -1,
) -> torch.Tensor:
if self.backend == "huggingface":
tokens = self.processor.encode(string).ids
elif self.backend == "sentencepiece":
tokens = self.processor.encode(string)
else:
raise RuntimeError
if bos:
bos_id = self.bos_id
if bos_id is None:
raise NotImplementedError("This tokenizer does not defined a bos token")
tokens = [bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
return self.processor.decode(tokens)
"""Utility functions for training and inference."""
import pickle
import sys
import warnings
from contextlib import contextmanager
from functools import partial
from io import BytesIO
from pathlib import Path
from types import MethodType
from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union
import torch
import torch.nn as nn
import torch.utils._device
from lightning.fabric.loggers import CSVLogger
from torch.serialization import normalize_storage_type
def find_multiple(n: int, k: int) -> int:
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad)
@contextmanager
def quantization(mode: Optional[str] = None):
if mode is None:
yield
return
if mode == "bnb.int8":
from quantize.bnb import InferenceLinear8bitLt
quantized_linear_cls = InferenceLinear8bitLt
elif mode == "bnb.fp4":
from quantize.bnb import Linear4bit
# Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses
class QuantizedLinear(Linear4bit):
def __init__(self, *args, **kwargs):
super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs)
quantized_linear_cls = QuantizedLinear
elif mode == "bnb.fp4-dq":
from quantize.bnb import Linear4bit
class QuantizedLinear(Linear4bit):
def __init__(self, *args, **kwargs):
super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs)
quantized_linear_cls = QuantizedLinear
elif mode == "bnb.nf4":
from quantize.bnb import Linear4bit
class QuantizedLinear(Linear4bit):
def __init__(self, *args, **kwargs):
super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs)
quantized_linear_cls = QuantizedLinear
elif mode == "bnb.nf4-dq":
from quantize.bnb import Linear4bit
class QuantizedLinear(Linear4bit):
def __init__(self, *args, **kwargs):
super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs)
quantized_linear_cls = QuantizedLinear
elif mode == "gptq.int4":
from quantize.gptq import ColBlockQuantizedLinear
class QuantizedLinear(ColBlockQuantizedLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
quantized_linear_cls = QuantizedLinear
else:
raise ValueError(f"Unknown quantization mode: {mode}")
torch_linear_cls = torch.nn.Linear
torch.nn.Linear = quantized_linear_cls
yield
torch.nn.Linear = torch_linear_cls
# this is taken from torchhacks https://github.com/lernapparat/torchhacks
class NotYetLoadedTensor:
def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args):
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args
@classmethod
def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None):
ret = func(*args)
if isinstance(ret, NotYetLoadedTensor):
old_lt = ret._load_tensor
def _load_tensor():
t = old_lt()
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)
ret._load_tensor = _load_tensor
return ret
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
@classmethod
def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None):
if isinstance(data, NotYetLoadedTensor):
old_lt = data._load_tensor
def _load_tensor():
t = old_lt()
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
data._load_tensor = _load_tensor
return data
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
@classmethod
def rebuild_tensor_v2(
cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None
):
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
metatensor = torch._utils._rebuild_tensor_v2(
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
)
storageinfo = storage.archiveinfo
return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
def _load_tensor(self):
name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype
uts = (
self.archiveinfo.zipfile_context.zf.get_storage_from_record(
f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage
)
._typed_storage()
._untyped_storage
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True)
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args]
return func(*loaded_args, **kwargs)
# gc.collect would be costly here, maybe do it optionally
def __getattr__(self, name):
# properties
## TODO: device, is_...??
## TODO: mH, mT, H, T, data, imag, real
## name ???
if name in {
"dtype",
"grad",
"grad_fn",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)
if name in {"size"}:
return getattr(self.metatensor, name)
# materializing with contiguous is needed for quantization
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)
raise AttributeError(f"{type(self)} does not have {name}")
def __repr__(self):
return f"NotYetLoadedTensor({repr(self.metatensor)})"
class LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file, zipfile_context):
super().__init__(file)
self.zipfile_context = zipfile_context
def find_class(self, module, name):
res = super().find_class(module, name)
if module == "torch._utils" and name == "_rebuild_tensor_v2":
return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
if module == "torch._tensor" and name == "_rebuild_from_type_v2":
return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
if module == "torch._utils" and name == "_rebuild_parameter":
return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
return res
def persistent_load(self, pid):
name, cls, fn, device, size = pid
with warnings.catch_warnings():
warnings.simplefilter("ignore")
s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")
s.archiveinfo = pid
return s
class lazy_load:
def __init__(self, fn):
self.zf = torch._C.PyTorchFileReader(str(fn))
with BytesIO(self.zf.get_record("data.pkl")) as pkl:
mup = LazyLoadingUnpickler(pkl, self)
self.sd = mup.load()
def __enter__(self):
return self.sd
def __exit__(self, exc_type, exc_val, exc_tb):
del self.zf # I don't think there is a way to force closing...
self.zf = None
def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
files = {
"lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
"lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or (
checkpoint_dir / "tokenizer.model"
).is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
if checkpoint_dir.is_dir():
if all(files.values()):
# we're good
return
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
else:
problem = " is not a checkpoint directory"
# list locally available checkpoints
available = list(Path("checkpoints").glob("*/*"))
if available:
options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
extra = f"\nYou have downloaded locally:{options}\n"
else:
extra = ""
error_message = (
f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
"\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
f"{extra}\nSee all download options by running:\n python scripts/download.py"
)
print(error_message, file=sys.stderr)
raise SystemExit(1)
class SavingProxyForStorage:
def __init__(self, obj, saver, protocol_version=5):
self.protocol_version = protocol_version
self.saver = saver
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
raise TypeError(f"expected storage, not {type(obj)}")
# this logic is taken from PyTorch 2.0+ torch/serialization.py
if isinstance(obj, torch.storage.TypedStorage):
# PT upstream wants to deprecate this eventually...
storage = obj._untyped_storage
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage_key = saver._write_storage_and_return_key(storage)
location = torch.serialization.location_tag(storage)
self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
def __reduce_ex__(self, protocol_version):
assert False, "this should be handled with out of band"
class SavingProxyForTensor:
def __init__(self, tensor, saver, protocol_version=5):
self.protocol_version = protocol_version
self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version)
assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)
self.reduce_args = (storage_proxy, *other_reduce_args)
def __reduce_ex__(self, protocol_version):
if protocol_version != self.protocol_version:
raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}")
return self.reduce_ret_fn, self.reduce_args
class IncrementalPyTorchPickler(pickle.Pickler):
def __init__(self, saver, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage_dtypes = {}
self.saver = saver
self.id_map = {}
# this logic is taken from PyTorch 2.0+ torch/serialization.py
def persistent_id(self, obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, SavingProxyForStorage):
return obj.storage_info
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in self.storage_dtypes:
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
"Cannot save multiple tensors or storages that view the same data as different types"
)
else:
self.storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = self.id_map.get(storage._cdata)
if storage_key is None:
storage_key = self.saver._write_storage_and_return_key(storage)
self.id_map[storage._cdata] = storage_key
location = torch.serialization.location_tag(storage)
return ("storage", storage_type, storage_key, location, storage_numel)
return None
class incremental_save:
def __init__(self, name):
self.name = name
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
def __enter__(self):
return self
def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self)
raise TypeError(f"can only store tensors early, not {type(tensor)}")
def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
pickler.dump(obj)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
self.has_saved = True
def _write_storage_and_return_key(self, storage):
if self.has_saved:
raise RuntimeError("have already saved")
key = self.next_key
self.next_key += 1
name = f"data/{key}"
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
return key
def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()
T = TypeVar("T")
def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T:
logger = cls(*args, **kwargs)
def merge_by(dicts, key):
from collections import defaultdict
out = defaultdict(dict)
for d in dicts:
if key in d:
out[d[key]].update(d)
return [v for _, v in sorted(out.items())]
def save(self) -> None:
"""Overridden to merge CSV by the step number."""
import csv
if not self.metrics:
return
metrics = merge_by(self.metrics, "step")
keys = sorted({k for m in metrics for k in m})
with self._fs.open(self.metrics_file_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=keys)
writer.writeheader()
writer.writerows(metrics)
logger.experiment.save = MethodType(save, logger.experiment)
return logger
def chunked_cross_entropy(
logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
) -> torch.Tensor:
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
# the memory usage in fine-tuning settings with low number of parameters.
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
# the memory spike's magnitude
# lm_head was chunked (we are fine-tuning)
if isinstance(logits, list):
# don't want to chunk cross entropy
if chunk_size == 0:
logits = torch.cat(logits, dim=1)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
# chunk cross entropy
logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
return torch.cat(loss_chunks).mean()
# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if chunk_size == 0:
return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
# lm_head wasn't chunked, chunk cross entropy
logit_chunks = logits.split(chunk_size)
target_chunks = targets.split(chunk_size)
loss_chunks = [
torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none")
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
return torch.cat(loss_chunks).mean()
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
for checkpoint_name, attribute_name in mapping.items():
full_checkpoint_name = prefix + checkpoint_name
if full_checkpoint_name in state_dict:
full_attribute_name = prefix + attribute_name
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
return state_dict
def get_default_supported_precision(training: bool, tpu: bool = False) -> str:
"""Return default precision that is supported by the hardware.
Args:
training: `-mixed` or `-true` version of the precision to use
tpu: whether TPU device is used
Returns:
default precision that is suitable for the task and is supported by the hardware
"""
if tpu:
return "32-true"
if not torch.cuda.is_available() or torch.cuda.is_bf16_supported():
return "bf16-mixed" if training else "bf16-true"
return "16-mixed" if training else "16-true"
# 模型编码
modelCode=
# 模型名称
modelName=tinyllama_pytorch
# 模型描述
modelDescription=只有1.1B参数,减小了llama2模型规模和训练数据量,可以在许多基于Llama的开源项目中即插即用。
# 应用场景
appScenario=推理,训练,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
import glob
import math
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union
import math
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader
from functools import partial
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually
from lit_gpt.model import GPT, Block, Config, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load
from pytorch_lightning.loggers import WandbLogger
from lit_gpt import FusedCrossEntropyLoss
import random
model_name = "tiny_LLaMA_1b"
name = "tinyllama_1b"
out_dir = Path("out") / name
# Hyperparameters
num_of_devices = 8
global_batch_size = 512
learning_rate = 4e-4
micro_batch_size = 8
max_step = 715256 * 2
warmup_steps = 2000
log_step_interval = 10
eval_iters = 100
save_step_interval = 5000
eval_step_interval = 5000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
min_lr = 4e-5
batch_size = global_batch_size // num_of_devices
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
warmup_iters = warmup_steps * gradient_accumulation_steps
max_iters = max_step * gradient_accumulation_steps
lr_decay_iters = max_iters
log_iter_interval = log_step_interval * gradient_accumulation_steps
# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight.
train_data_config = [
("train_slim", 0.693584),
("train_star", 0.306416),
]
val_data_config = [
("validation", 1.0),
]
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval)
wandb_logger = WandbLogger()
def setup(
devices: int = 8,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
precision: Optional[str] = None,
tpu: bool = False,
resume: Union[bool, Path] = False,
) -> None:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)
if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy=None,
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger])
fabric.print(hparams)
#fabric.launch(main, train_data_dir, val_data_dir, resume)
main(fabric, train_data_dir, val_data_dir, resume)
def main(fabric, train_data_dir, val_data_dir, resume):
monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval)
if fabric.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)
config = Config.from_name(model_name)
train_dataloader, val_dataloader = create_dataloaders(
batch_size=micro_batch_size,
block_size=config.block_size,
fabric=fabric,
train_data_dir=train_data_dir,
val_data_dir=val_data_dir,
seed=3407,
)
if val_dataloader is None:
train_dataloader = fabric.setup_dataloaders(train_dataloader)
else:
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
fabric.seed_everything(3407) # same seed for every process to init model (FSDP)
fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=False):
model = GPT(config)
model.apply(partial(model._init_weights ,n_layer=config.n_layer))
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters {num_parameters(model):,}")
model = fabric.setup(model)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)
# optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)
optimizer = fabric.setup_optimizers(optimizer)
state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}
if resume is True:
resume = sorted(out_dir.glob("*.pth"))[-1]
if resume :
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
train_time = time.perf_counter()
train(fabric, state, train_dataloader, val_dataloader, monitor, resume)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
def train(fabric, state, train_dataloader, val_dataloader, monitor, resume):
model = state["model"]
optimizer = state["optimizer"]
if val_dataloader is not None:
validate(fabric, model, val_dataloader) # sanity check
with torch.device("meta"):
meta_model = GPT(model.config)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
estimated_flops = estimate_flops(meta_model) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))
# measured_flos run in meta. Will trigger fusedRMSNorm error
#measured_flops = measure_flops(meta_model, x)
#fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
total_lengths = 0
total_t0 = time.perf_counter()
if fabric.device.type == "xla":
import torch_xla.core.xla_model as xm
xm.mark_step()
initial_iter = state["iter_num"]
curr_iter = 0
loss_func = FusedCrossEntropyLoss()
for train_data in train_dataloader:
# resume loader state. This is not elegant but it works. Should rewrite it in the future.
if resume:
if curr_iter < initial_iter:
curr_iter += 1
continue
else:
resume = False
curr_iter = -1
fabric.barrier()
fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0))
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr
iter_t0 = time.perf_counter()
input_ids = train_data[:, 0 : model.config.block_size].contiguous()
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = loss_func(logits, targets)
# loss = chunked_cross_entropy(logits, targets, chunk_size=0)
fabric.backward(loss / gradient_accumulation_steps)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1
elif fabric.device.type == "xla":
xm.mark_step()
state["iter_num"] += 1
# input_id: B L
total_lengths += input_ids.size(1)
t1 = time.perf_counter()
fabric.print(
f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. "
# print days as well
f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. "
)
monitor.on_train_batch_end(
state["iter_num"] * micro_batch_size,
t1 - total_t0,
# this assumes that device FLOPs are the same and that all devices have the same batch size
fabric.world_size,
state["step_count"],
flops_per_batch=estimated_flops,
lengths=total_lengths,
train_loss = loss.item()
)
if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader)
t1 = time.perf_counter() - t0
monitor.eval_end(t1)
fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"])
fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"])
fabric.barrier()
if not is_accumulating and state["step_count"] % save_step_interval == 0:
checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}")
fabric.save(checkpoint_path, state)
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters, device=fabric.device)
for k, val_data in enumerate(val_dataloader):
if k >= eval_iters:
break
input_ids = val_data[:, 0 : model.config.block_size].contiguous()
targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
# loss_func = FusedCrossEntropyLoss()
# loss = loss_func(logits, targets)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def create_dataloader(
batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train"
) -> DataLoader:
datasets = []
data_config = train_data_config if split == "train" else val_data_config
for prefix, _ in data_config:
filenames = sorted(glob.glob(str(data_dir / f"{prefix}*")))
random.seed(seed)
random.shuffle(filenames)
dataset = PackedDataset(
filenames,
# n_chunks control the buffer size.
# Note that the buffer size also impacts the random shuffle
# (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)
n_chunks=8,
block_size=block_size,
shuffle=shuffle,
seed=seed+fabric.global_rank,
num_processes=fabric.world_size,
process_rank=fabric.global_rank,
)
datasets.append(dataset)
if not datasets:
raise RuntimeError(
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
)
weights = [weight for _, weight in data_config]
sum_weights = sum(weights)
weights = [el / sum_weights for el in weights]
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
def create_dataloaders(
batch_size: int,
block_size: int,
fabric,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
seed: int = 12345,
) -> Tuple[DataLoader, DataLoader]:
# Increase by one because we need the next word as well
effective_block_size = block_size + 1
train_dataloader = create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=train_data_dir,
shuffle=True,
seed=seed,
split="train"
)
val_dataloader = (
create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=val_data_dir,
shuffle=False,
seed=seed,
split="validation"
)
if val_data_dir
else None
)
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse import CLI
CLI(setup)
import glob
import math
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union
import math
import lightning as L
import torch
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader
from functools import partial
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually
from lit_gpt.model import GPT, Block, Config, CausalSelfAttention
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load
from pytorch_lightning.loggers import WandbLogger
from lit_gpt import FusedCrossEntropyLoss
import random
model_name = "tiny_LLaMA_1b"
name = "tiny_LLaMA_1b"
out_dir = Path("out") / name
checkpoint_path = "out/TinyLlama-1.1B-intermediate-step-240k-503b/lit_model.pth"
# Hyperparameters
num_of_devices = 6
global_batch_size = 360
learning_rate = 2e-4
min_lr = 2e-5
micro_batch_size = 6
max_step = 10000
warmup_steps = 0
log_step_interval = 1
eval_iters = 1000000
save_step_interval = 2000
eval_step_interval = 2000
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0
decay_lr = True
batch_size = global_batch_size // num_of_devices
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0
warmup_iters = warmup_steps * gradient_accumulation_steps
max_iters = max_step * gradient_accumulation_steps
lr_decay_iters = max_iters
log_iter_interval = log_step_interval * gradient_accumulation_steps
# Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight.
train_data_config = [
("train_starcoder", 1),
]
val_data_config = [
("validation", 1.0),
]
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval)
wandb_logger = WandbLogger()
def setup(
devices: int = 8,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
precision: Optional[str] = None,
tpu: bool = False,
resume: Union[bool, Path] = False,
) -> None:
precision = precision or get_default_supported_precision(training=True, tpu=tpu)
if devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy=None,
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=False,
)
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger])
fabric.print(hparams)
fabric.launch(main, train_data_dir, val_data_dir, resume)
# main(fabric, train_data_dir, val_data_dir, resume)
def main(fabric, train_data_dir, val_data_dir, resume):
monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval)
if fabric.global_rank == 0:
out_dir.mkdir(parents=True, exist_ok=True)
config = Config.from_name(model_name)
train_dataloader, val_dataloader = create_dataloaders(
batch_size=micro_batch_size,
block_size=config.block_size,
fabric=fabric,
train_data_dir=train_data_dir,
val_data_dir=val_data_dir,
seed=3407,
)
if val_dataloader is None:
train_dataloader = fabric.setup_dataloaders(train_dataloader)
else:
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
fabric.seed_everything(3407) # same seed for every process to init model (FSDP)
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)
model = fabric.setup(model)
fabric.load_raw(checkpoint_path, model, strict=True)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
fabric.print(f"Total parameters {num_parameters(model):,}")
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)
# import bitsandbytes as bnb
# optimizer = bnb.optim.AdamW8bit(
# model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)
# )
# optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True)
optimizer = fabric.setup_optimizers(optimizer)
state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}
if resume is True:
resume = sorted(out_dir.glob("*.pth"))[-1]
if resume :
fabric.print(f"Resuming training from {resume}")
fabric.load(resume, state)
train_time = time.perf_counter()
train(fabric, state, train_dataloader, val_dataloader, monitor, resume)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
def train(fabric, state, train_dataloader, val_dataloader, monitor, resume):
model = state["model"]
optimizer = state["optimizer"]
if val_dataloader is not None:
validate(fabric, model, val_dataloader) # sanity check
with torch.device("meta"):
meta_model = GPT(model.config)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
estimated_flops = estimate_flops(meta_model) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
x = torch.randint(0, 1, (micro_batch_size, model.config.block_size))
# measured_flos run in meta. Will trigger fusedRMSNorm error
#measured_flops = measure_flops(meta_model, x)
#fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
total_lengths = 0
total_t0 = time.perf_counter()
if fabric.device.type == "xla":
import torch_xla.core.xla_model as xm
xm.mark_step()
initial_iter = state["iter_num"]
curr_iter = 0
loss_func = FusedCrossEntropyLoss()
for train_data in train_dataloader:
# resume loader state. This is not elegant but it works. Should rewrite it in the future.
if resume:
if curr_iter < initial_iter:
curr_iter += 1
continue
else:
resume = False
curr_iter = -1
fabric.barrier()
fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0))
if state["iter_num"] >= max_iters:
break
# determine and set the learning rate for this iteration
lr = get_lr(state["iter_num"]) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr
iter_t0 = time.perf_counter()
input_ids = train_data[:, 0 : model.config.block_size].contiguous()
targets = train_data[:, 1 : model.config.block_size + 1].contiguous()
is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids)
loss = loss_func(logits, targets)
# loss = chunked_cross_entropy(logits, targets, chunk_size=0)
fabric.backward(loss / gradient_accumulation_steps)
if not is_accumulating:
fabric.clip_gradients(model, optimizer, max_norm=grad_clip)
optimizer.step()
optimizer.zero_grad()
state["step_count"] += 1
elif fabric.device.type == "xla":
xm.mark_step()
state["iter_num"] += 1
# input_id: B L
total_lengths += input_ids.size(1)
t1 = time.perf_counter()
fabric.print(
f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. "
# print days as well
f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. "
)
monitor.on_train_batch_end(
state["iter_num"] * micro_batch_size,
t1 - total_t0,
# this assumes that device FLOPs are the same and that all devices have the same batch size
fabric.world_size,
state["step_count"],
flops_per_batch=estimated_flops,
lengths=total_lengths,
train_loss = loss.item()
)
if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_dataloader)
t1 = time.perf_counter() - t0
monitor.eval_end(t1)
fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"])
fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"])
fabric.barrier()
if not is_accumulating and state["step_count"] % save_step_interval == 0:
checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth"
fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}")
fabric.save(checkpoint_path, state)
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters, device=fabric.device)
for k, val_data in enumerate(val_dataloader):
if k >= eval_iters:
break
input_ids = val_data[:, 0 : model.config.block_size].contiguous()
targets = val_data[:, 1 : model.config.block_size + 1].contiguous()
logits = model(input_ids)
loss = chunked_cross_entropy(logits, targets, chunk_size=0)
# loss_func = FusedCrossEntropyLoss()
# loss = loss_func(logits, targets)
losses[k] = loss.item()
out = losses.mean()
model.train()
return out
def create_dataloader(
batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train"
) -> DataLoader:
datasets = []
data_config = train_data_config if split == "train" else val_data_config
for prefix, _ in data_config:
filenames = sorted(glob.glob(str(data_dir / f"{prefix}*")))
random.seed(seed)
random.shuffle(filenames)
dataset = PackedDataset(
filenames,
# n_chunks control the buffer size.
# Note that the buffer size also impacts the random shuffle
# (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer)
n_chunks=8,
block_size=block_size,
shuffle=shuffle,
seed=seed+fabric.global_rank,
num_processes=fabric.world_size,
process_rank=fabric.global_rank,
)
datasets.append(dataset)
if not datasets:
raise RuntimeError(
f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset."
)
weights = [weight for _, weight in data_config]
sum_weights = sum(weights)
weights = [el / sum_weights for el in weights]
combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)
return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
def create_dataloaders(
batch_size: int,
block_size: int,
fabric,
train_data_dir: Path = Path("data/redpajama_sample"),
val_data_dir: Optional[Path] = None,
seed: int = 12345,
) -> Tuple[DataLoader, DataLoader]:
# Increase by one because we need the next word as well
effective_block_size = block_size + 1
train_dataloader = create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=train_data_dir,
shuffle=True,
seed=seed,
split="train"
)
val_dataloader = (
create_dataloader(
batch_size=batch_size,
block_size=effective_block_size,
fabric=fabric,
data_dir=val_data_dir,
shuffle=False,
seed=seed,
split="validation"
)
if val_data_dir
else None
)
return train_dataloader, val_dataloader
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse import CLI
CLI(setup)
torch>=2.1.0dev
lightning==2.1.2
lightning[app]
jsonargparse[signatures] # CLI
pandas
pyarrow
tokenizers
sentencepiece
wandb
zstd
# for finetuning
bitsandbytes==0.40.0
transformers==4.31.0
peft==0.4.0
accelerate==0.21.0
einops==0.6.1
evaluate==0.4.0
scikit-learn==1.2.2
sentencepiece==0.1.99
wandb==0.15.3
# other optional dependencies are
# sentencepiece # pythia, falcon, redpajama
# tokenizers # llama-based models
# bitsandbytes>=0.41.1 # quantize/bnb.py
# scipy # TODO: remove when https://github.com/TimDettmers/bitsandbytes/pull/525 is released
# datasets # quantize/gptq.py
# zstandard # scripts/prepare_redpajama.py
# git+https://github.com/EleutherAI/lm-evaluation-harness.git@master # eval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment