Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
"""
Nanotron training script example using a custom dataloader.
Usage:
```
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml
```
"""
import argparse
from typing import Dict, cast
import datasets
import numpy as np
from nanotron import logging
from nanotron.config import (
DataArgs,
DatasetStageArgs,
PretrainDatasetsArgs,
)
from nanotron.dataloader import (
DataCollatorForCLM,
clm_process,
get_dataloader_worker_init,
get_datasets,
get_train_dataloader,
)
from nanotron.helpers import (
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader
try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
tf_version = None
logger = logging.get_logger(__name__)
def get_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
consumed_train_samples: int,
num_remaining_train_steps: int,
):
"""
Returns a dataloader for a given data stage.
data: The data configuration for the current stage.
consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero).
num_remaining_train_steps: The number of remaining training steps for this stage.
"""
assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0"
assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0"
# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
# Case 1: custom data generator
if data.dataset is None:
log_rank("Using custom data generator", logger=logger, level=logging.INFO, rank=0)
###########################################################################################################
# This can be replaced with your own tokenized data generator
###########################################################################################################
train_dataset = datasets.Dataset.from_dict(
{
"input_ids": np.random.randint(
0,
trainer.config.model.model_config.vocab_size,
(trainer.global_batch_size * num_remaining_train_steps, trainer.sequence_length + 1),
),
}
)
###########################################################################################################
data_collator = DataCollatorForCLM(
sequence_length=trainer.sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=trainer.parallel_context,
)
return DataLoader(
train_dataset,
batch_size=trainer.micro_batch_size,
collate_fn=data_collator,
drop_last=True,
num_workers=0,
pin_memory=True,
worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()),
)
# Case 2: HuggingFace datasets
elif isinstance(data.dataset, PretrainDatasetsArgs):
log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
log_rank(
f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
logger=logger,
level=logging.INFO,
rank=0,
)
# We need to the 1st device to process dataset and cache it, then other devices load from cache
with main_rank_first(trainer.parallel_context.world_pg):
# We load the raw dataset
raw_dataset = get_datasets(
hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets,
hf_dataset_config_name=data.dataset.hf_dataset_config_name,
splits=data.dataset.hf_dataset_splits,
)["train"]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# We apply the Causal Language Modeling preprocessing
train_dataset = clm_process(
raw_dataset=raw_dataset,
tokenizer=tokenizer,
text_column_name=data.dataset.text_column_name,
dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process,
dataset_overwrite_cache=data.dataset.dataset_overwrite_cache,
sequence_length=trainer.sequence_length,
)
# We load the processed dataset on the ranks requiring it
dataloader = get_train_dataloader(
train_dataset=train_dataset,
sequence_length=trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
seed_worker=data.seed,
dataloader_drop_last=True,
)
# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
return dataloader
def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}
for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata)
assert (
consumed_train_samples is not None
), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint"
num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, trainer.config, trainer.metadata
)
log_rank(
f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples",
logger=logger,
level=logging.INFO,
rank=0,
)
dataloader = (
get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
)
dataloaders[stage.name] = dataloader
return dataloaders
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
config_file = args.config_file
# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)
# Train
trainer.train(dataloader)
# DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining
Paper: https://arxiv.org/abs/2305.10429
You might think that one of the key ways to speed up pretraining performance is either by finding more quality data, increasing FLOPs, or changing the model architecture, but actually, these are not the only ways. DoReMi shows that, given the same source of training data, a model using an optimal data mixing strategy could outperform its counterpart with random sampling in at least 70% domains or all domains and downstream evaluations without any knowledge of the downstream evaluation tasks.
In our implementation, the experiment results show that doremi outperforms 15 out of 22 domains on test set and has a lower average cross entropy test loss. Here are the comparison of the training losses between:
- 280M proxy and reference model [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-280m-reference-vs-280m-proxy-s-training--Vmlldzo2NzYwNTU1)
- 2.5B reference and tuned weight models [[link]](https://wandb.ai/neuralink/nanotron/reports/-DoReMi-2-5B-tuned-weights-vs-2-5B-token-ratio-domain-weights-s-training--Vmlldzo2NzYwNzE2)
- And how the 280M proxy model's domain weights change during training [[link]](https://wandb.ai/neuralink/nanotron/runs/j9ojbso1?workspace=user-neuralink)
and the comparison of cross-entropy loss between the two 2.5B models on testset (the x-axis here just means sampling another batch from the test set given the same checkpoint): [[link]](https://api.wandb.ai/links/neuralink/qvof4dfq).
![The domains in which we outperform](./assets/outperform.png)
![The domains in which we don't outperform](./assets/not_outperform.png)
![Domain weights comparison](./assets/domain_weights.png)
**Notes**: The graph above represent test losses, not validation losses (this is a typo 🫠). The x-axis doesn't mean anything, it simply means sampling another batch of testset from the same final checkpoint.
### How it works
- Step 0: `pip install -r examples/doremi/requirements.txt`
- Step 1: Train a small reference model using uniform sampling from each domain (for a given global batch size, you equally sample `x` samples across all domains, or in some cases, a domain has a smaller amount of samples than other domains. This leads to some domains running out of samples early, so you could enable automatic domain weights based on the token count).
```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_280m_llama.yaml
```
- Step 2: Use the trained reference model from step 1 to train an identical model, and use its performance to dynamically tune the domain weights during training.
```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_doremi.py --config-file examples/doremi/configs/config_280m_llama_proxy.yaml
```
- Step 3: Nanotron saves the domain weights in the model checkpoint. Now, calculate the optimal domain weights by averaging the domain weights across all training steps from step 1: ˉα=1T∑Ti=1αt\bar{\alpha}=\frac{1}{T} \sum_{i=1}^T \alpha_t.
```python
import torch
domain_weights = torch.load("checkpoints/doremi/proxy-280m-llama/doremi_domain_weights_100000.pt")
total_weights = sum(d["domain_weights"] for d in domain_weights)
avg_weights = total_weights / len(domain_weights)
```
Then, set these `avg_weights` in the config of the larger run in the `doremi` section.
- Step 4: Use the optimized domain weights from step 3 to train a larger model (could be 10x to 30x larger).
```bash
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_2.8b_llama_with_tuned_weights.yaml
```
### Dataset
We expect the dataset path to link to a folder that already has tokenized data in the structure:
```
dataset
domain_0
...
domain_1
...
domain_2
...
```
For each tokenized sample, we expect a column name `domain_ids` which contains the domain index of that domain in the dataset. For example, if a sample is from the third domain, it should have a `domain_ids` equal to 2, and the folder names are the same as the domain names that you provide in the DoReMi config
### The Experiment
We first train a small 280M model for 70k steps on the Pile to obtain a reference model. Then, we use the reference model to tune the domain weights of that same model, where we train from scratch (aka: proxy training) for 70k steps.
The reference model's performance is used as a baseline to determine how difficult a domain is, so that the DoReMi algorithm can adjust the model weights accordingly on-the-fly. Once we obtain the optimized weights, we use them to train a 2.5B model (9x larger than the reference model) for 70k steps and train another one based on the token ratio domain weights (this is technically the same as random sampling, since the probability of a token occurring in the training data is the same as its token ratio).
For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model with optimized domain weights and token ratio domain weights. For more details on hyperparameters, please check the config YAML. Here are the model checkpoints in the experiment:
- 280M LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-280m-reference
- 280m LLAMA proxy model: https://huggingface.co/nanotron/doremi-llama-280m-proxy
- 2.5B LLaMA reference model: https://huggingface.co/nanotron/doremi-llama-2.5b-reference
- 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights
and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi
#### Thoughts
For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant.
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama/70000
save_initial_state: false
doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
# domain_weights: 0.1500, 0.1213, 0.0872, 0.0631, 0.0340, 0.0240, 0.0281, 0.0594, 0.1599, 0.0015, 0.0058, 0.0021, 0.0605, 0.1136, 0.0209, 0.0154, 0.0202, 0.0037, 0.0065, 0.0100, 0.0093, 0.0036
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron
run: train_2.8b_llama_reference
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
# NOTE: only change hidden_size, intermediate_size,
# num_attention_heads, num_key_value_heads and num_hidden_layers
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 24576
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 32
num_hidden_layers: 6
# num_hidden_layers: 1
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
# dp: 8
# # dp: 2
# pp: 1
# tp: 8
# # tp: 2
# NOTE: for running eval
dp: 8
pp: 1
tp: 2
pp_engine: 1f1b
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
# 240 * 1024 = 245760
# the doremi paper do 500k tokens per batch
# batch_accumulation_per_replica: 16
# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
# it results no samples from some domainsbatch_accumulation_per_replica
# NOTE: this causes some domain losses are 0
# batch_accumulation_per_replica: 8
# micro_batch_size: 8
batch_accumulation_per_replica: 1
micro_batch_size: 64
limit_test_batches: 0
# NOTE: this is like the number of microbatches for validation
limit_val_batches: 1
sequence_length: 1024
# train_steps: 1000
# train_steps: 1579
train_steps: 70_000
val_check_interval: 2
checkpoints:
checkpoint_interval: 5000
checkpoints_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: checkpoints/doremi/big-run-02/reference-2.8b-llama-tuned-weights_with_100k_proxy/70000
save_initial_state: false
doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
# domain_weights: 0.2333, 0.0700, 0.1154, 0.0528, 0.0665, 0.0670, 0.0366, 0.0571, 0.0451, 0.0036, 0.0087, 0.0078, 0.0708, 0.0656, 0.0034, 0.0048, 0.0222, 0.0084, 0.0038, 0.0186, 0.0149, 0.0235
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: nanotron
run: train_tuned_2.8b_model
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 24576
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 32
# num_hidden_layers: 40
num_hidden_layers: 6
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
# dp: 8
# pp: 1
# tp: 8
# tp: 2
# NOTE: for running eval
dp: 1
pp: 1
tp: 8
pp_engine: 1f1b
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 16 * 8 * 4 = 512
# batch_accumulation_per_replica * micro_batch_size * dp = 8 * 8 * 8 = 512 (this one)
# 240 * 1024 = 245760
# the doremi paper do 500k tokens per batch
# batch_accumulation_per_replica: 16
# NOTE: some weird bug, where if you run batch_accumulation_per_replica=16
# it results no samples from some domains
# NOTE: this causes some domain losses are 0
# batch_accumulation_per_replica: 8
# micro_batch_size: 8
batch_accumulation_per_replica: 1
micro_batch_size: 64
limit_test_batches: 0
limit_val_batches: 1
sequence_length: 1024
# train_steps: 1000
# train_steps: 70_000
# train_steps: 70_000
train_steps: 70_010
val_check_interval: -1
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/doremi/big-run-02/refrence-280m-llama
checkpoints_path_is_shared_file_system: true
# resume_checkpoint_path: checkpoints_test/
save_initial_state: false
doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: project_data/doremi/datasets/the_pile_raw/tokenized_data/train
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: doremi
run: train_280m_reference_model
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 1024
initializer_range: 0.02
intermediate_size: 4096
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 8
num_hidden_layers: 10
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# NOTE: batch_accumulation_per_replica * micro_batch_size * dp = 1 * 32 * 16 = 512
# 512 * 1024 = 524288 tokens per step
batch_accumulation_per_replica: 1
micro_batch_size: 32
limit_test_batches: 0
limit_val_batches: 0
sequence_length: 1024
# train_steps: 100_000
train_steps: 10
val_check_interval: -1
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/doremi/big-run-02/proxy-280m-llama_with_100k_reference
checkpoints_path_is_shared_file_system: true
# resume_checkpoint_path: checkpoints_test/
save_initial_state: false
doremi:
domain_names: Pile-CC, Github, OpenWebText2, StackExchange, Wikipedia (en), PubMed Abstracts, USPTO Backgrounds, FreeLaw, PubMed Central, Enron Emails, HackerNews, NIH ExPorter, Books3, ArXiv, DM Mathematics, OpenSubtitles, Gutenberg (PG-19), Ubuntu IRC, BookCorpus2, EuroParl, YoutubeSubtitles, PhilPapers
# domain_weights: 0.1500, 0.1213, 0.0872, 0.0631, 0.0340, 0.0240, 0.0281, 0.0594, 0.1599, 0.0015, 0.0058, 0.0021, 0.0605, 0.1136, 0.0209, 0.0154, 0.0202, 0.0037, 0.0065, 0.0100, 0.0093, 0.0036
ref_model_resume_checkpoint_path: checkpoints/1
data_stages:
- name: Stable Training Stage
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
# NOTE: this one works
# hf_dataset_or_datasets: vicgalle/alpaca-gpt4
# hf_dataset_splits: train
# text_column_name: instruction
# NOTE: too big
# hf_dataset_or_datasets: allenai/c4
# hf_dataset_splits: train
# text_column_name: text
# NOTE: good for testing
# hf_dataset_or_datasets: miam
# hf_dataset_splits: train
# text_column_name: Utterance
# hf_dataset_or_datasets: wikicorpus
# hf_dataset_splits: train
# text_column_name: text
# hf_dataset_or_datasets: mc4
# hf_dataset_splits: train
# text_column_name: text
# hf_dataset_or_datasets: leandro/the-pile-splitted
# hf_dataset_splits: train
# text_column_name: text
hf_dataset_or_datasets: /fsx/phuc/datasets/doremi/the_pile/testset
num_loading_workers: 1
seed: 42
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: doremi
run: train_280m_proxy_model
seed: 42
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 120
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 1024
initializer_range: 0.02
intermediate_size: 4096
is_llama_config: true
max_position_embeddings: 256
num_attention_heads: 8
num_hidden_layers: 10
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_steps: 8
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
# dp: 16
dp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
# batch_accumulation_per_replica * micro_batch_size * dp = 4 * 8 * 16 = 512
# 240 * 1024 = 245760
# the doremi paper do 500k tokens per batch
# NOTE: this causes some domain losses are 0
# batch_accumulation_per_replica: 4
# micro_batch_size: 8
batch_accumulation_per_replica: 1
micro_batch_size: 32
limit_test_batches: 0
limit_val_batches: 0
sequence_length: 1024
# train_steps: 1000
# train_steps: 1579
# train_steps: 100_000
train_steps: 10
val_check_interval: -1
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import torch
from nanotron.config import Config
@dataclass
class DoReMiArgs:
smoothing_param: float = 1e-3
step_size: float = 1.0
domain_names: Optional[Union[str, List[str]]] = None
domain_weights: Optional[Union[str, List[float]]] = None
# NOTE: the path where you want to load the
# reference model checkpoint for proxy training
ref_model_resume_checkpoint_path: Optional[Path] = None
def __post_init__(self):
assert self.domain_names is not None, "Domain names must be provided."
assert self.ref_model_resume_checkpoint_path is not None, "Reference model checkpoint path must be provided."
self.domain_names = [str(name.strip()) for name in self.domain_names.split(",")]
if self.domain_weights is not None:
if isinstance(self.domain_weights, str):
domain_weights = [float(weight.strip()) for weight in self.domain_weights.split(",")]
else:
domain_weights = self.domain_weights
assert torch.allclose(
torch.tensor(domain_weights).sum(), torch.tensor(1.0), rtol=1e-3
), "Domain weights must sum to 1.0."
self.domain_weights = domain_weights
self.ref_model_resume_checkpoint_path = Path(self.ref_model_resume_checkpoint_path)
@dataclass(kw_only=True) # pylint: disable=unexpected-keyword-arg
class DoReMiConfig(Config):
"""Configuration for DoReMi's Proxy Training."""
doremi: DoReMiArgs
import dataclasses
import math
import warnings
from typing import Dict, List, Union
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.dataloader import get_dataloader_worker_init
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from .doremi_context import DoReMiContext
try:
from datasets import Dataset, concatenate_datasets, load_from_disk
except ImportError:
warnings.warn("Datasets and/or Transformers not installed, you'll be unable to use the dataloader.")
logger = logging.get_logger(__name__)
class CombinedDataset(Dataset):
def __init__(self, datasets):
self.comebined_dataset = concatenate_datasets(datasets)
def __len__(self):
return len(self.comebined_dataset)
def __getitem__(self, batch):
if isinstance(batch, list) is False:
batch = [batch]
assert len(batch) > 0
if isinstance(batch[0], list):
# TODO(xrsrke): do a single index, then split the output
samples = [self.comebined_dataset[idxs] for idxs in batch]
return self._merge_dicts(samples)
return self.comebined_dataset[batch]
def _merge_dicts(self, data):
merged = {}
for key in data[0].keys():
merged[key] = np.concatenate([d[key] for d in data if key in d])
return merged
@dataclasses.dataclass
class DataCollatorForCLM:
"""
Data collator used for causal language modeling.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
doremi_context: DoReMiContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(self.input_pp_rank),
"input_mask": TensorPointer(self.input_pp_rank),
"label_ids": TensorPointer(self.output_pp_rank),
"label_mask": TensorPointer(self.output_pp_rank),
}
assert all(list(example.keys()) == ["input_ids", "domain_ids"] for example in examples)
input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[np.ndarray, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
result["domain_idxs"] = np.vstack([examples[i]["domain_ids"] for i in range(len(examples))])
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
# Cast np.array to torch.Tensor
result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()}
return result
class DistributedSamplerForDoReMi(DistributedSampler):
def __init__(
self,
datasets: List[Dataset],
batch_size: int,
num_microbatches: int,
num_replicas: int,
rank: int,
doremi_context: DoReMiContext,
parallel_context: ParallelContext,
shuffle: bool = False,
seed: int = 42,
drop_last: bool = False,
):
assert len(datasets) == len(
doremi_context.domain_weights
), "The number of datasets must equal to the number of domain weights"
super().__init__(datasets, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last)
self.datasets = datasets
self.batch_size = batch_size
self.num_microbatches = num_microbatches
self.doremi_context = doremi_context
self.parallel_context = parallel_context
self.total_size = self._calculate_total_size()
self.lengths = [len(d) for d in self.datasets]
self.offsets = np.cumsum([0] + self.lengths[:-1])
self.seed = seed
# self.global_batch_size = batch_size * dist.get_world_size(parallel_context.dp_pg) * num_microbatches
self.global_batch_size = batch_size * self.num_replicas * num_microbatches
# NOTE: Reset the seed of the generator for consistent randomness across epochs
self.generator = torch.Generator(device="cpu").manual_seed(
seed * (1 + dist.get_rank(self.parallel_context.dp_pg)) * (1 + dist.get_rank(self.parallel_context.pp_pg))
)
self.reset()
def _calculate_total_size(self):
total_samples = sum(len(d) for d in self.datasets)
return math.ceil(total_samples / self.batch_size) * self.batch_size
def __iter__(self):
return self
def _recompute_domain_batch_sizes(self, domain_weights):
domain_batch_sizes = [round(self.global_batch_size * weight.item()) for weight in domain_weights]
# NOTE: in some cases, the weight of a domain is too small
# resulting in a domain with 0 samples per global batch
# => zero loss for that domain => we no longer update the weights of that domain
# so we add a sample to that domain
domain_batch_sizes = [1 if x < 1 else x for x in domain_batch_sizes]
if sum(domain_batch_sizes) != self.global_batch_size:
# NOTE: randomly add a sample to round it up
domain_batch_sizes = self._round_up_domain_batch_sizes(
domain_batch_sizes,
target_total_size=self.global_batch_size,
)
assert all(x > 0 for x in domain_batch_sizes), "There is a domain with 0 samples per global batch"
return domain_batch_sizes
def __next__(self):
if self.microbatch_idx == 0:
# NOTE: because we randomly add a sample to round up the domain batch sizes
# so it's better if we recompute the global batch every time we start a new microbatch
# so that not bias towards a domain (where that domain gets more samples than the others)
self.domain_batch_sizes = self._recompute_domain_batch_sizes(
domain_weights=self.doremi_context.domain_weights,
)
self.batch = []
for domain_index, (idxs, domain_batch_size) in enumerate(
zip(self.domain_indices, self.domain_batch_sizes)
):
start_idx = self.domain_counters[domain_index]
end_idx = start_idx + domain_batch_size
if end_idx > len(idxs):
raise StopIteration(f"Domain {domain_index}-th ran out of samples")
assert self.domain_counters[domain_index] + domain_batch_size == end_idx
self.domain_counters[domain_index] = end_idx
global_batch_idxs = idxs[start_idx:end_idx]
self.batch.extend(global_batch_idxs)
num_samples_per_dp_rank = self.batch_size * self.num_microbatches
dp_start_idx = self.rank * num_samples_per_dp_rank
dp_end_idx = dp_start_idx + num_samples_per_dp_rank
if dp_end_idx > len(self.batch):
raise StopIteration(f"[DoReMi] Rank {self.rank} ran out of samples, len(batch)={len(self.batch)}")
dp_batch = self.batch[dp_start_idx:dp_end_idx]
microbatch_start_idx = self.microbatch_idx * self.batch_size
microbatch_end_idx = microbatch_start_idx + self.batch_size
if microbatch_end_idx > len(dp_batch):
raise StopIteration(
f"[DoReMi] Rank {self.rank}'s microbatch {self.microbatch_idx}-th ran out of samples, len(dp_batch)={len(dp_batch)}"
)
microbatch_idxs = dp_batch[microbatch_start_idx:microbatch_end_idx]
if self.microbatch_idx == self.num_microbatches - 1:
self.microbatch_idx = 0
else:
self.microbatch_idx += 1
return microbatch_idxs
def _recompute_global_batch(self):
self.domain_batch_sizes = self._recompute_domain_batch_sizes(
domain_weights=self.doremi_context.domain_weights,
)
for domain_index, (idxs, domain_batch_size) in enumerate(zip(self.domain_indices, self.domain_batch_sizes)):
start_idx = self.domain_counters[domain_index]
end_idx = start_idx + domain_batch_size
if end_idx > len(idxs):
raise StopIteration(f"Domain {domain_index}-th ran out of samples")
self.domain_counters[domain_index] = end_idx
global_batch_idxs = idxs[start_idx:end_idx]
self.batch.extend(global_batch_idxs)
def _round_up_domain_batch_sizes(self, domain_batch_sizes: List[int], target_total_size: int) -> List[int]:
"""
NOTE: Makes sum(domain_batch_sizes) == batch_size
"""
total_batch_size = sum(domain_batch_sizes)
while total_batch_size != target_total_size:
diff = target_total_size - total_batch_size
# NOTE: Randomly select a domain to increase/decrase a sample
# to match the target_total_size
eligible_indices = torch.nonzero(torch.tensor(domain_batch_sizes) > 1).view(-1)
random_index = torch.randint(
low=0, high=len(eligible_indices), size=(1,), generator=self.generator, device="cpu"
).item()
selected_domain = eligible_indices[random_index].item()
if diff > 0:
domain_batch_sizes[selected_domain] += 1
elif diff < 0 and domain_batch_sizes[selected_domain] > 0:
domain_batch_sizes[selected_domain] -= 1
total_batch_size = sum(domain_batch_sizes)
return domain_batch_sizes
def reset(self):
"""Reset the state of the sampler for a new epoch."""
self.microbatch_idx = 0
self.domain_counters = [0 for _ in self.datasets]
self.total_samples_yielded = 0
self.out_of_samples = False
domain_indices = []
for i, dataset in enumerate(self.datasets):
local_indices = torch.arange(0, len(dataset), device="cpu").tolist()
# NOTE: align the indices across the combined dataset
global_indices = local_indices + self.offsets[i]
domain_indices.append(global_indices)
self.num_samples_per_global_step = self.batch_size * self.num_microbatches * self.num_replicas
self.domain_indices = domain_indices
self.expected_total_samples = sum([len(d) for d in domain_indices])
def get_datasets(paths):
datasets = []
for path in tqdm(paths, desc="Loading dataset from disk"):
d = load_from_disk(path)
datasets.append(d)
return datasets
def get_dataloader(trainer: DistributedTrainer, datasets) -> DataLoader:
doremi_context = trainer.doremi_context
parallel_context = trainer.parallel_context
datasets = [d.with_format(type="numpy", columns=["input_ids"], output_all_columns=True) for d in datasets]
# TODO(xrsrke): decouple trainer from dataloader
# TODO(xrsrke): decouple data collating from data loading
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
data_collator = DataCollatorForCLM(
sequence_length=trainer.sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
doremi_context=doremi_context,
)
sampler = DistributedSamplerForDoReMi(
datasets,
batch_size=trainer.micro_batch_size,
num_microbatches=trainer.n_micro_batches_per_batch,
num_replicas=parallel_context.dp_pg.size(),
rank=dist.get_rank(parallel_context.dp_pg),
seed=trainer.config.data_stages[0].data.seed,
drop_last=True,
doremi_context=doremi_context,
parallel_context=parallel_context,
)
comebined_dataset = CombinedDataset(datasets)
dataloader = DataLoader(
comebined_dataset,
batch_sampler=sampler,
collate_fn=data_collator,
num_workers=trainer.config.data_stages[0].data.num_loading_workers,
pin_memory=True,
worker_init_fn=get_dataloader_worker_init(dp_rank=dist.get_rank(parallel_context.dp_pg)),
)
def _data_generator(dataloader):
def inner():
for batch in dataloader:
batch = {k: v.to("cuda") for k, v in batch.items()}
# NOTE: because the inference model don't take `domain_idxs`
# as input we need to remove it from the batch
batch_for_inference = {k: v for k, v in batch.items() if k != "domain_idxs"}
ref_losses = trainer.ref_model(**batch_for_inference)["losses"]
batch["ref_losses"] = ref_losses
yield batch
return inner
dataloader = _data_generator(dataloader) if doremi_context.is_proxy is True else dataloader
# NOTE: we need to call the dataloader to generate reference losses
# if the model is a proxy model
dataloader = dataloader() if doremi_context.is_proxy is True else dataloader
return dataloader
from dataclasses import dataclass, field
from typing import List, TypedDict
import torch
class WeightHistory(TypedDict):
step: int
weight: torch.Tensor
@dataclass
class DoReMiContext:
# NOTE: this is the current domain weights
domain_keys: List[str]
is_proxy: bool
step_size: float = 1
smoothing_param: float = 1e-3
domain_weight_history: WeightHistory = field(default_factory=list)
@property
def num_domains(self) -> int:
return len(self.domain_keys)
def get_domain_name(self, domain_idx: int) -> str:
return self.domain_keys[domain_idx]
def __post_init__(self):
# NOTE: by default, we do uniform sampling for DoReMi
self.domain_weights = torch.ones(self.num_domains) / self.num_domains
assert torch.allclose(
self.domain_weights.sum(dim=-1), torch.tensor(1.0), rtol=0.001
), "Domain weights must sum up to 1."
assert (
self.domain_weights.shape[0] == self.num_domains
), "The length of domain_weights must be equal to the number of domains"
self.add_weight_with_history(self.domain_weights, 0)
def add_weight_with_history(self, domain_weights: torch.Tensor, step: int):
assert step >= 0, "Step must be a positive integer"
self.domain_weight_history.append(WeightHistory(step=step, weight=domain_weights.cpu()))
import math
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from nanotron import logging
from nanotron.config import ParallelismArgs
from nanotron.models import NanotronModel
from nanotron.models.llama import LlamaModel
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from transformers import LlamaConfig
from .doremi_context import DoReMiContext
from .loss import CrossEntropyWithPerDomainLoss, DoReMiLossForProxyTraining
logger = logging.get_logger(__name__)
class BaseLLaMa(NanotronModel):
@torch.no_grad()
def init_model_randomly(self, config):
"""Initialize model parameters randomly.
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
std = config.model.init_method.std
sigma = config.model.init_method.std
num_layers = config.model.model_config.num_hidden_layers
for param_name, param in model.named_parameters():
assert isinstance(param, NanotronParameter)
module_name, param_name = param_name.rsplit(".", 1)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelRowLinear):
if "weight" == param_name:
nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers))
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TritonRMSNorm):
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
module.weight.fill_(1)
elif "bias" == param_name:
module.bias.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
elif isinstance(module, TensorParallelEmbedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
else:
raise Exception(f"Parameter {full_param_name} was not initialized")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
class LLaMaForInference(BaseLLaMa):
def __init__(
self,
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
parallel_context: ParallelContext,
):
super().__init__()
self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = sharded_cross_entropy(
sharded_logits,
label_ids.transpose(0, 1).contiguous(),
group=self.parallel_context.tp_pg,
dtype=torch.float,
).transpose(0, 1)
return {"losses": loss}
class LlamaForDoReMiTraining(BaseLLaMa):
def __init__(
self,
config: LlamaConfig,
parallel_context: ParallelContext,
doremi_context: DoReMiContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=DoReMiLossForProxyTraining,
module_kwargs={
"parallel_context": parallel_context,
"doremi_context": doremi_context,
},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
"domain_idxs",
"ref_losses",
},
module_output_keys={
"loss",
"ce_loss",
"domain_losses",
"domain_weights",
"samples_per_domain",
},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
domain_idxs: Optional[Union[torch.Tensor, TensorPointer]],
ref_losses: Optional[Union[torch.Tensor, TensorPointer]],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
sharded_logits = sharded_logits.transpose(0, 1).contiguous()
outputs = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
domain_idxs=domain_idxs,
ref_losses=ref_losses,
)
return outputs
class LlamaReferenceForTrainingWithPerDomainLoss(BaseLLaMa):
def __init__(
self,
config: LlamaConfig,
doremi_context: DoReMiContext,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=CrossEntropyWithPerDomainLoss,
module_kwargs={
"doremi_context": doremi_context,
"parallel_context": parallel_context,
},
module_input_keys={"sharded_logits", "label_ids", "label_mask", "domain_idxs"},
module_output_keys={"loss", "domain_losses", "samples_per_domain"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
domain_idxs: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
sharded_logits = sharded_logits.transpose(0, 1).contiguous()
outputs = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
domain_idxs=domain_idxs,
)
return {
"loss": outputs["loss"],
"domain_losses": outputs["domain_losses"],
"samples_per_domain": outputs["samples_per_domain"],
}
from typing import Dict, Tuple
import torch
import torch.distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from torch import nn
from .doremi_context import DoReMiContext
from .utils import masked_mean
def compute_per_domain_loss(
losses: torch.Tensor, domain_idxs: torch.Tensor, doremi_context: DoReMiContext, parallel_context: ParallelContext
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dp_size = dist.get_world_size(parallel_context.dp_pg)
dp_pg = parallel_context.dp_pg
# NOTE: can't do allgather([tensor_list], [tensor]) if a tensor in tensor_list is not contiguous
losses_dp = [
torch.empty_like(losses, device="cuda", memory_format=torch.contiguous_format) for _ in range(dp_size)
]
dist.all_gather(losses_dp, losses.contiguous(), group=dp_pg)
losses_dp = torch.cat(losses_dp, dim=0)
domain_ids_dp = [
torch.empty_like(domain_idxs, device="cuda", memory_format=torch.contiguous_format) for _ in range(dp_size)
]
dist.all_gather(domain_ids_dp, domain_idxs.contiguous(), group=dp_pg)
domain_ids_dp = torch.cat(domain_ids_dp, dim=0)
# NOTE: Calculate total loss per domain
n_domains = doremi_context.num_domains
domain_losses = torch.zeros(n_domains, device="cuda")
domain_ids_dp = domain_ids_dp.view(-1)
assert losses_dp.shape[0] == domain_ids_dp.shape[0]
GLOBAL_BATCH_SIZE = losses_dp.shape[0]
for i in range(GLOBAL_BATCH_SIZE):
# NOTE: sum the excess losses of all tokens in the batch
# then add it to the domain loss of the corresponding domain
domain_losses[domain_ids_dp[i]] += losses_dp[i].sum(dim=-1)
# NOTE: Normalize and smooth domain weights
samples_per_domain = torch.bincount(domain_ids_dp, minlength=n_domains)
SEQ_LEN = losses.shape[1]
normalized_domain_losses = domain_losses / (samples_per_domain * SEQ_LEN)
# NOTE: if the domain loss is zero, then the normalized domain loss is NaN
normalized_domain_losses[torch.isnan(normalized_domain_losses)] = 0.0
return losses_dp, normalized_domain_losses, samples_per_domain
def compute_domain_loss_per_replicas(
losses: torch.Tensor, domain_idxs: torch.Tensor, doremi_context: DoReMiContext
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
domain_idxs = domain_idxs.view(-1)
# NOTE: Calculate total loss per domain
n_domains = doremi_context.num_domains
domain_losses = torch.zeros(n_domains, device="cuda")
assert losses.shape[0] == domain_idxs.shape[0]
GLOBAL_BATCH_SIZE = domain_idxs.shape[0]
for i in range(GLOBAL_BATCH_SIZE):
# NOTE: sum the excess losses of all tokens in the batch
# then add it to the domain loss of the corresponding domain
domain_losses[domain_idxs[i]] += losses[i].sum(dim=-1)
# NOTE: Normalize domain weights
SEQ_LEN = losses.shape[1]
samples_per_domain = torch.bincount(domain_idxs, minlength=n_domains)
normalized_domain_losses = domain_losses / (samples_per_domain * SEQ_LEN)
# NOTE: if the domain loss is zero, then the normalized domain loss is NaN
normalized_domain_losses[torch.isnan(normalized_domain_losses)] = 0.0
return normalized_domain_losses, samples_per_domain
class DomainLossForProxyTraining:
def __init__(self, doremi_context: DoReMiContext, parallel_context: ParallelContext):
self.doremi_context = doremi_context
self.parallel_context = parallel_context
def __call__(self, losses: torch.Tensor, ref_losses: torch.Tensor, domain_idxs: torch.Tensor):
assert losses.shape == ref_losses.shape, "losses and ref_losses must have the same shape"
assert (
domain_idxs.shape[0] == losses.shape[0]
), "the batch size of domain_idxs must match the batch size of losses"
# NOTE: sometimes you'll see the domain losses equal to zero.
# this doesn't mean there are bugs, it just means that in that case,
# the proxy model is performing better than the reference model
# => clamp(lower loss - higher loss, 0) = clamp(negative, 0) = 0.
excess_losses = (losses - ref_losses).clamp(min=0)
normalized_domain_losses, samples_per_domain = compute_domain_loss_per_replicas(
excess_losses, domain_idxs, self.doremi_context
)
domain_weights = self.doremi_context.domain_weights
step_size = self.doremi_context.step_size
smoothing_param = self.doremi_context.smoothing_param
log_new_domain_weights = torch.log(domain_weights) + step_size * normalized_domain_losses
log_new_domain_weights = log_new_domain_weights - torch.logsumexp(log_new_domain_weights, dim=0)
train_domain_weights = (1 - smoothing_param) * torch.exp(log_new_domain_weights) + smoothing_param / len(
log_new_domain_weights
)
dro_loss = (train_domain_weights * normalized_domain_losses).sum(dim=-1)
return {
"dro_loss": dro_loss,
"domain_losses": normalized_domain_losses,
"domain_weights": train_domain_weights,
"samples_per_domain": samples_per_domain,
}
class CrossEntropyWithPerDomainLoss(nn.Module):
def __init__(self, doremi_context: DoReMiContext, parallel_context: ParallelContext):
super().__init__()
self.doremi_context = doremi_context
self.parallel_context = parallel_context
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
domain_idxs: torch.Tensor,
) -> Dict[str, torch.Tensor]:
per_token_loss = sharded_cross_entropy(
sharded_logits, label_ids, group=self.parallel_context.tp_pg, dtype=torch.float
)
ce_loss = masked_mean(per_token_loss, label_mask, dtype=torch.float)
_, domain_losses, samples_per_domain = compute_per_domain_loss(
per_token_loss, domain_idxs, self.doremi_context, self.parallel_context
)
return {"ce_loss": ce_loss, "domain_losses": domain_losses, "samples_per_domain": samples_per_domain}
class DoReMiLossForProxyTraining(nn.Module):
def __init__(self, doremi_context: DoReMiContext, parallel_context: ParallelContext):
super().__init__()
self.parallel_context = parallel_context
self.doremi_loss = DomainLossForProxyTraining(doremi_context, parallel_context)
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
domain_idxs: torch.Tensor,
ref_losses: torch.Tensor,
) -> Dict[str, torch.Tensor]:
loss = sharded_cross_entropy(
sharded_logits,
label_ids,
group=self.parallel_context.tp_pg,
dtype=torch.float,
)
ce_loss = masked_mean(loss, label_mask, dtype=torch.float)
doremi_loss_outputs = self.doremi_loss(loss, ref_losses, domain_idxs)
return {
"ce_loss": ce_loss,
"loss": doremi_loss_outputs["dro_loss"], # NOTE: this is the one we optimize
"domain_losses": doremi_loss_outputs["domain_losses"],
"domain_weights": doremi_loss_outputs["domain_weights"],
"samples_per_domain": doremi_loss_outputs["samples_per_domain"],
}
from typing import Dict, Iterable, List, Optional, Type, Union
import torch
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config, get_config_from_file
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.sanity_checks import assert_tensor_synced_across_pg
from nanotron.serialize import load_weights
from nanotron.trainer import DistributedTrainer
from torch.nn.parallel import DistributedDataParallel
from .config import DoReMiConfig
from .doremi_context import DoReMiContext
from .llama import (
LlamaForDoReMiTraining,
LLaMaForInference,
LlamaReferenceForTrainingWithPerDomainLoss,
)
try:
import wandb
except ImportError:
wandb = None
logger = logging.get_logger(__name__)
def print_array_for_human(arr: List[float], precision: int = 5) -> str:
formatted_elements = [f"{x:.{precision}f}" for x in arr]
return "[" + ", ".join(formatted_elements) + "]"
class DoReMiTrainer(DistributedTrainer):
def __init__(
self,
config_or_config_file: Union[Config, str],
config_class: Type[Config] = Config,
):
# NOTE: save the initial domain_weights
config: DoReMiConfig = get_config_from_file(config_or_config_file, config_class=config_class)
assert (
config.doremi.ref_model_resume_checkpoint_path is not None
), "You must provide a reference model checkpoint path for DoReMi training."
self.doremi_context = DoReMiContext(
config.doremi.domain_names,
is_proxy=True,
step_size=config.doremi.step_size,
smoothing_param=config.doremi.smoothing_param,
)
self.ref_checkpoint_path = config.doremi.ref_model_resume_checkpoint_path
super().__init__(config_or_config_file, config_class)
def _init_model_instance(self) -> Union[NanotronModel, DistributedDataParallel]:
assert (
self.ref_checkpoint_path is not None
), "You must provide a reference model checkpoint path for DoReMi's proxy training."
# NOTE: after initializing parallel context, now we can move domain weights to
# the GPU corresponding to the current rank
self.doremi_context.domain_weights = self.doremi_context.domain_weights.to("cuda")
# NOTE: SANITY CHECKS: make sure all ranks have the same domain weights
assert_tensor_synced_across_pg(
tensor=self.doremi_context.domain_weights,
pg=self.parallel_context.world_pg,
msg=lambda err: f"Domain weights are not synced across ranks {err}",
)
log_rank(
f"""[DoReMi] In DoReMi's proxy training, please note that 'loss' represents DRO loss, and 'ce_loss' represent cross entropy loss.
[DoReMi] Sampling weights: {self.doremi_context.domain_weights}""",
logger=logger,
level=logging.INFO,
)
model = self._init_model(
model_builder=lambda: LlamaForDoReMiTraining(
config=self.model_config,
parallel_context=self.parallel_context,
parallel_config=self.config.parallelism,
doremi_context=self.doremi_context,
),
)
log_rank("[DoReMi] Initializing reference model for DoReMi training", logger=logger, level=logging.INFO)
self.ref_model = self._init_model(
model_builder=lambda: LLaMaForInference(
config=self.model_config,
parallel_config=self.config.parallelism,
parallel_context=self.parallel_context,
),
)
normalized_ref_model = (
self.ref_model.module
if isinstance(self.ref_model.module, DistributedDataParallel)
else self.ref_model.module
)
log_rank(
f"Loading weights from {self.ref_checkpoint_path} for reference model",
logger=logger,
level=logging.INFO,
rank=0,
)
load_weights(
model=normalized_ref_model,
parallel_context=self.parallel_context,
root_folder=self.ref_checkpoint_path,
)
return model
def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
):
domain_weights = outputs[0]["domain_weights"]
domain_losses = outputs[0]["domain_losses"]
samples_per_domain = outputs[0]["samples_per_domain"]
# NOTE: this is cross entropy loss
ce_loss_avg = torch.stack([output["ce_loss"] for output in outputs]).sum()
handle_weight = dist.all_reduce(
domain_weights, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG
)
handle_loss = dist.all_reduce(
domain_losses, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG
)
# NOTE: sum the total samples per domain across dp replicas
handle_samples_per_domain = dist.all_reduce(
samples_per_domain, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.SUM
)
handle_ce_loss = dist.all_reduce(
ce_loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG
)
super().train_step_logs(outputs, loss_avg)
handle_weight.wait()
handle_loss.wait()
handle_samples_per_domain.wait()
handle_ce_loss.wait()
self.doremi_context.add_weight_with_history(domain_weights, self.iteration_step)
domain_weights = domain_weights.cpu().detach().numpy()
domain_losses = domain_losses.cpu().detach().numpy()
# NOTE: the domain weights here aren't the sampling weights
# but in-flight weights of the current step, we use a fixed uniform weights
# for sampling
log_rank(
f"""[DoReMi] Domain weights: {print_array_for_human(domain_weights)}
[DoReMi] Domain losses: {print_array_for_human(domain_losses)}
[DoReMi] Samples per domain: {str(samples_per_domain)}
""",
logger=logger,
level=logging.INFO,
rank=0,
group=self.parallel_context.dp_pg,
)
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0]:
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
checkpoints_path = self.config.checkpoints.checkpoints_path
checkpoint_path = checkpoints_path / f"doremi_domain_weights_{self.iteration_step}.pt"
torch.save(self.doremi_context.domain_weight_history, checkpoint_path)
if wandb is not None:
weight_logs = {
f"weight_domain_{self.doremi_context.get_domain_name(i)}": weight
for i, weight in enumerate(domain_weights)
}
loss_logs = {
f"loss_domain_{self.doremi_context.get_domain_name(i)}": loss
for i, loss in enumerate(domain_losses)
}
samples_per_domain_logs = {
f"samples_per_domain_{self.doremi_context.get_domain_name(i)}": samples
for i, samples in enumerate(samples_per_domain)
}
wandb.log(
{
**weight_logs,
**loss_logs,
**samples_per_domain_logs,
"ce_loss": ce_loss_avg.cpu().detach().numpy(),
"iteration_step": self.iteration_step,
}
)
class ReferenceTrainer(DistributedTrainer):
def __init__(self, domain_weights: torch.Tensor, domain_keys: List[str], *args, **kwargs):
self.doremi_context = DoReMiContext(domain_weights, domain_keys, is_proxy=False)
self.valid_dataloader = None
super().__init__(*args, **kwargs)
self.doremi_context.domain_weights = self.doremi_context.domain_weights.to("cuda")
# NOTE: SANITY CHECKS: make sure all ranks have the same domain weights
assert_tensor_synced_across_pg(
tensor=self.doremi_context.domain_weights,
pg=self.parallel_context.world_pg,
msg=lambda err: f"Domain weights are not synced across ranks {err}",
)
log_rank(
f"[DoReMi] Initial domain weights: {self.doremi_context.domain_weights}", logger=logger, level=logging.INFO
)
def _init_model_instance(self) -> Union[NanotronModel, DistributedDataParallel]:
model = self._init_model(
model_builder=lambda: LlamaReferenceForTrainingWithPerDomainLoss(
config=self.model_config,
doremi_context=self.doremi_context,
parallel_context=self.parallel_context,
parallel_config=self.config.parallelism,
),
)
return model
def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
):
super().train_step_logs(outputs, loss_avg)
domain_losses = outputs[0]["domain_losses"].tolist()
samples_per_domain = outputs[0]["samples_per_domain"].tolist()
log_rank(
f"[DoReMi][Train] Domain loss: {print_array_for_human(domain_losses)}",
logger=logger,
level=logging.INFO,
rank=0,
)
log_rank(
f"[DoReMi][Train] Samples per domain: {str(samples_per_domain)}",
logger=logger,
level=logging.INFO,
rank=0,
)
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
loss_logs = {
f"loss_domain_{self.doremi_context.get_domain_name(i)}": loss for i, loss in enumerate(domain_losses)
}
samples_per_domain_logs = {
f"samples_per_domain_{self.doremi_context.get_domain_name(i)}": n_samples
for i, n_samples in enumerate(samples_per_domain)
}
wandb.log(
{
**loss_logs,
**samples_per_domain_logs,
"loss_avg": loss_avg.item(),
"iteration_step": self.iteration_step,
}
)
from typing import List
import torch
from torch.utils.data import Dataset
@torch.jit.script
def masked_mean(loss: torch.Tensor, label_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
def compute_domain_weights_based_on_token_count(datasets: List[Dataset]) -> torch.Tensor:
num_samples_per_domain = [len(d) for d in datasets]
total_samples = sum(num_samples_per_domain)
weights = torch.tensor([num_sample / total_samples for num_sample in num_samples_per_domain])
return weights
from copy import deepcopy
import torch
from utils import set_system_path
set_system_path()
from examples.doremi.doremi.doremi_context import DoReMiContext
def test_initialization():
domain_keys = ["domain1", "domain2"]
step_size, smoothing_param = 0.01, 0.001
is_proxy = False
doremi_context = DoReMiContext(domain_keys, is_proxy, step_size, smoothing_param=smoothing_param)
assert torch.equal(doremi_context.domain_weights, torch.tensor([0.5, 0.5]))
assert doremi_context.domain_keys == domain_keys
assert doremi_context.is_proxy == is_proxy
assert doremi_context.step_size == step_size
assert doremi_context.smoothing_param == smoothing_param
def test_num_domains():
domain_keys = ["domain1", "domain2"]
context = DoReMiContext(domain_keys, False)
assert context.num_domains == 2
def test_get_domain_name():
domain_keys = ["domain1", "domain2"]
context = DoReMiContext(domain_keys, False)
assert context.get_domain_name(0) == "domain1"
assert context.get_domain_name(1) == "domain2"
def test_record_domain_weights_history():
domain_weights = [torch.tensor([0.1, 0.3, 0.6]), torch.tensor([0.2, 0.3, 0.5])]
domain_keys = ["domain1", "domain2", "domain3"]
doremi_context = DoReMiContext(domain_keys, False)
initial_domain_weights = deepcopy(doremi_context.domain_weights)
doremi_context.add_weight_with_history(domain_weights[0], step=1)
doremi_context.add_weight_with_history(domain_weights[1], step=2)
assert torch.equal(initial_domain_weights, doremi_context.domain_weights)
expected_weight_history = [initial_domain_weights, *domain_weights]
for i, history in enumerate(doremi_context.domain_weight_history):
assert history["step"] == i
assert torch.equal(history["weight"], expected_weight_history[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment