Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/brrr/nanotron/examples/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- name: General purpose training
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 24
hf_dataset_config_name: null
hf_dataset_or_datasets:
roneneldan/TinyStories: 1.0
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: test
run: mamba
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
initializer_range: 0.02
n_residuals_per_layer: 1
rescale_prenorm_residual: true
make_vocab_size_divisible_by: 1
model_config:
d_model: 1536
dtype: bfloat16
fused_add_norm: true
is_mamba_config: true
num_hidden_layers: 48
pad_token_id: null
pad_vocab_size_multiple: 8
residual_in_fp32: true
rms_norm: true
rms_norm_eps: 1.0e-05
ssm_cfg:
bias: false
conv_bias: true
d_conv: 4
d_state: 16
dt_init: random
dt_init_floor: 0.0001
dt_max: 0.1
dt_min: 0.001
dt_rank: auto
dt_scale: 1.0
expand: 2
use_fast_path: true
vocab_size: 50277
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 90
lr_decay_style: cosine
lr_warmup_steps: 10
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 2048
train_steps: 100
val_check_interval: -1
# ruff: noqa: E402
"""
Converts a HF model to a Nanotron model
Command:
torchrun --nproc_per_node=1 convert_hf_to_nanotron.py --inp_path state-spaces/mamba-130m-hf --out_path nanotron_weights
"""
import argparse
import json
from dataclasses import asdict
from pathlib import Path
from typing import Dict
import torch
import yaml
from config import MambaConfig, MambaInit, MambaModelConfig
from mamba import MambaForTraining
from nanotron import logging
from nanotron.config import (
AllForwardAllBackwardPipelineEngine,
GeneralArgs,
LoggingArgs,
ModelArgs,
ParallelismArgs,
TensorParallelLinearMode,
TokenizerArgs,
)
from nanotron.distributed import dist
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.serialize import save_meta, save_weights
from nanotron.trainer import mark_tied_parameters
from tqdm import tqdm
from transformers import MambaConfig as HFMambaConfig
from transformers import MambaForCausalLM
from transformers.utils import CONFIG_NAME
from transformers.utils.hub import cached_file
logger = logging.get_logger(__name__)
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def get_weight_from_hf(
name: str,
ref_module_state_dict: Dict[str, torch.Tensor],
ref_module: MambaForCausalLM,
nanotron_to_hf: Dict[str, str],
get_grad: bool = False,
param_is_tp_sharded: bool = False,
) -> torch.Tensor:
"""From our brrr implementation, we get the equivalent tensor in transformers implementation"""
def _interleave_pattern(N):
"""
interleave_pattern(4) -> [0, 2, 1, 3]
interleave_pattern(8) -> [0, 4, 1, 5, 2, 6, 3, 7]
"""
assert N % 2 == 0, "N must be even"
pattern = []
for i in range(N // 2):
pattern.append(i)
pattern.append(i + N // 2)
return pattern
hf_name = nanotron_to_hf[name]
if get_grad is False:
def _get_tensor(path: str):
return ref_module_state_dict[path]
else:
def _get_tensor(path: str):
param = ref_module.get_parameter(path)
return param.grad
param = _get_tensor(hf_name)
if "in_proj" in hf_name:
# In Nanotron, we do tensor parallel column so weight need to be split in the column dimension (i.e: xz.view(...))
# However, the HF weights was trained such that it expected xz.chunk(...) to split the tensor in the row dimension
# Thus, we need to interleaved the HF weights to make it compatible with Nanotron
log_rank(
f"Interleaving {hf_name} to make it compatible with Nanotron", logger=logger, level=logging.INFO, rank=0
)
return param[_interleave_pattern(param.shape[0]), :]
return param
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert HF weights from states-space repo to brrr weights")
parser.add_argument("--inp_path", type=str, default="state-spaces/mamba-130m-hf")
parser.add_argument("--out_path", type=str, default="nanotron_weight")
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--tp", type=int, default=1)
args = parser.parse_args()
out_path = Path(args.out_path)
parallel_config = ParallelismArgs(
dp=args.dp,
pp=args.pp,
tp=args.tp,
pp_engine=AllForwardAllBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
assert (
parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE
and parallel_config.tp_linear_async_communication is False
)
parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)
# Set log log levels
logging_config = LoggingArgs(
log_level="info",
log_level_replica="info",
)
# Set log levels
set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config)
hf_config = HFMambaConfig.from_pretrained(args.inp_path)
dtype_str = "float32"
# TODO(fmom): Add support for ssm_cfg
yaml_content = f"""
is_mamba_config: true
d_model: {hf_config.hidden_size}
dtype: {dtype_str}
fused_add_norm: true
is_mamba_config: true
num_hidden_layers: {hf_config.num_hidden_layers}
pad_token_id: null
pad_vocab_size_multiple: 8
residual_in_fp32: true
rms_norm: true
rms_norm_eps: 1.0e-05
ssm_cfg: null
vocab_size: {hf_config.vocab_size}
"""
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
attrs = yaml.safe_load(yaml_content)
model_config = MambaModelConfig(**attrs)
model_ref = MambaForCausalLM.from_pretrained(args.inp_path)
model_ref.to(device, dtype=dtype)
model_ref.eval()
nanotron_model = build_model(
model_builder=lambda: MambaForTraining(
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
parallel_context=parallel_context,
dtype=dtype,
device=device,
)
device_map = {}
current_pp_rank = dist.get_rank(parallel_context.pp_pg)
tied_embs_ranks = [nanotron_model.model.token_position_embeddings.rank, nanotron_model.model.lm_head.rank]
device_map["backbone.embedding"] = (
nanotron_model.model.token_position_embeddings.rank if current_pp_rank in tied_embs_ranks else "meta"
)
for i in range(model_config.num_hidden_layers):
device_map[f"backbone.layers[{i}]"] = (
nanotron_model.model.decoder[i].rank if current_pp_rank == nanotron_model.model.decoder[i].rank else "meta"
)
device_map["lm_head"] = nanotron_model.model.lm_head.rank if current_pp_rank in tied_embs_ranks else "meta"
# Get mapping of Nanotron layer to HF layer
nanotron_to_hf = {}
# Static mappings
nanotron_to_hf["token_position_embeddings.pp_block.token_embedding.weight"] = "backbone.embeddings.weight"
nanotron_to_hf["final_layer_norm.pp_block.weight"] = "backbone.norm_f.weight"
nanotron_to_hf["lm_head.pp_block.weight"] = "lm_head.weight"
# Dynamic mappings within a loop
for i in range(model_config.num_hidden_layers):
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.A_log"] = f"backbone.layers.{i}.mixer.A_log"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.D"] = f"backbone.layers.{i}.mixer.D"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.in_proj.weight"] = f"backbone.layers.{i}.mixer.in_proj.weight"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.conv1d.weight"] = f"backbone.layers.{i}.mixer.conv1d.weight"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.conv1d.bias"] = f"backbone.layers.{i}.mixer.conv1d.bias"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.x_proj.weight"] = f"backbone.layers.{i}.mixer.x_proj.weight"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.x_proj.bias"] = f"backbone.layers.{i}.mixer.x_proj.bias"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.dt_proj.weight"] = f"backbone.layers.{i}.mixer.dt_proj.weight"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.dt_proj.bias"] = f"backbone.layers.{i}.mixer.dt_proj.bias"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.out_proj.weight"] = f"backbone.layers.{i}.mixer.out_proj.weight"
nanotron_to_hf[f"decoder.{i}.pp_block.mixer.out_proj.bias"] = f"backbone.layers.{i}.mixer.out_proj.bias"
nanotron_to_hf[f"decoder.{i}.pp_block.norm.weight"] = f"backbone.layers.{i}.norm.weight"
# Sync weights
ref_state_dict = model_ref.state_dict()
for name, param in tqdm(
nanotron_model.model.named_parameters(),
total=len(list(nanotron_model.model.named_parameters())),
desc="Converting",
):
param_is_tp_sharded = (
isinstance(param, NanotronParameter)
and param.is_sharded
and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg
)
ref_param = get_weight_from_hf(
name=name,
ref_module_state_dict=ref_state_dict,
ref_module=model_ref,
nanotron_to_hf=nanotron_to_hf,
param_is_tp_sharded=param_is_tp_sharded,
)
if param_is_tp_sharded:
sharded_info = param.get_sharded_info()
# copy param data (not just the reference)
with torch.no_grad():
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
local_slices = local_global_slices_pair.local_slices
global_slices = local_global_slices_pair.global_slices
param[local_slices].copy_(ref_param[global_slices])
else:
assert (
ref_param.shape == param.shape
), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}"
# copy param data (not just the reference)
with torch.no_grad():
param.copy_(ref_param)
ref_param = None
torch.cuda.empty_cache()
# Marks parameters as NanotronParameters
mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context)
sanity_check(root_module=nanotron_model)
save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=out_path)
checkpoint_metadata = {
"last_train_step": 0,
"consumed_train_samples": 0,
}
save_meta(root_folder=out_path, parallel_context=parallel_context, checkpoint_metadata=checkpoint_metadata)
if dist.get_rank() == 0:
with open(out_path / "config.yaml", "w") as f:
config = MambaConfig(
general=GeneralArgs(project="test", run="mamba"),
parallelism=parallel_config,
model=ModelArgs(
init_method=MambaInit(),
model_config=model_config,
),
tokenizer=TokenizerArgs(args.inp_path),
)
log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0)
yaml.dump(config.as_dict(), f)
with open(out_path / "model_config.json", "w") as f:
log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0)
json.dump(asdict(model_config), f)
# ruff: noqa: E402
"""
Converts a nanotron model to HF format
Command:
torchrun --nproc_per_node=1 convert_nanotron_to_hf.py --checkpoint_path=nanotron_weights --save_path=HF_weights
"""
import argparse
import json
from pathlib import Path
import torch
import yaml
from config import MambaModelConfig
from mamba import MambaForTraining
from nanotron import logging
from nanotron.config import (
AllForwardAllBackwardPipelineEngine,
ParallelismArgs,
TensorParallelLinearMode,
)
from nanotron.models import build_model, init_on_device_and_dtype
from nanotron.parallel import ParallelContext
from nanotron.serialize import load_weights
from nanotron.trainer import mark_tied_parameters
from transformers import AutoTokenizer, MambaConfig, MambaForCausalLM
logger = logging.get_logger(__name__)
def convert_checkpoint_and_save(checkpoint_path: Path, save_path: Path):
device = torch.device("cuda")
with open(checkpoint_path / "config.yaml", "r") as f:
attrs = yaml.safe_load(f)
tokenizer_name = attrs["tokenizer"]["tokenizer_name_or_path"]
with open(checkpoint_path / "model_config.json", "r") as f:
attrs = json.load(f)
model_config = MambaModelConfig(**attrs)
dtype = getattr(torch, model_config.dtype)
parallel_config = ParallelismArgs(
dp=1,
pp=1,
tp=1,
pp_engine=AllForwardAllBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
parallel_context = ParallelContext(
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_size=1,
)
model_nanotron = build_model(
model_builder=lambda: MambaForTraining(
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
parallel_context=parallel_context,
dtype=dtype,
device=device,
)
mark_tied_parameters(model=model_nanotron, parallel_context=parallel_context)
# Load checkpoint directly in memory and then only keep the state dictionary
load_weights(model=model_nanotron, parallel_context=parallel_context, root_folder=checkpoint_path)
model_nanotron_state_dict = model_nanotron.state_dict()
del model_nanotron
# Init the HF mode
if model_config.ssm_cfg is None:
model_config_hf = MambaConfig(
vocab_size=model_config.vocab_size,
num_hidden_layers=model_config.num_hidden_layers,
residual_in_fp32=model_config.residual_in_fp32,
layer_norm_epsilon=model_config.rms_norm_eps,
hidden_size=model_config.d_model,
)
else:
model_config_hf = MambaConfig(
vocab_size=model_config.vocab_size,
num_hidden_layers=model_config.num_hidden_layers,
residual_in_fp32=model_config.residual_in_fp32,
layer_norm_epsilon=model_config.rms_norm_eps,
hidden_size=model_config.d_model,
state_size=model_config.ssm_cfg["d_state"],
expand=model_config.ssm_cfg["expand"],
conv_kernel=model_config.ssm_cfg["d_conv"],
use_bias=model_config.ssm_cfg["bias"],
use_conv_bias=model_config.ssm_cfg["conv_bias"],
time_step_rank=model_config.ssm_cfg["dt_rank"],
time_step_scale=model_config.ssm_cfg["dt_scale"],
time_step_min=model_config.ssm_cfg["dt_min"],
time_step_max=model_config.ssm_cfg["dt_max"],
time_step_init_scheme=model_config.ssm_cfg["dt_init"],
time_step_floor=model_config.ssm_cfg["dt_init_floor"],
)
# Initialised HF model
with init_on_device_and_dtype(device, dtype):
model_hf = MambaForCausalLM._from_config(model_config_hf)
# Get mapping of Nanotron layer and HF layer
hf_to_nanotron = {}
# Static mappings
hf_to_nanotron["backbone.embeddings.weight"] = "token_position_embeddings.pp_block.token_embedding.weight"
hf_to_nanotron["backbone.norm_f.weight"] = "final_layer_norm.pp_block.weight"
hf_to_nanotron["lm_head.weight"] = "lm_head.pp_block.weight"
# Dynamic mappings within a loop
for i in range(model_config.num_hidden_layers):
hf_to_nanotron[f"backbone.layers.{i}.mixer.A_log"] = f"decoder.{i}.pp_block.mixer.A_log"
hf_to_nanotron[f"backbone.layers.{i}.mixer.D"] = f"decoder.{i}.pp_block.mixer.D"
hf_to_nanotron[f"backbone.layers.{i}.mixer.in_proj.weight"] = f"decoder.{i}.pp_block.mixer.in_proj.weight"
hf_to_nanotron[f"backbone.layers.{i}.mixer.conv1d.weight"] = f"decoder.{i}.pp_block.mixer.conv1d.weight"
hf_to_nanotron[f"backbone.layers.{i}.mixer.conv1d.bias"] = f"decoder.{i}.pp_block.mixer.conv1d.bias"
hf_to_nanotron[f"backbone.layers.{i}.mixer.x_proj.weight"] = f"decoder.{i}.pp_block.mixer.x_proj.weight"
hf_to_nanotron[f"backbone.layers.{i}.mixer.x_proj.bias"] = f"decoder.{i}.pp_block.mixer.x_proj.bias"
hf_to_nanotron[f"backbone.layers.{i}.mixer.dt_proj.weight"] = f"decoder.{i}.pp_block.mixer.dt_proj.weight"
hf_to_nanotron[f"backbone.layers.{i}.mixer.dt_proj.bias"] = f"decoder.{i}.pp_block.mixer.dt_proj.bias"
hf_to_nanotron[f"backbone.layers.{i}.mixer.out_proj.weight"] = f"decoder.{i}.pp_block.mixer.out_proj.weight"
hf_to_nanotron[f"backbone.layers.{i}.mixer.out_proj.bias"] = f"decoder.{i}.pp_block.mixer.out_proj.bias"
hf_to_nanotron[f"backbone.layers.{i}.norm.weight"] = f"decoder.{i}.pp_block.norm.weight"
def _reverse_interleave_pattern(N):
"""
Compute the reverse of the interleave pattern given by _interleave_pattern.
Example:
reverse_interleave_pattern(4) -> [0, 2, 1, 3]
reverse_interleave_pattern(8) -> [0, 2, 4, 6, 1, 3, 5, 7]
"""
assert N % 2 == 0, "N must be even"
def __interleave_pattern(N):
"""
interleave_pattern(4) -> [0, 2, 1, 3]
interleave_pattern(8) -> [0, 4, 1, 5, 2, 6, 3, 7]
"""
assert N % 2 == 0, "N must be even"
pattern = []
for i in range(N // 2):
pattern.append(i)
pattern.append(i + N // 2)
return pattern
interleaved_pattern = __interleave_pattern(N)
reverse_pattern = [0] * N
for original_index, interleaved_index in enumerate(interleaved_pattern):
reverse_pattern[interleaved_index] = original_index
return reverse_pattern
# Loop over the state dict and convert the keys to HF format
for module_name_hf, module_hf in model_hf.named_modules():
for param_name_hf, param_hf in module_hf.named_parameters(recurse=False):
# Get the Nanotron parameter
nanotron_key = "model." + hf_to_nanotron[f"{module_name_hf}.{param_name_hf}"]
param = model_nanotron_state_dict[nanotron_key]
if "in_proj" in nanotron_key:
# Undo the interleaving weights in Nanotron to make it HF compatible
param = param[_reverse_interleave_pattern(param.shape[0]), :]
with torch.no_grad():
param_hf.copy_(param)
# Save the model
model_hf.save_pretrained(save_path)
print(f"Model saved to {save_path}")
# Save the tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.save_pretrained(save_path)
print(f"Tokenizer saved to {save_path}")
def check_converted_model_generation(save_path: Path):
HARCODED_PROMPT = "What is your "
tokenizer = AutoTokenizer.from_pretrained(save_path)
input_ids = tokenizer(HARCODED_PROMPT, return_tensors="pt")["input_ids"]
print("Inputs:", tokenizer.batch_decode(input_ids))
model = MambaForCausalLM.from_pretrained(save_path)
out = model.generate(input_ids, max_new_tokens=100)
print("Generation (converted): ", tokenizer.batch_decode(out))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Nanotron weights to HF format")
parser.add_argument("--checkpoint_path", type=str, default="mamba-130m")
parser.add_argument("--save_path", type=str, default="mamba-hf")
args = parser.parse_args()
save_path = Path(args.save_path)
checkpoint_path = Path(args.checkpoint_path)
# Convert Nanotron model to HF format
convert_checkpoint_and_save(checkpoint_path=checkpoint_path, save_path=save_path)
# check if the conversion was successful by generating some text
check_converted_model_generation(save_path=save_path)
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import math
import os
import uuid
from config import MambaConfig, MambaInit, MambaModelConfig
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
PretrainDatasetsArgs,
TokenizerArgs,
TokensArgs,
)
from nanotron.logging import human_format
new_job_id = uuid.uuid4()
job_id = str(new_job_id)[:8]
seed = 42
ssm_cfg_dtype = "bfloat16"
ssm_cfg = {
"d_state": 16,
"d_conv": 4,
"expand": 2,
"dt_rank": "auto",
"dt_min": 0.001,
"dt_max": 0.1,
"dt_init": "random",
"dt_scale": 1.0,
"dt_init_floor": 1e-4,
"conv_bias": True,
"bias": False,
"use_fast_path": True,
}
# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json
model_config = MambaModelConfig(
d_model=1024,
num_hidden_layers=2,
vocab_size=50278,
ssm_cfg=ssm_cfg,
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
pad_vocab_size_multiple=8,
# Custom
dtype=ssm_cfg_dtype,
rms_norm_eps=1e-5,
)
# NOTE: vocab_size is normally round up to the nearest multiple of 10. But here, we don't really care
tie_embedding = model_config.vocab_size * model_config.d_model # model_config.vocab_size * model_config.d_model
expand = 2 if ("expand" not in ssm_cfg) else ssm_cfg["expand"]
ngroups = 1 if ("ngroups" not in ssm_cfg) else ssm_cfg["ngroups"]
d_state = 16 if ("d_state" not in ssm_cfg) else ssm_cfg["d_state"]
d_conv = 4 if ("d_conv" not in ssm_cfg) else ssm_cfg["d_conv"]
dt_rank = (
math.ceil(model_config.d_model / 16)
if ("dt_rank" not in ssm_cfg or ssm_cfg["dt_rank"] == "auto")
else ssm_cfg["dt_rank"]
)
d_inner = int(expand * model_config.d_model)
in_proj = model_config.d_model * d_inner * 2
# conv1d.weight = out_channels * (in_channels // groups) * kernel_size
# conv1d.bias = out_channels
conv1d = d_inner * int(d_inner / d_inner) * d_conv + d_inner
# linear.weight = out_features * in_features
in_proj = model_config.d_model * d_inner * 2 + 0
x_proj = d_inner * (dt_rank + d_state * 2) + 0
out_proj = d_inner * model_config.d_model + 0
dt_proj = dt_rank * d_inner + d_inner
A_log = d_inner * d_state
D = d_inner
norm = model_config.d_model
norm_f = model_config.d_model
num_params = human_format(
(
tie_embedding
+ model_config.num_hidden_layers * (A_log + D + in_proj + conv1d + x_proj + dt_proj + out_proj + norm + norm_f)
)
).replace(".", "p")
print(f"Model has {num_params} parameters")
seed = 42
optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1
learning_rate_scheduler=LRSchedulerArgs(
learning_rate=0.0015,
lr_warmup_steps=30,
lr_warmup_style="linear",
lr_decay_style="cosine",
min_decay_lr=0.00015,
),
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)
parallelism = ParallelismArgs(
dp=2,
pp=2,
tp=2,
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
tp_linear_async_communication=False,
)
tokens = TokensArgs(sequence_length=2048, train_steps=300, micro_batch_size=8, batch_accumulation_per_replica=1)
data_stages = [
DatasetStageArgs(
name="Stable Training Stage",
start_training_step=1,
data=DataArgs(
dataset=PretrainDatasetsArgs(hf_dataset_or_datasets="roneneldan/TinyStories", text_column_name="text"),
seed=seed,
),
)
]
model = ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
)
checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
config = MambaConfig(
general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100),
parallelism=parallelism,
model=model,
tokenizer=TokenizerArgs("gpt2"),
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data_stages=data_stages,
profiler=None,
)
if __name__ == "__main__":
dir = os.path.dirname(__file__)
# Save config as YAML file
config.save_as_yaml(f"{dir}/config_mamba.yaml")
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mamba model."""
import math
from functools import partial
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import MambaModelConfig
from einops import rearrange, repeat
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import ParallelismArgs
from nanotron.config.utils_config import cast_str_to_torch_dtype
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from selective_scan_interface import mamba_inner_fn, selective_scan_fn
from torch.nn import init
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
logger = logging.get_logger(__name__)
class Mamba(nn.Module, AttachableStore):
def __init__(
self,
d_model: int,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
d_state: int = 16,
d_conv: int = 4,
expand: int = 2,
dt_rank: str = "auto",
dt_min: float = 0.001,
dt_max: float = 0.1,
dt_init: str = "random",
dt_scale: float = 1.0,
dt_init_floor: float = 1e-4,
conv_bias: bool = True,
bias: bool = False,
use_fast_path: bool = True, # Fused kernel options
layer_idx: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
assert tp_mode == TensorParallelLinearMode.ALL_REDUCE or parallel_config.tp_linear_async_communication is False
"Only ALL_REDUCE and tp_linear_async_communication=False are supported"
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
# Get current tensor parallel rank
self.tp_pg = tp_pg
self.tp_rank = dist.get_rank(self.tp_pg)
self.in_proj = TensorParallelColumnLinear(
in_features=self.d_model,
out_features=self.d_inner * 2,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=False,
contiguous_chunks=None,
)
assert self.d_inner % self.tp_pg.size() == 0
self.conv1d = nn.Conv1d(
in_channels=self.d_inner // self.tp_pg.size(),
out_channels=self.d_inner // self.tp_pg.size(),
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner // self.tp_pg.size(),
padding=d_conv - 1,
**factory_kwargs,
)
self.conv1d.weight = create_sharded_parameter_from_config(
parameter=self.conv1d.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
if conv_bias:
self.conv1d.bias = create_sharded_parameter_from_config(
parameter=self.conv1d.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = TensorParallelRowLinear(
in_features=self.d_inner,
out_features=self.dt_rank + self.d_state * 2,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=None,
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner // self.tp_pg.size(), bias=True, **factory_kwargs)
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner // self.tp_pg.size(), **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj.weight = create_sharded_parameter_from_config(
parameter=self.dt_proj.weight, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
self.dt_proj.bias = create_sharded_parameter_from_config(
parameter=self.dt_proj.bias, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner // self.tp_pg.size(),
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = create_sharded_parameter_from_config(
parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
# D "skip" parameter
self.D = create_sharded_parameter_from_config(
parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device),
pg=self.tp_pg,
split_config=SplitConfig(split_dim=0),
)
# self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.out_proj = TensorParallelRowLinear(
in_features=self.d_inner,
out_features=self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=None,
)
def forward(self, hidden_states: Union[torch.Tensor, TensorPointer]):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch_size, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
store = self.get_local_store()
if store is not None:
if "key_value_memory_list" not in store:
store["key_value_memory_list"] = []
if "seqlen_offset" not in store:
store["seqlen_offset"] = 0
conv_state, ssm_state = self._get_states_from_cache(batch_size)
if store["seqlen_offset"] > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
store["seqlen_offset"] += 1
return out
else:
store["seqlen_offset"] += 1
# We do matmul and transpose BLH -> HBL at the same time
xz = self.in_proj(hidden_states).transpose(1, 2)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and store is None: # Doesn't support outputting the states
y = mamba_inner_fn(
d_inner=self.d_inner,
tp_pg=self.tp_pg,
xz=xz,
conv1d_weight=self.conv1d.weight,
conv1d_bias=self.conv1d.bias,
x_proj_weight=self.x_proj.weight,
delta_proj_weight=self.dt_proj.weight,
A=A,
B=None, # input-dependent B
C=None, # input-dependent C
D=self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
if self.tp_pg.size() > 1:
x, z = xz.view(batch_size, self.d_inner // 2, 2, seqlen).chunk(2, dim=2)
else:
x, z = xz.view(batch_size, self.d_inner, 2, seqlen).chunk(2, dim=2)
x = x.squeeze(2)
z = z.squeeze(2)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
batch_size, seqlen, dim = hidden_states.shape
dtype = hidden_states.dtype
assert seqlen == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
if self.tp_pg.size() > 1:
x, z = xz.view(batch_size, self.d_inner // 2, 2).chunk(2, dim=2)
else:
x, z = xz.view(batch_size, self.d_inner, 2).chunk(2, dim=2)
x = x.squeeze(2) # (B D)
z = z.squeeze(2) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def _get_states_from_cache(self, batch_size: int, initialize_states: bool = False):
assert self.layer_idx is not None
store = self.get_local_store()
if len(store["key_value_memory_list"]) == 0:
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand // self.tp_pg.size(),
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand // self.tp_pg.size(),
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
)
store["key_value_memory_list"] = (conv_state, ssm_state)
else:
conv_state, ssm_state = store["key_value_memory_list"]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
class Embedding(nn.Module, AttachableStore):
def __init__(
self,
tp_pg: dist.ProcessGroup,
config: MambaModelConfig,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.d_model,
padding_idx=config.pad_token_id,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
store = self.get_local_store()
if store is not None:
if "past_length" in store:
past_length = store["past_length"]
else:
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# Store new past_length in store
store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
# input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class MambaDecoderLayer(nn.Module):
def __init__(
self,
config: MambaModelConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
if config.ssm_cfg is None:
ssm_cfg = {}
else:
ssm_cfg = config.ssm_cfg
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.fused_add_norm = config.fused_add_norm
self.mixer = Mamba(
d_model=config.d_model,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
**ssm_cfg,
**factory_kwargs,
)
self.norm = partial(
nn.LayerNorm if not config.rms_norm else RMSNorm,
eps=config.rms_norm_eps,
**factory_kwargs,
)(config.d_model)
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
residual: Optional[Union[torch.Tensor, TensorPointer]],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
if not self.fused_add_norm:
# self.layer_idx was assigned when calling create_block
# residual=None happens only at the first block
residual = hidden_states if (self.layer_idx == 0) else hidden_states + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
hidden_states, residual = fused_add_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=None if (self.layer_idx == 0) else residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states)
return {
"hidden_states": hidden_states,
"sequence_mask": sequence_mask, # NOTE(fmom): dunno how to use it for now. Just keep it
"residual": residual,
}
class MambaModel(nn.Module):
def __init__(
self,
config: MambaModelConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=MambaDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
"device": self.p2p.device,
"dtype": cast_str_to_torch_dtype(config.dtype),
},
module_input_keys={"hidden_states", "sequence_mask", "residual"},
module_output_keys={"hidden_states", "sequence_mask", "residual"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=RMSNorm,
module_kwargs={"hidden_size": config.d_model, "eps": config.rms_norm_eps},
module_input_keys={"x", "residual"},
module_output_keys={"hidden_states"},
)
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.d_model,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
def forward_with_hidden_states(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
"residual": output["input_embeds"],
}
for block in self.decoder:
hidden_encoder_states = block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(
x=hidden_encoder_states["hidden_states"],
residual=hidden_encoder_states["residual"],
)["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
# model_config = self.config
# d_ff = model_config.intermediate_size
# d_qkv = model_config.d_model // model_config.num_attention_heads
# block_compute_costs = {
# # CausalSelfAttention (qkv proj + attn out) + MLP
# LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.d_model
# + 3 * d_ff * model_config.d_model,
# # This is the last lm_head
# TensorParallelColumnLinear: model_config.vocab_size * model_config.d_model,
# }
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
MambaDecoderLayer: 1,
# This is the last lm_head
TensorParallelColumnLinear: 0,
}
log_rank(
"get_block_compute_costs() Not implemented yet",
logger=logger,
level=logging.INFO,
rank=0,
)
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""
Get flops per second for a Mamba model.
Terms such as nonlinearities, biases, and layer normalization are omitted (https://arxiv.org/pdf/2001.08361.pdf)
"""
# world_size = self.parallel_context.world_pg.size()
# try:
# num_key_values_heads = self.config.num_key_value_heads
# except AttributeError:
# num_key_values_heads = self.config.num_attention_heads
# model_flops, hardware_flops = get_flops(
# num_layers=self.config.num_hidden_layers,
# hidden_size=self.config.d_model,
# num_heads=self.config.num_attention_heads,
# num_key_value_heads=num_key_values_heads,
# vocab_size=self.config.vocab_size,
# ffn_hidden_size=self.config.intermediate_size,
# seq_len=sequence_length,
# batch_size=global_batch_size,
# recompute_granularity=self.parallel_config.recompute_granularity,
# )
# model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
# hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
world_size = self.parallel_context.world_pg.size()
expand = 2 if ("expand" not in self.config.ssm_cfg) else self.config.ssm_cfg["expand"]
d_state = 16 if ("d_state" not in self.config.ssm_cfg) else self.config.ssm_cfg["d_state"]
dt_rank = (
math.ceil(self.config.d_model / 16)
if ("dt_rank" not in self.config.ssm_cfg or self.config.ssm_cfg["dt_rank"] == "auto")
else self.config.ssm_cfg["dt_rank"]
)
d_inner = int(expand * self.config.d_model)
# embeddings (do not include embeddigns as per Chinchilla)
# embeddings = 2 * sequence_length * self.config.vocab_size * self.config.d_model
# selective scan, see : https://github.com/state-spaces/mamba/issues/110
scan = 9 * sequence_length * d_state * self.config.d_model
# linear projections
in_proj = 2 * sequence_length * self.config.d_model * d_inner * 2
x_proj = 2 * sequence_length * d_inner * (dt_rank + d_state * 2)
dt_proj = 2 * sequence_length * dt_rank * d_inner
out_proj = 2 * sequence_length * d_inner * self.config.d_model
# output projection
projection = 2 * sequence_length * self.config.vocab_size * self.config.d_model
forward_flops = self.config.num_hidden_layers * (in_proj + scan + x_proj + dt_proj + out_proj) + projection
backward_flops = 2 * forward_flops
model_flops = forward_flops + backward_flops
model_flops_per_s = model_flops * global_batch_size / (iteration_time_in_sec * world_size * 1e12)
# add hardware flops later
hardware_flops_per_s = 0
return model_flops_per_s, hardware_flops_per_s
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
# NOTE(fmom): undo transpose for now since Mamba is not using TP
# loss = sharded_cross_entropy(
# sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
# ).transpose(0, 1)
loss = sharded_cross_entropy(sharded_logits, label_ids, group=self.tp_pg, dtype=torch.float)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class MambaForTraining(NanotronModel):
def __init__(
self,
config: MambaModelConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
self.model = MambaModel(
config=self.config,
parallel_context=self.parallel_context,
parallel_config=self.parallel_config,
random_states=random_states,
)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
def get_named_params_without_weight_decay(self):
# get full name with "A_log", "D"
named_param_without_weight_decay = []
for name, _ in self.model.named_parameters():
if "A_log" in name or "D" in name:
named_param_without_weight_decay.append(name)
return named_param_without_weight_decay
@torch.no_grad()
def init_model_randomly(self, config):
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
initializer_range = config.model.init_method.initializer_range
n_residuals_per_layer = config.model.init_method.n_residuals_per_layer
num_hidden_layers = config.model.model_config.num_hidden_layers
rescale_prenorm_residual = config.model.init_method.rescale_prenorm_residual
d_model = config.model.model_config.d_model
if config.model.model_config.ssm_cfg is not None:
dt_init = config.model.model_config.ssm_cfg["dt_init"]
dt_rank = config.model.model_config.ssm_cfg["dt_rank"]
dt_scale = config.model.model_config.ssm_cfg["dt_scale"]
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear) or isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif "bias" == param_name:
raise ValueError("We don't use bias for TensorParallelColumnLinear and TensorParallelRow")
else:
raise ValueError(f"Who the fuck is {param_name}?")
if rescale_prenorm_residual and full_param_name.endswith("out_proj.weight"):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
with torch.no_grad():
module.weight /= math.sqrt(n_residuals_per_layer * num_hidden_layers)
elif isinstance(module, nn.Conv1d):
fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight)
if "weight" == param_name:
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif "bias" == param_name:
bound = 1 / math.sqrt(fan_in) if (fan_in > 0) else 0
init.uniform_(module.bias, -bound, bound)
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, nn.Linear):
fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight)
if "weight" == param_name:
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif "bias" == param_name:
bound = 1 / math.sqrt(fan_in) if (fan_in > 0) else 0
init.uniform_(module.bias, -bound, bound)
else:
raise ValueError(f"Who the fuck is {param_name}?")
if config.model.model_config.ssm_cfg is not None:
if dt_rank == "auto":
dt_init_std = math.ceil(d_model / 16) ** -0.5 * dt_scale
else:
dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(module.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(module.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, std=initializer_range)
elif isinstance(module, RMSNorm) or isinstance(module, nn.LayerNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, Mamba):
pass
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
@staticmethod
def get_embeddings_lm_head_tied_names():
return [
"model.token_position_embeddings.pp_block.token_embedding.weight",
"model.lm_head.pp_block.weight",
]
# TODO(fmom): implement get_block_compute_costs
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
# TODO(fmom): implement get_flops_per_sec
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
torch==2.1.0
einops
causal-conv1d==1.1.0
mamba-ssm==1.1.4
flash-attn==2.5.0
"""
Nanotron Inference Script
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4
```
"""
import argparse
import os
from pathlib import Path
import torch
from config import MambaConfig, MambaModelConfig
from mamba import MambaForTraining
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
GenerationArgs,
LoggingArgs,
ParallelismArgs,
get_config_from_file,
)
from nanotron.generation.decode import (
GenerationInput,
TokenizerConfig,
decode_text,
decode_tokenized,
)
from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
OneForwardOneBackwardPipelineEngine,
)
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.random import (
RandomStates,
get_current_random_state,
get_synced_random_state,
set_random_seed,
)
from nanotron.serialize import load_weights
from nanotron.trainer import mark_tied_parameters
try:
from transformers import AutoTokenizer
except ImportError:
AutoTokenizer = None
logger = logging.get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path")
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate")
return parser.parse_args()
def main():
args = get_args()
assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist"
config = get_config_from_file(
(args.ckpt_path / "config.yaml").as_posix(), config_class=MambaConfig, model_config_class=MambaModelConfig
)
model_config = config.model.model_config
tokenizer_path = config.tokenizer.tokenizer_name_or_path
parallel_config = ParallelismArgs(
dp=args.dp,
pp=args.pp,
tp=args.tp,
pp_engine=OneForwardOneBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
print(parallel_config)
# Initialise all process groups
parallel_context = ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)
# Set log levels
logging_config = LoggingArgs(
log_level="info",
log_level_replica="info",
)
# Set log levels
set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config)
log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0)
log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0)
# Set random states
set_random_seed(42)
# Get synchronized random states
if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
random_states = RandomStates(
{"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)}
)
else:
# We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER)
random_states = RandomStates({})
model = build_model(
model_builder=lambda: MambaForTraining(
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=random_states,
),
dtype=getattr(torch, model_config.dtype),
parallel_context=parallel_context,
)
# Mark some parameters as tied
# TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead?
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
# Sanity check model
sanity_check(root_module=model)
# Load checkpoint
checkpoint_path = args.ckpt_path
log_rank(
f"Loading checkpoint from {checkpoint_path}:",
logger=logger,
level=logging.INFO,
rank=0,
)
load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path)
model.eval()
if AutoTokenizer is not None:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# tokenizer.pad_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
elif getattr(model.config, "pad_token_id", None) is not None:
tokenizer.pad_token_id = int(model.config.pad_token_id)
elif getattr(model.config, "eos_token_id", None) is not None:
tokenizer.pad_token_id = int(model.config.eos_token_id)
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left" # TODO @nouamane: do we want this?
dummy_inputs = [
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
# "This film was probably inspired by Godzilla",
"What is your "
]
log_rank("Setup Inference mode for mamba model", logger=logger, level=logging.INFO, rank=0)
# assert config.inference_params.max_batch_size == 1, "Only batch size 1 is supported for inference for now"
outputs = decode_text(
input_iter=(GenerationInput(text=text) for text in dummy_inputs),
tokenizer=tokenizer,
# TODO @thomasw21: From ModelWithLoss extract the model.
model=model.model,
parallel_context=parallel_context,
max_new_tokens=args.max_new_tokens,
max_micro_batch_size=2,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
tokenizer_config=TokenizerConfig(max_input_length=None),
is_bench=os.environ.get("USE_BENCH", "0") == "1",
logits_are_batch_first=False,
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)
log_rank(
f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)
# Model ref
tokens = tokenizer(dummy_inputs, return_tensors="pt")
input_ids = tokens.input_ids.to(device="cuda")
else:
outputs = decode_tokenized(
input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"),
input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"),
model=model.model,
parallel_context=parallel_context,
generation_config=GenerationArgs(sampler="greedy", use_cache=True),
max_micro_batch_size=1,
max_new_tokens=12,
returns_logits=False,
)
for output in outputs:
input_ids = output.input_ids
generated_ids = output.generation_ids
if isinstance(input_ids, TensorPointer):
assert isinstance(generated_ids, TensorPointer)
continue
assert isinstance(generated_ids, torch.Tensor)
log_rank(
f"generation: {generated_ids[len(input_ids) :]}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
"--------------------------------------------------",
logger=logger,
level=logging.INFO,
rank=0,
)
dist.barrier()
if __name__ == "__main__":
main()
# Copyright (c) 2023, Tri Dao, Albert Gu.
import causal_conv1d_cuda
import selective_scan_cuda
import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn
from einops import rearrange, repeat
from torch.cuda.amp import custom_bwd, custom_fwd
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout,
x,
out,
None,
ctx.delta_softplus,
False, # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (
du,
ddelta,
dA,
dB,
dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None,
)
def selective_scan_fn(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def selective_scan_ref(
u,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
):
"""
u: r(B D L)
delta: r(B D L)
A: c(D N) or r(D N)
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
out: r(B D L)
last_state (optional): r(B D dstate) or c(B D dstate)
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
if not is_variable_B:
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum("bdn,dn->bd", x, C)
else:
if C.dim() == 3:
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
else:
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
class MambaInnerFn(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx,
d_inner,
tp_pg,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
checkpoint_lvl=1,
):
"""
xz: (batch, dim, seqlen)
"""
assert checkpoint_lvl in [0, 1]
batch, L = xz.shape[0], xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
if torch.is_autocast_enabled():
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
if xz.stride(-1) != 1:
xz = xz.contiguous()
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
if tp_pg.size() > 1:
x, z = xz.view(batch, d_inner // 2, 2, L).chunk(2, dim=2)
else:
x, z = xz.view(batch, d_inner, 2, L).chunk(2, dim=2)
x = x.squeeze(2)
z = z.squeeze(2)
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
ctx.is_variable_B = B is None
ctx.is_variable_C = C is None
ctx.B_proj_bias_is_None = B_proj_bias is None
ctx.C_proj_bias_is_None = C_proj_bias is None
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if B.stride(-1) != 1:
B = B.contiguous()
if C is None: # variable C
C = x_dbl[:, -d_state:] # (bl dstate)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
else:
if C.stride(-1) != 1:
C = C.contiguous()
if D is not None:
D = D.contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
ctx.delta_softplus = delta_softplus
# ctx.out_proj_bias_is_None = out_proj_bias is None
ctx.checkpoint_lvl = checkpoint_lvl
if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
conv1d_out, delta = None, None
ctx.d_inner = d_inner
ctx.tp_pg = tp_pg
ctx.save_for_backward(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
out,
)
return rearrange(out_z, "b d l -> b l d")
@staticmethod
@custom_bwd
def backward(ctx, dout):
# dout: (batch, seqlen, dim)
(
xz,
conv1d_weight,
conv1d_bias,
x_dbl,
x_proj_weight,
delta_proj_weight,
conv1d_out,
delta,
A,
B,
C,
D,
delta_bias,
scan_intermediates,
out,
) = ctx.saved_tensors
batch, L = xz.shape[0], xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
# x, z = xz.chunk(2, dim=1)
assert ctx.d_inner % ctx.tp_pg.size() == 0
x, z = xz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2)
x = x.squeeze(2)
z = z.squeeze(2)
if ctx.checkpoint_lvl == 1:
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
# dx, dz = dxz.chunk(2, dim=1)
assert ctx.d_inner % ctx.tp_pg.size() == 0
dx, dz = dxz.view(batch, ctx.d_inner // ctx.tp_pg.size(), 2, L).chunk(2, dim=2)
dx = dx.squeeze(2)
dz = dz.squeeze(2)
dout = rearrange(dout, "b l e -> b e l")
if dout.stride(-1) != 1:
dout = dout.contiguous()
(dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z,) = selective_scan_cuda.bwd(
conv1d_out,
delta,
A,
B,
C,
D,
z,
delta_bias,
dout,
scan_intermediates,
out,
dz,
ctx.delta_softplus,
True, # option to recompute out_z
)
dD = dD if D is not None else None
dx_dbl = torch.empty_like(x_dbl)
dB_proj_bias = None
if ctx.is_variable_B:
if not A.is_complex():
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
dB = None
dC_proj_bias = None
if ctx.is_variable_C:
if not A.is_complex():
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
else:
dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
dx_dbl[:, -d_state:] = dC # (bl d)
dC = None
ddelta = rearrange(ddelta, "b d l -> d (b l)")
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
# backward of conv1d with the backward of chunk).
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
)
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
return (
None, # d_inner
None, # tp_pg
dxz,
dconv1d_weight,
dconv1d_bias,
dx_proj_weight,
ddelta_proj_weight,
dA,
dB,
dC,
dD,
ddelta_bias if delta_bias is not None else None,
dB_proj_bias,
dC_proj_bias,
None,
)
def mamba_inner_fn(
d_inner,
tp_pg,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
):
return MambaInnerFn.apply(
d_inner,
tp_pg,
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
A,
B,
C,
D,
delta_bias,
B_proj_bias,
C_proj_bias,
delta_softplus,
)
def mamba_inner_ref(
xz,
conv1d_weight,
conv1d_bias,
x_proj_weight,
delta_proj_weight,
out_proj_weight,
out_proj_bias,
A,
B=None,
C=None,
D=None,
delta_bias=None,
B_proj_bias=None,
C_proj_bias=None,
delta_softplus=True,
):
L = xz.shape[-1]
delta_rank = delta_proj_weight.shape[1]
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
x, z = xz.chunk(2, dim=1)
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
# We're being very careful here about the layout, to avoid extra transposes.
# We want delta to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
delta = rearrange(delta, "d (b l) -> b d l", l=L)
if B is None: # variable B
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
if B_proj_bias is not None:
B = B + B_proj_bias.to(dtype=B.dtype)
if not A.is_complex():
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
if C is None: # variable B
C = x_dbl[:, -d_state:] # (bl d)
if C_proj_bias is not None:
C = C + C_proj_bias.to(dtype=C.dtype)
if not A.is_complex():
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
else:
C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
import argparse
import os
import sys
from config import MambaModelConfig
from mamba import MambaForTraining
from trainer import MambaTrainer
from nanotron import logging
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from run_train import get_dataloader # noqa
logger = logging.get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
config_file = args.config_file
# Load trainer and data
trainer = MambaTrainer(config_file, model_config_class=MambaModelConfig, model_class=MambaForTraining)
dataloader = get_dataloader(trainer)
# Train
trainer.train(dataloader)
#!/bin/bash
# Simple script to create a tiny mamba model and train it
set -e -x
# Create the YAML config file
EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P)
REPO_PATH=$(dirname $EXAMPLE_PATH)
python $EXAMPLE_PATH/create_config_mamba.py
# Setup from environment variables
export CUDA_DEVICE_MAX_CONNECTIONS=1
export FI_PROVIDER="efa"
python -u -m torch.distributed.run \
--nproc_per_node 8 \
--nnodes 1 \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
$REPO_PATH/mamba/train_mamba.py --config-file $EXAMPLE_PATH/config_mamba.yaml
from typing import Optional, Type, Union
from config import ExistingCheckpointInit, MambaConfig, MambaInit
from torch.nn.parallel import DistributedDataParallel
from nanotron import logging
from nanotron.trainer import DistributedTrainer
logger = logging.get_logger(__name__)
from nanotron import distributed as dist
from nanotron.config import ParallelismArgs
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.parallel.tied_parameters import (
create_pg_for_tied_weights,
get_tied_id_to_param,
tie_parameters,
)
from nanotron.serialize import load_weights, parse_ckpt_path
class MambaTrainer(DistributedTrainer):
def __init__(
self,
config_or_config_file: Union[MambaConfig, str],
config_class: Type[MambaConfig] = MambaConfig,
model_config_class: Optional[Type] = None,
model_class: Type[NanotronModel] = None,
):
assert config_class == MambaConfig
super().__init__(config_or_config_file, config_class, model_config_class, model_class)
def _mark_tied_parameters(
self,
model: NanotronModel,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs] = None,
):
# Tie embeddings
embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names()
if len(embeddings_lm_head_tied_names) > 0:
shared_embeddings = [
(
target,
(
parallel_context.world_rank_matrix[
dist.get_rank(parallel_context.expert_pg),
get_pp_rank_of(target, module=model),
dist.get_rank(parallel_context.dp_pg),
dist.get_rank(parallel_context.tp_pg),
],
),
)
for target in embeddings_lm_head_tied_names
]
tie_parameters(
root_module=model,
ties=shared_embeddings,
parallel_context=parallel_context,
reduce_op=dist.ReduceOp.SUM,
)
# Tie custom params
model.tie_custom_params()
# Sync all parameters that have the same name and that are not sharded
assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point"
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if isinstance(param, NanotronParameter) and (param.is_sharded or param.is_tied):
continue
if isinstance(module, TensorParallelRowLinear) and "bias" == param_name:
# bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it
continue
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
if (
parallel_config is None
or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE
or hasattr(model.config.model.model_config, "is_mamba_config")
):
# We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce
# when TP=2 we have LN that is duplicated across TP, so by design it's tied
reduce_op = None
else:
reduce_op = dist.ReduceOp.SUM
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)
create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context)
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model
# Load or initialize model weights
self.init_checkpoint_path = parse_ckpt_path(config=self.config)
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Reload from a training checkpoint
log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0)
self.param_shard_metadata = load_weights(
model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
)
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO)
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
# Initialize model from an pretrained model checkpoint
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.config.model.init_method.path,
)
elif isinstance(self.config.model.init_method, MambaInit):
unwrapped_model.init_model_randomly(config=self.config)
# Synchronize parameters so that the model is consistent
# sync all params across dp
for name, param in sorted(model.named_parameters(), key=lambda x: x[0]):
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)
# sync tied params across tied groups
for (_, group_ranks), param in sorted(
get_tied_id_to_param(
parameters=model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
):
group = self.parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
else:
raise ValueError(f"Unsupported {self.config.model.init_method}")
return model
---
library_name: nanotron
---
# LlaMoE
Modeling code for LlaMoE to use with [Nanotron](https://github.com/huggingface/nanotron/)
## 🚀 Quickstart
```bash
# Generate a config file
python examples/moe/config_llamoe.py
# Install megablocks
pip install megablocks
# Run training
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=4 examples/moe/train_moe.py --config-file examples/moe/config_llamoe.yaml
```
## 🚀 Use your custom model
- Update the `LlaMoEConfig` class in `config_llamoe.py` to match your model's configuration
- Update the `LlaMoEForTraining` class in `modeling_llamoe.py` to match your model's architecture
- Pass the previous to the `DistributedTrainer` class in `train_moe.py`:
```python
trainer = DistributedTrainer(config_file, model_class=LlaMoEForTraining, model_config_class=LlaMoEConfig)
```
- Run training as usual
## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
- https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/dmoe.py
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import os
from dataclasses import dataclass
from typing import Optional
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
RandomInit,
TokenizerArgs,
TokensArgs,
)
from nanotron.config.config import PretrainDatasetsArgs
from nanotron.logging import human_format
@dataclass
class LlaMoEConfig:
"""Configuration for a LLAMA model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
is_llamoe_config: bool = True # We use this help differentiate models in yaml/python conversion
max_position_embeddings: int = 2048
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: Optional[int] = None
pad_token_id: Optional[int] = None
pretraining_tp: int = 1
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
## MoE specific
# Number of experts per Sparse MLP layer.
moe_num_experts: int = 1
# the number of experts to root per-token, can be also interpreted as the `top-p` routing parameter
num_experts_per_tok: int = 1
moe_capacity_factor: int = 1
def __post_init__(self):
# for backward compatibility
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
assert (
self.num_experts_per_tok <= self.moe_num_experts
), f"num_experts_per_tok ({self.num_experts_per_tok}) must be <= moe_num_experts ({self.moe_num_experts})"
model_config = LlaMoEConfig(
# Config for a 52M llama model
num_hidden_layers=1,
hidden_size=512,
num_attention_heads=8,
intermediate_size=512 * 4,
max_position_embeddings=128,
tie_word_embeddings=False,
vocab_size=32000,
moe_num_experts=4,
)
num_params = human_format(
model_config.vocab_size * model_config.hidden_size * 2
+ model_config.num_hidden_layers
* (
3 * model_config.hidden_size * model_config.intermediate_size
+ 4 * model_config.hidden_size * model_config.hidden_size
)
).replace(".", "p")
print(f"Model has {num_params} parameters")
SEED = 42
learning_rate = LRSchedulerArgs(
learning_rate=3e-4, lr_warmup_steps=100, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
)
optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=False,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)
parallelism = ParallelismArgs(
dp=1,
pp=1,
tp=2,
expert_parallel_size=2,
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
tp_linear_async_communication=False,
)
assert (
model_config.moe_num_experts % parallelism.expert_parallel_size == 0
), "Number of experts must be divisible by expert_parallel_size"
tokens = TokensArgs(sequence_length=256, train_steps=1918, micro_batch_size=256, batch_accumulation_per_replica=2)
data = DataArgs(
seed=SEED,
num_loading_workers=1,
# dataset=None
dataset=PretrainDatasetsArgs(
hf_dataset_or_datasets="roneneldan/TinyStories",
hf_dataset_splits="train",
text_column_name="text",
dataset_processing_num_proc_per_process=12,
),
)
checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
config = Config(
general=GeneralArgs(project="moe", run="llamoe", seed=SEED),
checkpoints=CheckpointsArgs(
checkpoints_path=checkpoints_path,
checkpoint_interval=100000,
save_initial_state=True,
resume_checkpoint_path=checkpoints_path,
),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
tokenizer=TokenizerArgs("meta-llama/Llama-2-7b-hf"),
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data_stages=[
DatasetStageArgs(name="Stable Training Stage", start_training_step=1, data=data),
DatasetStageArgs(name="Annealing Phase", start_training_step=10, data=data),
],
)
if __name__ == "__main__":
dir = os.path.dirname(__file__)
# Save config as YAML file
filename = os.path.basename(__file__).replace(".py", ".yaml")
config.save_as_yaml(f"{dir}/{filename}")
print(f"Config saved as {dir}/{filename}")
# You can now train a model with this config using `/run_train.py`
checkpoints:
checkpoint_interval: 100000
checkpoints_path: /fsx/nouamane/projects/nanotron/examples/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: /fsx/nouamane/projects/nanotron/examples/checkpoints
save_initial_state: true
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 12
hf_dataset_config_name: null
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: moe
run: llamoe
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 512
initializer_range: 0.02
intermediate_size: 2048
is_llamoe_config: true
max_position_embeddings: 128
moe_capacity_factor: 1
moe_num_experts: 4
num_attention_heads: 8
num_experts_per_tok: 1
num_hidden_layers: 1
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-06
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 32000
optimizer:
accumulate_grad_in_fp32: false
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 1818
lr_decay_style: cosine
lr_warmup_steps: 100
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Llama-2-7b-hf
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 256
sequence_length: 256
train_steps: 1918
val_check_interval: -1
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMa MoE model."""
import math
from typing import Dict, Optional, Union, List
import torch
from config_llamoe import LlaMoEConfig
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from moe import dMoE
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import ParallelismArgs
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.utils import checkpoint_method
from torch import nn
from torch.nn import init
logger = logging.get_logger(__name__)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.end = end
self.theta = theta
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
# TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
self.freqs_cis: torch.Tensor
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
self.register_buffer(
"freqs_cis",
torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
persistent=False,
)
assert self.freqs_cis.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.freqs_cis.dtype != torch.float:
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta
** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
)
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
freqs = torch.view_as_real(complex_freqs)
self.freqs_cis.copy_(freqs)
self._initialized_buffer = True
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
batch_size, seq_length, num_heads, inner_dim = x.shape
while (
position_ids is not None and position_ids[-1, -1] >= self.end
) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
self.end *= 2
self._initialized_buffer = False
if self._initialized_buffer is False:
print(f"Initializing rotary embeddings with end={self.end}")
self.init_rotary_embeddings()
dtype = x.dtype
assert inner_dim % 2 == 0
x = x.view(
batch_size, seq_length, num_heads, inner_dim // 2, 2
) # [batch_size, q_length, num_heads, inner_dim]
if x.dtype == torch.bfloat16:
x = x.float()
complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
if position_ids is None:
freqs_cis = self.freqs_cis[None, :seq_length, None, :]
else:
# TODO(kunhao): Should None follow the num_heads dimension?
if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
complex_freqs = torch.view_as_complex(freqs_cis)
x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
return x_out.type(dtype)
class CoreAttention(nn.Module):
def __init__(self, config: LlaMoEConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.checkpoint_attention = False # Because flash_attn already does checkpointing
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=0.0,
softmax_scale=None, # This already defaults to the scale I'm interested in
causal=causal,
return_attn_probs=False,
)
return attn_output
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
class CausalSelfAttention(nn.Module, AttachableStore):
def __init__(
self,
config: LlaMoEConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
try:
assert (
config.num_key_value_heads % tp_pg.size() == 0
), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
except AttributeError:
log_rank(
"WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
logger=logger,
level=logging.WARNING,
rank=0,
)
# If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
config.num_key_value_heads = config.num_attention_heads
assert (
config.num_attention_heads % config.num_key_value_heads == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
self.n_repeats = config.num_attention_heads // config.num_key_value_heads
self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
# build the slice config for self.qkv for save/load
# shard are done within the contiguous chunk
qkv_contiguous_chunks = (
config.num_attention_heads * self.d_qk, # shape of q
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
self.qkv_proj = TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
)
# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True)
self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
)
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
):
qkv_states = self.qkv_proj(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape
if self.is_gqa:
query_states, key_states, value_states = torch.split(
qkv_states,
[
self.n_local_q_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
],
dim=-1,
)
query_states = (
query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
)
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
else:
query_states, key_states, value_states = (
qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
.permute(2, 1, 0, 3, 4)
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
position_offsets = position_ids[:, -1]
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
if self.rotary_embedding.end > old_rotary_embed_end:
k_cache = torch.cat(
[
k_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
v_cache = torch.cat(
[
v_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_v,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
assert (
k_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
assert (
v_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size, kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size, kv_length, self.n_heads, d_v]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
}
)
else: # Training case
# Apply rotary embeddings to query/key states
# NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk]
# Here it is, [batch_size, seq_length, num_heads, d_qk]
# [2, batch_size, seq_length, num_heads, d_qk]
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
# [batch_size, seq_length, 2, num_heads, d_qk]
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]
key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output = self.o_proj(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
class LlaMoEDecoderLayer(nn.Module):
def __init__(
self,
config: LlaMoEConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
parallel_context: ParallelContext,
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.block_sparse_moe = dMoE(
config,
parallel_context=parallel_context,
parallel_config=parallel_config,
)
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states) # TODO @nouamane do we want to return router_logits?
hidden_states = hidden_states + residual
return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlaMoEConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
padding_idx=config.pad_token_id,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
store = self.get_local_store()
if store is not None:
if "past_length" in store:
past_length = store["past_length"]
else:
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# Store new past_length in store
store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class LlaMoEModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: LlaMoEConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=LlaMoEDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"parallel_context": parallel_context,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
def forward_with_hidden_states(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
LlaMoEDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
try:
num_key_values_heads = self.config.num_key_value_heads
except AttributeError:
num_key_values_heads = self.config.num_attention_heads
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
num_key_value_heads=num_key_values_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(0, 1)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class LlaMoEForTraining(NanotronModel):
def __init__(
self,
config: LlaMoEConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
self.model = LlaMoEModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
@torch.no_grad()
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TritonRMSNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, nn.Linear):
fan_in = None
if "weight" == param_name:
fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight)
init.kaiming_uniform_(module.weight, a=math.sqrt(5))
elif "bias" == param_name:
bound = 1 / math.sqrt(fan_in) if (fan_in is not None and fan_in > 0) else 0
init.uniform_(module.bias, -bound, bound)
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
def get_flops(
num_layers,
hidden_size,
num_heads,
num_key_value_heads,
vocab_size,
seq_len,
ffn_hidden_size,
batch_size=1,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
num_key_value_heads: number of key/value heads in the model
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if num_key_value_heads is None:
num_key_value_heads = num_heads
hidden_size_per_head = hidden_size // num_heads
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention
## qkv projection
decoder_qkv_proj_flops_fwd = (
2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head
+ 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head
)
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head
## attn out
decoder_attn_out_flops_fwd = (
2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size
)
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_qkv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
hardware_flops = model_flops # TODO @nouamanetazi: add hardware flops
return model_flops, hardware_flops
""" LlaMa model with MoEs"""
import warnings
from functools import partial
from typing import Optional, Tuple
import numpy as np
import stk
import torch
import torch.nn.functional as F
from config_llamoe import LlaMoEConfig
from megablocks.layers import weight_parallel as wp
from megablocks.layers.activation_fn import act_fn
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import ParallelismArgs
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.sharded_parameters import SplitConfig, mark_all_parameters_in_module_as_sharded
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from torch import nn
try:
import megablocks.ops as ops
from megablocks.layers.all_to_all import all_to_all
except ImportError:
warnings.warn("Please install megablocks to use MoEs: `pip install megablocks`")
logger = logging.get_logger(__name__)
class dMoE(torch.nn.Module):
def __init__(
self,
config: LlaMoEConfig,
parallel_context: "ParallelContext",
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
self.config = config
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER:
logging.warn_once(
logger=logger,
msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.",
rank=0,
)
# Token router.
self.gate = LearnedRouter(config)
# Expert computation helper.
self.experts = ParallelDroplessMLP(
config,
use_bias=False,
parallel_context=parallel_context,
parallel_config=parallel_config,
)
def forward(self, x: torch.Tensor):
"""
Args:
x: input tensor of shape [sequence_length, batch_size, hidden_size]
"""
# Compute the expert scores and assignments.
# TODO: support sequence parallelism
batch_size, sequence_length, _ = x.size()
x = x.view(-1, self.config.hidden_size)
scores, expert_weights, top_experts = self.gate(x)
# Compute the experts.
x = self.experts(x, scores, expert_weights, top_experts)
return x.reshape(batch_size, sequence_length, -1)
# Adapted from megablocks.layers.router.LearnedRouter
class LearnedRouter(torch.nn.Module):
def __init__(self, config: LlaMoEConfig):
super().__init__()
self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
# TODO: initialization
self.config = config
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
router_logits = self.layer(x) # (batch * sequence_length, n_experts)
scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse?
if self.config.num_experts_per_tok == 1:
expert_weights, expert_indices = scores.max(dim=-1, keepdim=True)
else:
expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1)
return scores, expert_weights, expert_indices.int()
# Adapted from megablocks.layers.mlp.ParallelDroplessMLP
class ParallelDroplessMLP(torch.nn.Module):
def __init__(
self,
config: LlaMoEConfig,
use_bias: bool,
parallel_context: "ParallelContext",
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
self.config = config
self.use_bias = use_bias
self.expert_pg_size = parallel_context.expert_pg.size()
self.expert_parallel_group = parallel_context.expert_pg
self.hidden_sharding_degree = self.expert_pg_size // min(self.expert_pg_size, self.config.moe_num_experts)
self.experts_per_rank = self.config.moe_num_experts // min(self.expert_pg_size, self.config.moe_num_experts)
self.num_experts = config.moe_num_experts
self.num_experts_per_tok = self.config.num_experts_per_tok
# Calculate the number of bits needed to represent the expert indices
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
if use_bias:
self.bias = torch.nn.Parameter(torch.empty(config.hidden_size)) # TODO: init
# Select the forward function for the operating mode.
self.forward_fn = self.parallel_forward_once if self.expert_pg_size > 1 else self.forward_once
self.blocking = 128
if self.experts_per_rank == 1:
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=parallel_context.tp_pg)
else:
self.mlp = SparseMLP(config=config, parallel_config=parallel_config, parallel_context=parallel_context)
max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking
self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1)
def indices_and_bins(self, top_expert):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
top_expert = top_expert.int()
bin_ids, indices = ops.sort(top_expert, self.sort_end_bit)
tokens_per_expert = ops.histogram(top_expert, self.num_experts)
# Calculate the bin bounds for the sorted tokens.
bins = inclusive_cumsum(tokens_per_expert, 0)
return indices, bin_ids, bins, tokens_per_expert
def indices_and_padded_bins(self, top_experts):
# Sort the expert ids to produce the scatter/gather
# indices for the permutation.
bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
# Histogram the expert ids to identify the number of
# tokens routed to each expert.
tokens_per_expert = ops.histogram(top_experts, self.num_experts)
# Round the token counts up to the block size used in
# the matrix muliplications. Calculate the starting
# position of each bin.
padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking)
padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0)
# Calculate the bin bounds for the sorted tokens.
bins = inclusive_cumsum(tokens_per_expert, 0)
return indices, bin_ids, bins, padded_bins, tokens_per_expert
def forward_once(self, x, expert_weights, top_experts): # TODO: sparse
with torch.no_grad():
(
indices,
bin_ids,
bins,
padded_bins,
tokens_per_expert,
) = self.indices_and_padded_bins(top_experts)
# Route the tokens for MoE computation.
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok)
with torch.no_grad():
topo = self.topology(x, padded_bins)
x = self.mlp(x, topo)
# Un-route the data for the MoE output.
x = ops.padded_scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
padded_bins,
self.num_experts_per_tok,
-1,
)
return x, tokens_per_expert
def parallel_forward_once(self, x, expert_weights, top_experts):
with torch.no_grad():
indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts)
repeated_tokens_per_expert = ops.repeat(tokens_per_expert, (self.hidden_sharding_degree,))
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
tpe_handle = torch.distributed.all_to_all_single(
parallel_tokens_per_expert,
repeated_tokens_per_expert,
group=self.expert_parallel_group,
async_op=True,
)
x = ops.gather(x, indices, bin_ids, bins, self.num_experts_per_tok)
# Compute the number of tokens that will be received from each
# device and permute the input data across the devices.
with torch.no_grad():
tpe_handle.wait()
# Reshape to [expert_pg_size, num_experts_per_rank].
repeated_tokens_per_expert = repeated_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank)
parallel_tokens_per_expert = parallel_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank)
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
# Convert the send/recv counts to lists.
send_counts = send_counts.tolist()
recv_counts = recv_counts.tolist()
tokens_received = sum(recv_counts)
x = ops.repeat(x, (self.hidden_sharding_degree, 1))
# Start the cross-device permutation asynchronously so we can
# overlap communication with computation.
parallel_x, parallel_x_handle = all_to_all(
x, recv_counts, send_counts, self.expert_parallel_group, async_op=True
)
with torch.no_grad():
replicate_bins = inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
# Construct the expert indices for the permuted tokens.
parallel_top_expert = torch.remainder(
torch.arange(
self.num_experts * self.hidden_sharding_degree,
dtype=torch.int32,
device=indices.device,
),
self.experts_per_rank,
)
parallel_top_expert = ops.replicate(
parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received
).flatten()
parallel_bin_ids, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit)
# Calculate the bins boundaries from the token counts.
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int)
parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
# Locally permute the tokens and perform the expert computation.
# Block to make sure that the cross-device permutation is complete.
parallel_x_handle.wait()
parallel_x = self.permute_and_compute(
parallel_x,
parallel_tokens_per_expert,
parallel_indices,
parallel_bin_ids,
None, # expert_weights
parallel_bins,
num_experts_per_tok=1,
)
# Un-permute the tokens across the devices.
x, _ = all_to_all(parallel_x, send_counts, recv_counts, self.expert_parallel_group)
# Reduce along the hidden sharding to get the final outputs.
shape = (self.hidden_sharding_degree, -1, self.config.hidden_size)
x = ops.sum(x.view(shape), dim=0)
# Un-permute locally to setup for the next series of operations.
x = ops.scatter(
x,
indices,
bin_ids,
expert_weights,
bins,
self.num_experts_per_tok,
)
return x, tokens_per_expert.flatten()
def forward(self, x, scores, expert_weights, top_experts):
"""
Args:
x: input tensor of shape [sequence_length, batch_size, hidden_size]
scores: tensor of shape [sequence_length * batch_size, n_experts]
expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok]
top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok]
"""
# Compute the experts.
x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten())
if self.use_bias:
return x + self.bias
return x
def permute_and_compute(
self,
x,
tokens_per_expert,
indices,
bin_ids,
expert_weights,
bins,
num_experts_per_tok,
):
# Round the token counts up to the block size used in the matrix
# multiplication. Calculate the starting position of each bin.
padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking)
padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0)
# Route the tokens for MoE computation.
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, num_experts_per_tok)
# Perform the expert computation.
with torch.no_grad():
topo = self.topology(x, padded_bins)
x = self.mlp(x, topo)
# Un-route the data for the MoE output.
return ops.padded_scatter(x, indices, bin_ids, expert_weights, bins, padded_bins, num_experts_per_tok)
def sparse_transpose(self, size, row_indices, column_indices, offsets):
block_columns = size[1] // self.blocking
_, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit)
column_indices_t = row_indices.gather(0, gather_indices.long())
block_offsets_t = gather_indices.int()
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
nnz_per_column = ops.histogram(column_indices, block_columns)
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
offsets_t = torch.cat([zero, nnz_per_column])
return column_indices_t, offsets_t, block_offsets_t
def topology(self, x, padded_bins):
padded_tokens, _ = x.size()
assert padded_tokens % self.blocking == 0
assert self.config.intermediate_size % self.blocking == 0
# Offsets for the sparse matrix. All rows have the
# same number of nonzero blocks dictated by the
# dimensionality of a single expert.
block_rows = padded_tokens // self.blocking
blocks_per_row = self.config.intermediate_size // self.blocking
offsets = torch.arange(0, block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, device=x.device)
# Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending
# on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row)
# TODO(tgale): This is unused. Remove the need for this in stk.
# For now, use meta init to save the device memory.
data = torch.empty(column_indices.numel(), self.blocking, self.blocking, dtype=x.dtype, device="meta")
shape = (padded_tokens, self.config.intermediate_size * self.experts_per_rank)
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
shape, row_indices, column_indices, offsets
)
return stk.Matrix(
shape, data, row_indices, column_indices, offsets, column_indices_t, offsets_t, block_offsets_t
)
class ScaleGradient(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, x, scale):
ctx.scale = scale
return x
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad):
return grad * ctx.scale, None
scale_gradient = ScaleGradient.apply
class ExpertParallel(nn.Module):
"""
ExpertParallel serves to scale the gradients of the expert weights because unlike DP the gradients are not averaged across the expert parallel group.
"""
def __init__(self, module, expert_parallel_size: int):
super().__init__()
self.module = module
self.expert_parallel_size = expert_parallel_size
def forward(self, *args, **kwargs):
self.scale_gradients()
return self.module(*args, **kwargs)
def scale_gradients(self):
scale_gradient(self.module, 1 / self.expert_parallel_size)
class SparseMLP(nn.Module):
def __init__(
self,
config: LlaMoEConfig,
parallel_config: Optional[ParallelismArgs],
parallel_context: "ParallelContext",
):
super().__init__()
self.expert_pg_size = parallel_config.expert_parallel_size if parallel_config is not None else 1
self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts)
self.tp_pg = parallel_context.tp_pg
self.w1 = ExpertParallel(
nn.Linear(
config.hidden_size, config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), bias=False
),
expert_parallel_size=self.expert_pg_size,
)
self.w2 = ExpertParallel(
nn.Linear(
config.hidden_size, config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), bias=False
),
expert_parallel_size=self.expert_pg_size,
)
mark_all_parameters_in_module_as_sharded(
self,
pg=parallel_context.tp_and_expert_pg,
split_config=SplitConfig(split_dim=0),
)
if self.tp_pg.size() == 1:
self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous()
# TODO @nouamane: jit
self.act = partial(F.gelu, approximate="tanh")
self.sdd = partial(wp.sdd_nt, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.sdd
self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd
def forward(self, x, topo):
self.w1.scale_gradients(), self.w2.scale_gradients()
x = self.sdd(x.contiguous(), self.w1.module.weight, topo)
activation_fn_out = act_fn(x, self.act)
return self.dsd(activation_fn_out, self.w2.module.weight)
class MLP(nn.Module):
def __init__(
self,
config: LlaMoEConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.expert_pg_size = parallel_config.expert_parallel_size
self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts)
assert self.experts_per_rank == 1, "moe.MLP only supports 1 expert per rank, otherwise use moe.SparseMLP"
self.w1 = ExpertParallel(
TensorParallelColumnLinear(
config.hidden_size,
config.intermediate_size * self.experts_per_rank,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
),
expert_parallel_size=self.expert_pg_size,
)
self.w2 = ExpertParallel(
TensorParallelRowLinear(
config.intermediate_size * self.experts_per_rank,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication
and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
),
expert_parallel_size=self.expert_pg_size,
)
# TODO @nouamane: jit
self.act = partial(F.gelu, approximate="tanh")
def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim]
merged_states = self.w1(hidden_states)
hidden_states = self.w2(self.act(merged_states))
return hidden_states
def inclusive_cumsum(x, dim):
scalar = ops.inclusive_cumsum(x, dim)
return scalar.view(1) if not len(scalar.size()) else scalar
stanford-stk>=0.0.6
megablocks==0.5.1
"""
You can run using command:
```
python examples/moe/config_llamoe.py; USE_FAST=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/moe/train_moe.py --config-file examples/moe/config_llamoe.yaml
```
"""
import argparse
import os
import sys
from config_llamoe import LlaMoEConfig
from llamoe import LlaMoEForTraining
from nanotron import logging
from nanotron.trainer import DistributedTrainer
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from run_train import get_dataloader # noqa
logger = logging.get_logger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
parser.add_argument("--job-id", type=str, help="Optional job name")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
config_file = args.config_file
# Load trainer and data
trainer = DistributedTrainer(config_file, model_config_class=LlaMoEConfig, model_class=LlaMoEForTraining)
dataloader = get_dataloader(trainer)
# Train
trainer.train(dataloader)
OpenAI's scaling laws [[link]](https://arxiv.org/abs/2001.08361) in 2020 has shown that scaling is one of the core ingredients for the success of LLMs. But naively stacking more layers can lead to unstable training due to exploding or vanishing gradients. In our implementation, the experimental results show that in a 350m llama, spectral µTransfer matches the pretraining performance of the baseline (albeit with a slightly higher training loss of 0.04). In another MLP-only experiment, µTransfer maintains a consistent L1 norm of activation across widths, and depths and allows scaling up to 2B while the SP baseline blows up and becomes untrainable.
# How to use Spectral µTransfer
In your Nanotron configuration, simply set `use_mup` to `true`. Nanotron will automatically determine the right standard deviation and learning rate for each parameter.
```diff
model:
...
init_method:
- std: 0.025
+ use_mup: true
```
# MLP Only Experiment
We ran a systematic experiment varying the number of layers from 8 to 32, width from 128 to 8192, and batch size from 32 to 2048, all on a logarithmic scale, CIFAR dataset, using an MSE training objective for 4 epochs with Adam optimizer. [[Experiment Report]](https://wandb.ai/neuralink/exp14_mup_grid_search/reports/-Spectral-Transfer-MLP-s-Experiment-Results--Vmlldzo3NDQ0NTQw?accessToken=xe0mkunx3y8t0xzbzxu9caqcre57or5la58d9o209hinanlmzoaj7es24m4elvdj)
![Scale across widths](./assets/scale-across-width.png)
![Scale across depths](./assets/scale-across-depth.png)
# On 350m LLaMA
We trained a 350m model with spectral µTransfer and standard parametrization using Nanotron, a global batch size of 1m tokens at a learning rate of 0.001. µTransfer matches the performance of standard parametrization, with a slightly higher training loss of 0.04. [[Experiment Report]](https://api.wandb.ai/links/neuralink/i70nnpu9)
Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments.
![LLaMA](./assets/llama.png)
#### Thoughts
For Spectral MuP, the experiments we used it on MLP only [link] and 300m LLaMA [link] (there are links to the experiment config in the mup readme). However, when we tested it on 1B/8B models iirc, the loss blew up for some reasons. So, we'd recommend they try μTransfer, not spectral μTransfer.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment