"vscode:/vscode.git/clone" did not exist on "45cb1c0fcb2079fff5bca4e1e1bef4ee7f4e9378"
Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
"""
This script use to log evaluation results to wandb.
python3 log_eval_results_to_wandb.py --eval-path /path/to/eval/results --wandb-project project_name --wandb-name run_name
The folder that contains the evaluation results should have the following structure:
- 5000:
results_x.json # where x is the ligheval's evaluation number
- 10000:
...
...
"""
import argparse
import json
import os
from pathlib import Path
import wandb
def run(current_path: Path):
def compute_avg_acc_of_a_benchmark(data, benchmark_prefix):
sum_acc, sum_acc_norm, sum_acc_stderr, sum_acc_norm_stderr, count = 0, 0, 0, 0, 0
for key, values in data.items():
if f"{benchmark_prefix}:" in key:
sum_acc += values["acc"]
sum_acc_norm += values["acc_norm"]
sum_acc_stderr += values["acc_stderr"]
sum_acc_norm_stderr += values["acc_norm_stderr"]
count += 1
average_acc = sum_acc / count if count else 0
return average_acc
def compute_avg_acc_of_all_tasks(data):
sum_acc, count = 0, 0
for _, values in data.items():
sum_acc += values["acc"]
count += 1
average_acc = sum_acc / count if count else 0
return average_acc
list_checkpoints = os.listdir(current_path)
sorted_list_checkpoints = sorted(list_checkpoints, key=int)
for item in sorted_list_checkpoints:
item_path = os.path.join(current_path, item)
if os.path.isdir(item_path):
json_files = [f for f in os.listdir(item_path) if f.endswith(".json")]
if len(json_files) == 1:
json_file_path = os.path.join(item_path, json_files[0])
with open(json_file_path, "r") as file:
eval_data = json.load(file)
iteration_step = eval_data["config_general"]["config"]["general"]["step"]
consumed_train_samples = eval_data["config_general"]["config"]["general"]["consumed_train_samples"]
logging_results = {}
for name, data in eval_data["results"].items():
logging_results[f"{name}_acc"] = data["acc"]
logging_results["mmlu:average_acc"] = compute_avg_acc_of_a_benchmark(eval_data["results"], "mmlu")
logging_results["arc:average_acc"] = compute_avg_acc_of_a_benchmark(eval_data["results"], "arc")
logging_results["all:average_acc"] = compute_avg_acc_of_all_tasks(eval_data["results"])
wandb.log(
{
**logging_results,
"iteration_step": iteration_step,
"consumed_train_samples": consumed_train_samples,
}
)
elif len(json_files) > 1:
print(f"More than one JSON file found in {item_path}. Skipping.")
else:
print(f"No JSON file found in {item_path}.")
print(f"Checkpoint {item} is done. /n")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--eval-path", type=str, required=True, help="Path of the lighteval's evaluation results")
parser.add_argument(
"--wandb-project", type=str, help="Path of the lighteval's evaluation results", default="nanotron_evals"
)
parser.add_argument(
"--wandb-name",
type=str,
required=True,
help="Path of the lighteval's evaluation results",
default="sanity_evals",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
eval_path = args.eval_path
wandb_project = args.wandb_project
wandb_name = args.wandb_name
wandb.init(
project=wandb_project,
name=wandb_name,
config={"eval_path": eval_path},
)
run(eval_path)
# Pre-training
We use [nanotron](https://github.com/huggingface/nanotron/) library for training SmolLM and SmolLM2 base models.
The scripts for training SmolLM v1 can be found in the `smollm1` folder. SmolLM2 has a similar architecture and setup but uses an improved data mixture that we curated and significantly longer training periods (11 trillion tokens for the 1.7B, 4 trillion for the 360M and 2 trillion for the 135M). We will upload the SmolLM2 configs soon.
## Setup
Please refer to [nanotron](https://github.com/huggingface/nanotron/) for detailed instructions on setting up your training environment and launching jobs.
After setting up the environment and tokenizing the training datasets with [datatrove](https://github.com/huggingface/datatrove) (instructions available [here](https://github.com/huggingface/nanotron/blob/main/docs/nanoset.md#nanosets)), you can modify the configurations to match your number of nodes and local paths.
Below is an example of launching SmolLM1 135M training on 1 node (you can change the DP value to 8 in the config and adjust the batch size) and run:
```bash
git clone https://github.com/huggingface/nanotron
cd nanotron
# follow installation
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config-file smollm1/config_smollm1_135M.yaml
```
If you are working on a slurm cluster, you can modify the `launch.slurm` and launch the training with:
```bash
sbatch launch.slurm
```
> [!NOTE]
> Don't forget to create the logs directory before launching the job:
# SmolLM1 135M trained on 600B tokens
checkpoints:
checkpoint_interval: 2000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_final_state: false
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder: # paths to tokenized datasets
- datasets/fineweb-edu-dedup
- datasets/cosmopedia-v2
- datasets/python-edu
- datasets/open-web-math
- datasets/stackoverflow
dataset_weights:
- 0.7
- 0.15
- 0.08
- 0.06
- 0.01
num_loading_workers: 1
seed: 42
name: training stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: smollm
run: smollm-135M
seed: 8
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.0416 # 1/sqrt(hidden_size)
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 0
eos_token_id: 0
hidden_act: silu
hidden_size: 576
initializer_range: 0.02
intermediate_size: 1536
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 9
num_hidden_layers: 30
num_key_value_heads: 3
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.003
lr_decay_starting_step: 250000
lr_decay_steps: 50000
lr_decay_style: 1-sqrt
lr_warmup_steps: 2500
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 32 # 4 nodes
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 8 # GBS = 8*2*32*sequence_length = 512*sequence_length = 1M tokens
sequence_length: 2048
train_steps: 600000
val_check_interval: -1
\ No newline at end of file
# SmolLM1 135M trained on 600B tokens
checkpoints:
checkpoint_interval: 2000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_final_state: false
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder: # paths to tokenized datasets
- datasets/fineweb-edu-dedup-ds
- datasets/fineweb-edu-dedup-ds
- datasets/fineweb-edu-dedup-ds
- datasets/fineweb-edu-dedup-ds
- datasets/fineweb-edu-dedup-ds
dataset_weights:
- 0.7
- 0.15
- 0.08
- 0.06
- 0.01
num_loading_workers: 1
seed: 42
name: training stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: smollm
run: smollm-135M
seed: 8
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.0416 # 1/sqrt(hidden_size)
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 0
eos_token_id: 0
hidden_act: silu
hidden_size: 576
initializer_range: 0.02
intermediate_size: 1536
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 9
num_hidden_layers: 30
num_key_value_heads: 3
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.003
lr_decay_starting_step: 250000
lr_decay_steps: 50000
lr_decay_style: 1-sqrt
lr_warmup_steps: 2500
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1 # 4 nodes
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 8 # GBS = 8*2*32*sequence_length = 512*sequence_length = 1M tokens
sequence_length: 2048
train_steps: 2000
val_check_interval: -1
# SmolLM1 135M trained on 600B tokens
checkpoints:
checkpoint_interval: 2000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_final_state: false
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: datasets/fineweb-edu-dedup # paths to tokenized datasets
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: training stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: smollm
run: smollm-135M
seed: 8
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.0416 # 1/sqrt(hidden_size)
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 0
eos_token_id: 0
hidden_act: silu
hidden_size: 576
initializer_range: 0.02
intermediate_size: 1536
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 9
num_hidden_layers: 30
num_key_value_heads: 3
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.003
lr_decay_starting_step: 250000
lr_decay_steps: 50000
lr_decay_style: 1-sqrt
lr_warmup_steps: 2500
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1 # 4 nodes
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 8 # GBS = 8*2*32*sequence_length = 512*sequence_length = 1M tokens
sequence_length: 2048
train_steps: 2000
val_check_interval: -1
# SmolLM1 1.7B trained on 1T tokens
checkpoints:
checkpoint_interval: 2000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_final_state: false
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder: # paths to tokenized datasets
- datasets/fineweb-edu-dedup
- datasets/cosmopedia-v2
- datasets/open-web-math
- datasets/starcoderdata-python
- datasets/stackoverflow
dataset_weights:
- 0.7
- 0.15
- 0.06
- 0.08
- 0.01
num_loading_workers: 1
seed: 42
name: training stage
start_training_step: 1
- data:
dataset: # we change data mixture to use python-edu
dataset_folder:
- datasets/fineweb-edu-dedup
- datasets/cosmopedia-v2
- datasets/open-web-math
- datasets/python-edu
- datasets/stackoverflow
- datasets/deepmind_mathematics
dataset_weights:
- 0.7
- 0.15
- 0.055
- 0.08
- 0.01
- 0.005
num_loading_workers: 1
seed: 42
name: training stage 2
start_training_step: 300000
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: smollm
run: smollm-1700M
seed: 8
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.022097086912079608
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 0
eos_token_id: 0
hidden_act: silu
hidden_size: 2048
initializer_range: 0.02
intermediate_size: 8192
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 32
num_hidden_layers: 24
num_key_value_heads: 32
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0005
lr_decay_starting_step: 400000
lr_decay_steps: 100000
lr_decay_style: 1-sqrt
lr_warmup_steps: 2000
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 64 # 8 nodes
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 4
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 4 # GBS = 4*4*64*sequence_length = 1024*sequence_length = 2.1M tokens
sequence_length: 2048
train_steps: 500000
val_check_interval: -1
\ No newline at end of file
# SmolLM1 360M trained on 600B tokens
checkpoints:
checkpoint_interval: 2000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_final_state: false
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder: # paths to tokenized datasets
- datasets/fineweb-edu-dedup
- datasets/cosmopedia-v2
- datasets/python-edu
- datasets/open-web-math
- datasets/stackoverflow
dataset_weights:
- 0.7
- 0.15
- 0.08
- 0.06
- 0.01
num_loading_workers: 1
seed: 42
name: training stage
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: smollm
run: smollm-360M
seed: 8
step: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.03227486121839514
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 0
eos_token_id: 0
hidden_act: silu
hidden_size: 960
initializer_range: 0.02
intermediate_size: 2560
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 15
num_hidden_layers: 32
num_key_value_heads: 5
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
tie_word_embeddings: true
use_cache: true
vocab_size: 49152
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.003
lr_decay_starting_step: 500000
lr_decay_steps: 100000
lr_decay_style: 1-sqrt
lr_warmup_steps: 5000
lr_warmup_style: linear
min_decay_lr: 0
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 32
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
recompute_layer: false
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_recompute_allgather: true
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: HuggingFaceTB/cosmo2-tokenizer
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 2
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 8
sequence_length: 2048
train_steps: 600000
val_check_interval: -1
\ No newline at end of file
__version__ = "0.4"
# flake8: noqa
from nanotron.config.config import *
from nanotron.config.models_config import *
from nanotron.config.utils_config import *
from nanotron.config.lighteval_config import *
import datetime
import os
from dataclasses import dataclass, fields
from pathlib import Path
from typing import List, Optional, Type, Union
import dacite
import torch
import yaml
from dacite import from_dict
from datasets.download.streaming_download_manager import xPath
from yaml.loader import SafeLoader
from nanotron.config.lighteval_config import LightEvalConfig
from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit
from nanotron.config.parallelism_config import ParallelismArgs
from nanotron.config.utils_config import (
RecomputeGranularity,
cast_str_to_pipeline_engine,
cast_str_to_torch_dtype,
serialize,
)
from nanotron.generation.sampler import SamplerType
from nanotron.logging import get_logger
from nanotron.parallel.pipeline_parallel.engine import PipelineEngine
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
logger = get_logger(__name__)
DEFAULT_SEED = 42
@dataclass
class BenchArgs:
model_name: str
sequence_length: int
micro_batch_size: int
batch_accumulation_per_replica: int
benchmark_csv_path: str
@dataclass
class LoggingArgs:
"""Arguments related to logging"""
log_level: Optional[str] = None
log_level_replica: Optional[str] = None
iteration_step_info_interval: Optional[int] = 1
def __post_init__(self):
if self.log_level is None:
self.log_level = "info"
if self.log_level not in [
"debug",
"info",
"warning",
"error",
"critical",
"passive",
]:
raise ValueError(
f"log_level should be a string selected in ['debug', 'info', 'warning', 'error', 'critical', 'passive'] and not {self.log_level}"
)
if self.log_level_replica is None:
self.log_level_replica = "info"
if self.log_level_replica not in [
"debug",
"info",
"warning",
"error",
"critical",
"passive",
]:
raise ValueError(
f"log_level_replica should be a string selected in ['debug', 'info', 'warning', 'error', 'critical', 'passive'] and not {self.log_level_replica}"
)
@dataclass
class PretrainDatasetsArgs:
hf_dataset_or_datasets: Union[str, list, dict]
hf_dataset_splits: Optional[Union[str, list]] = None
hf_dataset_config_name: Optional[str] = None
dataset_processing_num_proc_per_process: Optional[int] = 1
dataset_overwrite_cache: Optional[bool] = False
text_column_name: Optional[str] = None
def __post_init__(self):
if self.text_column_name is None:
self.text_column_name = "text"
if self.hf_dataset_splits is None:
self.hf_dataset_splits = "train"
@dataclass
class S3UploadArgs:
"""Arguments related to uploading checkpoints on s3"""
upload_s3_path: xPath
remove_after_upload: bool
s5cmd_numworkers: Optional[int]
s5cmd_concurrency: Optional[int]
s5cmd_path: Optional[xPath]
def __post_init__(self):
if isinstance(self.upload_s3_path, str):
self.upload_s3_path = xPath(self.upload_s3_path)
if isinstance(self.s5cmd_path, str):
self.s5cmd_path = xPath(self.s5cmd_path)
@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, List[str]]
dataset_weights: Optional[List[float]] = None
def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1
def __post_init__(self):
if self.seed is None:
self.seed = DEFAULT_SEED
@dataclass
class DatasetStageArgs:
"""Arguments for loading dataset in different stages of the training process"""
name: str
start_training_step: int
data: DataArgs
def __post_init__(self):
if self.start_training_step < 0:
raise ValueError(f"training_steps should be a positive integer and not {self.start_training_step}")
@dataclass
class CheckpointsArgs:
"""Arguments related to checkpoints:
checkpoints_path: where to save the checkpoints
checkpoint_interval: how often to save the checkpoints
resume_checkpoint_path: if you want to load from a specific checkpoint path
"""
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[xPath] = None
load_lr_scheduler: Optional[bool] = True
load_optimizer: Optional[bool] = True
checkpoints_path_is_shared_file_system: Optional[bool] = False
def __post_init__(self):
if isinstance(self.checkpoints_path, str):
self.checkpoints_path = xPath(self.checkpoints_path)
if isinstance(self.resume_checkpoint_path, str):
self.resume_checkpoint_path = xPath(self.resume_checkpoint_path)
@dataclass
class GeneralArgs:
"""General training experiment arguments
Args:
project: Name of the project (a project gather several runs in common tensorboard/hub-folders)
run: Name of the run
step: Global step (updated when we save the checkpoint)
consumed_train_samples: Number of samples consumed during training (should be actually just step*batch_size)
ignore_sanity_checks: Whether to ignore sanity checks
"""
project: str
run: Optional[str] = None
seed: Optional[int] = None
step: Optional[int] = None
consumed_train_samples: Optional[int] = None
benchmark_csv_path: Optional[Path] = None
ignore_sanity_checks: bool = True
def __post_init__(self):
if self.seed is None:
self.seed = DEFAULT_SEED
if self.benchmark_csv_path is not None:
assert (
os.environ.get("NANOTRON_BENCHMARK", None) is not None
), f"Please set NANOTRON_BENCHMARK to 1 when using benchmark_csv_path. Got {os.environ.get('NANOTRON_BENCHMARK', None)}"
if self.run is None:
self.run = "%date_%jobid"
self.run.replace("%date", datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
self.run.replace("%jobid", os.environ.get("SLURM_JOB_ID", "local"))
@dataclass
class ProfilerArgs:
"""Arguments related to profiling"""
profiler_export_path: Optional[Path]
@dataclass
class ModelArgs:
"""Arguments related to model architecture"""
model_config: NanotronConfigs
init_method: Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]
dtype: Optional[torch.dtype] = None
make_vocab_size_divisible_by: int = 1
ddp_bucket_cap_mb: int = 25
def __post_init__(self):
if self.dtype is None:
self.dtype = torch.bfloat16
if isinstance(self.dtype, str):
self.dtype = cast_str_to_torch_dtype(self.dtype)
self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit)
# if self.model_config.max_position_embeddings is None:
# self.model_config.max_position_embeddings = 0
@dataclass
class TokenizerArgs:
"""Arguments related to the tokenizer"""
tokenizer_name_or_path: Optional[str] = None
tokenizer_revision: Optional[str] = None
tokenizer_max_length: Optional[int] = None
@dataclass
class TokensArgs:
"""Arguments related to the tokens, sequence, batch and steps of the training"""
sequence_length: int
train_steps: int
micro_batch_size: int
batch_accumulation_per_replica: int
val_check_interval: Optional[int] = -1
limit_val_batches: Optional[int] = 0
limit_test_batches: Optional[int] = 0
@dataclass
class LRSchedulerArgs:
"""Arguments related to the learning rate scheduler
lr_warmup_steps: number of steps to warmup the learning rate
lr_warmup_style: linear or constant
lr_decay_style: linear, cosine or 1-sqrt
min_decay_lr: minimum learning rate after decay
lr_decay_steps: optional number of steps to decay the learning rate otherwise will default to train_steps - lr_warmup_steps
lr_decay_starting_step: optional number of steps to decay the learning rate otherwise will default to train_steps - lr_warmup_steps
"""
learning_rate: float
lr_warmup_steps: int = 0
lr_warmup_style: str = None
lr_decay_style: str = None
lr_decay_steps: Optional[int] = None
lr_decay_starting_step: Optional[int] = None
min_decay_lr: float = None
def __post_init__(self):
if self.lr_warmup_style not in ["linear", "constant"]:
raise ValueError(
f"lr_warmup_style should be a string selected in ['linear', 'constant'] and not {self.lr_warmup_style}"
)
if self.lr_warmup_style is None:
self.lr_warmup_style = "linear"
if self.lr_decay_style is None:
self.lr_decay_style = "linear"
if self.lr_decay_style not in ["linear", "cosine", "1-sqrt"]:
raise ValueError(
f"lr_decay_style should be a string selected in ['linear', 'cosine', '1-sqrt'] and not {self.lr_decay_style}"
)
if self.min_decay_lr is None:
self.min_decay_lr = self.learning_rate
@dataclass
class SGDOptimizerArgs:
name: str = "sgd"
@dataclass
class AdamWOptimizerArgs:
adam_eps: float
adam_beta1: float
adam_beta2: float
torch_adam_is_fused: bool
name: str = "adamW"
@dataclass
class OptimizerArgs:
"""Arguments related to the optimizer and learning rate"""
optimizer_factory: Union[SGDOptimizerArgs, AdamWOptimizerArgs]
zero_stage: int
weight_decay: float
clip_grad: Optional[float]
accumulate_grad_in_fp32: bool
learning_rate_scheduler: LRSchedulerArgs
@dataclass
class GenerationArgs:
sampler: Optional[Union[str, SamplerType]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
n_samples: Optional[int] = None
eos: Optional[str] = None
seed: Optional[int] = None
use_cache: Optional[bool] = False
def __post_init__(self):
if isinstance(self.sampler, str):
self.sampler = SamplerType[self.sampler.upper()]
if self.seed is None:
self.seed = DEFAULT_SEED
@dataclass
class Config:
"""Main configuration class"""
general: GeneralArgs
parallelism: ParallelismArgs
model: ModelArgs
tokenizer: TokenizerArgs
checkpoints: Optional[CheckpointsArgs] = None
logging: Optional[LoggingArgs] = None
tokens: Optional[TokensArgs] = None
optimizer: Optional[OptimizerArgs] = None
data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
s3_upload: Optional[S3UploadArgs] = None
@classmethod
def create_empty(cls):
cls_fields = fields(cls)
return cls(**{f.name: None for f in cls_fields})
def __post_init__(self):
if self.s3_upload is not None:
self.s3_upload.__post_init__()
# Some final sanity checks across separate arguments sections:
if self.profiler is not None and self.profiler.profiler_export_path is not None:
assert self.tokens.train_steps < 10
if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps
)
if self.data_stages is not None:
self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step)
names = [stage.name for stage in self.data_stages]
training_steps = [stage.start_training_step for stage in self.data_stages]
assert any(
stage.start_training_step == 1 for stage in self.data_stages
), "You must have a training stage starting at 1 in the config's data_stages"
for stage in self.data_stages:
if names.count(stage.name) > 1:
raise ValueError(f"Each stage should have unique names and not {names}")
if training_steps.count(stage.start_training_step) > 1:
raise ValueError(
f"Each stage should have unique starting training step, please change the starting training step for stage {stage.name}"
)
# NOTE: must order the stages by start_training_step from lowest to highest
assert all(
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"
# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
@property
def global_batch_size(self):
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp
def save_as_yaml(self, file_path: str):
config_dict = serialize(self)
file_path = str(file_path)
with open(file_path, "w") as f:
yaml.dump(config_dict, f)
# Sanity test config can be reloaded
_ = get_config_from_file(file_path, config_class=self.__class__)
def as_dict(self) -> dict:
return serialize(self)
def get_config_from_dict(
config_dict: dict, config_class: Type = Config, skip_unused_config_keys: bool = False, skip_null_keys: bool = False
):
"""Get a config object from a dictionary
Args:
args: dictionary of arguments
config_class: type of the config object to get as a ConfigTypes (Config, LightevalConfig, LightevalSlurm) or str
skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections)
skip_null_keys: whether to skip keys with value None at first and second nesting level
"""
if skip_unused_config_keys:
logger.warning("skip_unused_config_keys set")
config_dict = {
field.name: config_dict[field.name] for field in fields(config_class) if field.name in config_dict
}
if skip_null_keys:
logger.warning("Skip_null_keys set")
config_dict = {
k: {kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v
for k, v in config_dict.items()
if v is not None
}
return from_dict(
data_class=config_class,
data=config_dict,
config=dacite.Config(
cast=[Path],
type_hooks={
torch.dtype: cast_str_to_torch_dtype,
PipelineEngine: cast_str_to_pipeline_engine,
TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()],
RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()],
SamplerType: lambda x: SamplerType[x.upper()],
},
# strict_unions_match=True,
strict=True,
),
)
def get_config_from_file(
config_path: str,
config_class: Type = Config,
model_config_class: Optional[Type] = None,
skip_unused_config_keys: bool = False,
skip_null_keys: bool = False,
) -> Config:
"""Get a config object from a file (python or YAML)
Args:
config_path: path to the config file
config_type: if the file is a python file, type of the config object to get as a
ConfigTypes (Config, LightevalConfig, LightevalSlurm) or str
if None, will default to Config
skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections)
skip_null_keys: whether to skip keys with value None at first and second nesting level
"""
# Open the file and load the file
with open(config_path) as f:
config_dict = yaml.load(f, Loader=SafeLoader)
config = get_config_from_dict(
config_dict,
config_class=config_class,
skip_unused_config_keys=skip_unused_config_keys,
skip_null_keys=skip_null_keys,
)
if model_config_class is not None:
if not isinstance(config.model.model_config, (dict, model_config_class)):
raise ValueError(
f"model_config should be a dictionary or a {model_config_class} and not {config.model.model_config}"
)
config.model.model_config = model_config_class(**config.model.model_config)
return config
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
from nanotron.config.parallelism_config import ParallelismArgs
from nanotron.generation.sampler import SamplerType
from nanotron.logging import get_logger
logger = get_logger(__name__)
DEFAULT_GENERATION_SEED = 42
@dataclass
class GenerationArgs:
sampler: Optional[Union[str, SamplerType]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
n_samples: Optional[int] = None
eos: Optional[str] = None
seed: Optional[int] = None
use_cache: Optional[bool] = False
def __post_init__(self):
if isinstance(self.sampler, str):
self.sampler = SamplerType[self.sampler.upper()]
if self.seed is None:
self.seed = DEFAULT_GENERATION_SEED
@dataclass
class LightEvalLoggingArgs:
"""Arguments related to logging for LightEval"""
local_output_path: Optional[Path] = None
push_results_to_hub: Optional[bool] = None
push_details_to_hub: Optional[bool] = None
push_results_to_tensorboard: Optional[bool] = None
hub_repo_results: Optional[str] = None
hub_repo_details: Optional[str] = None
hub_repo_tensorboard: Optional[str] = None
tensorboard_metric_prefix: Optional[str] = None
def __post_init__(self):
if isinstance(self.local_output_path, str):
self.local_output_path = Path(self.local_output_path)
@dataclass
class LightEvalTasksArgs:
"""Arguments related to tasks for LightEval"""
tasks: Optional[str] = None
custom_tasks: Optional[str] = None
max_samples: Optional[int] = None
num_fewshot_seeds: Optional[int] = None
dataset_loading_processes: Optional[int] = 8
multichoice_continuations_start_space: Optional[bool] = None
no_multichoice_continuations_start_space: Optional[bool] = None
@dataclass
class LightEvalWandbLoggerConfig:
"""Arguments related to the local Wandb logger"""
wandb_project: str = ""
wandb_entity: Optional[str] = None
wandb_run_name: Optional[str] = None
def __post_init__(self):
assert self.wandb_project != "", "Please specify a wandb_project"
@dataclass
class LightEvalConfig:
"""Arguments related to running LightEval on checkpoints.
All is optional because you can also use this class to later supply arguments to override
the saved config when running LightEval after training.
"""
slurm_template: Optional[str] = None
slurm_script_dir: Optional[str] = None
checkpoints_path: Optional[str] = None
parallelism: Optional[ParallelismArgs] = None
batch_size: Optional[int] = None
generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None
tasks: Optional[LightEvalTasksArgs] = None
logging: Optional[LightEvalLoggingArgs] = None
wandb: Optional[LightEvalWandbLoggerConfig] = None
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional, Union
@dataclass
class RandomInit:
std: float
@dataclass
class SpectralMupInit:
"""This is used to initialize the model with spectral mup. Set it to True to use it."""
use_mup: bool
def __post_init__(self):
assert self.use_mup, "Remove `use_mup` if you don't want to use it"
@dataclass
class ExistingCheckpointInit:
"""This is used to initialize from an already existing model (without optimizer, lr_scheduler...)"""
path: Path
@dataclass
class LlamaConfig:
"""Configuration for a LLAMA model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
is_llama_config: bool = True # We use this help differentiate models in yaml/python conversion
max_position_embeddings: int = 2048
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: Optional[int] = None
pad_token_id: Optional[int] = None
pretraining_tp: int = 1
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
# then we only pass LlamaConfig around
self._is_using_mup: bool = False
# self._init_method: Optional[Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]] = None
# for backward compatibility
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@property
def is_using_mup(self) -> bool:
return self._is_using_mup
@dataclass
class Starcoder2Config:
"""Configuration for a Starcoder2 model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
activation_function: str = "gelu_pytorch_tanh"
attention_softmax_in_fp32: bool = True # TODO: not used
attn_pdrop: float = 0.1
bos_token_id: int = 49152 # TODO: not used
embd_pdrop: float = 0.1
eos_token_id: int = 49152
global_attn_layers: List[int] = field(default_factory=list)
grouped_query: bool = False # GQA
hidden_size: int = 2048
initializer_range: float = 0.02 # TODO: not used
intermediate_size: Optional[int] = None
is_starcoder2_config: bool = True # We use this help differentiate models in yaml/python conversion
layer_norm_epsilon: float = 1e-05
max_position_embeddings: int = 4096
multi_query: bool = False # MQA
num_attention_heads: int = 16
num_hidden_layers: int = 24
num_kv_heads: Optional[int] = None
resid_pdrop: float = 0.1
rope_theta: Optional[int] = 10000
scale_attention_softmax_in_fp32: bool = True
scale_attn_weights: bool = True
sliding_window_size: Optional[int] = None
use_position_embeddings: bool = False # TODO @nouamane this is not used
use_rotary_embeddings: bool = True
vocab_size: int = 49280
def __post_init__(self):
if self.global_attn_layers is None:
self.global_attn_layers = []
if self.grouped_query:
assert self.num_kv_heads is not None, "num_kv_heads must be specified for grouped query"
assert self.multi_query is False, "Cannot use both multi_query and grouped_query"
if not self.multi_query and not self.grouped_query:
self.multi_query = True
@property
def n_embed(self):
return self.hidden_size
@property
def n_head(self):
return self.num_attention_heads
@property
def n_layer(self):
return self.num_hidden_layers
@property
def n_positions(self):
return self.max_position_embeddings
@property
def n_inner(self):
return self.intermediate_size
NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any]
from dataclasses import dataclass
from typing import Optional
from nanotron.config.utils_config import (
cast_str_to_pipeline_engine,
)
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
@dataclass
class ParallelismArgs:
"""Arguments related to TP/PP/DP
Args:
dp: Number of DP replicas
pp: Number of PP stages
tp: Number of TP replicas
expert_parallel_size: Number of expert parallel replicas (used only for MoEs)
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""
dp: int
pp: int
tp: int
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False
tp_recompute_allgather: bool = True
expert_parallel_size: int = 1
def __post_init__(self):
# Conservative defaults
if self.pp_engine is None:
self.pp_engine = AllForwardAllBackwardPipelineEngine()
if self.tp_mode is None:
self.tp_mode = TensorParallelLinearMode.ALL_REDUCE
if self.tp_linear_async_communication is None:
self.tp_linear_async_communication = False
if isinstance(self.pp_engine, str):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
from dataclasses import fields
from enum import Enum, auto
from pathlib import Path
import torch
from nanotron.generation.sampler import SamplerType
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
OneForwardOneBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
class RecomputeGranularity(Enum):
SELECTIVE = auto()
FULL = auto()
def serialize(data) -> dict:
"""Recursively serialize a nested dataclass to a dict - do some type conversions along the way"""
if data is None:
return None
if not hasattr(data, "__dataclass_fields__"):
return data
result = {}
for field in fields(data):
value = getattr(data, field.name)
if hasattr(value, "__dataclass_fields__"):
result[field.name] = serialize(value)
elif isinstance(value, Path):
result[field.name] = str(value)
elif isinstance(value, PipelineEngine):
result[field.name] = cast_pipeline_engine_to_str(value)
elif isinstance(value, TensorParallelLinearMode):
result[field.name] = value.name
elif isinstance(value, RecomputeGranularity):
result[field.name] = value.name
elif isinstance(value, SamplerType):
result[field.name] = value.name
elif isinstance(value, torch.dtype):
result[field.name] = dtype_to_str[value]
elif isinstance(value, (list, tuple)):
result[field.name] = [serialize(v) for v in value]
elif isinstance(value, dict) and not value:
result[field.name] = None # So we can serialize empty dicts without issue with `datasets` in particular
else:
result[field.name] = value
return result
str_to_dtype = {
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
dtype_to_str = {
torch.float32: "float32",
torch.float64: "float64",
torch.complex64: "complex64",
torch.complex128: "complex128",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.uint8: "uint8",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.bool: "bool",
}
def cast_str_to_torch_dtype(str_dtype: str):
if str_dtype in str_to_dtype:
return str_to_dtype[str_dtype]
else:
raise ValueError(f"dtype should be a string selected in {str_to_dtype.keys()} and not {str_dtype}")
def cast_str_to_pipeline_engine(str_pp_engine: str) -> PipelineEngine:
if str_pp_engine == "afab":
return AllForwardAllBackwardPipelineEngine()
elif str_pp_engine == "1f1b":
return OneForwardOneBackwardPipelineEngine()
else:
raise ValueError(f"pp_engine should be a string selected in ['afab', '1f1b'] and not {str_pp_engine}")
def cast_pipeline_engine_to_str(pp_engine: PipelineEngine) -> str:
if isinstance(pp_engine, AllForwardAllBackwardPipelineEngine):
return "afab"
elif isinstance(pp_engine, OneForwardOneBackwardPipelineEngine):
return "1f1b"
else:
raise ValueError(
f"pp_engine should be aan instance of AllForwardAllBackwardPipelineEngine or OneForwardOneBackwardPipelineEngine, not {type(pp_engine)}"
)
import platform
from packaging.version import Version, parse
CHECKPOINT_VERSION = Version("1.4")
PY_VERSION = parse(platform.python_version())
#### FOR SERIALIZATION ####
CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"
import dataclasses
from typing import Dict, List, Union
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
@dataclasses.dataclass
class NanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
return result
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
)
from nanotron.parallel import ParallelContext
from torch.utils.data import DataLoader
logger = logging.get_logger(__name__)
def build_nanoset_dataloader(
dataset,
sequence_length: int,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
micro_batch_size: int,
dataloader_num_workers: int,
consumed_train_samples: int = 0,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
) -> DataLoader:
# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]:
dataset_length = len(dataset)
dataset = EmptyInfiniteDataset(length=dataset_length)
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0
data_collator = NanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
sampler = get_sampler(
train_dataset=dataset,
dl_ranks_size=dp_ranks_size,
dl_rank=dp_rank,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
shuffle=False,
)
return DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
collate_fn=data_collator,
drop_last=dataloader_drop_last,
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
)
import os
import warnings
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
from datatrove.utils.dataset import DatatroveFolderDataset
from nanotron import logging
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.logging import log_rank
from numba import jit
logger = logging.get_logger(__name__)
class Nanoset(torch.utils.data.Dataset):
"""
The Nanoset dataset
Args:
dataset_folders (List[str]): List of folders with tokenized datasets
dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
sequence_length (int): Sequence length of the built samples
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
"""
def __init__(
self,
dataset_folders: List[str],
sequence_length: int,
token_size: int,
train_split_num_samples: int,
dataset_weights: Union[List[float], None] = None,
random_seed: int = 1234,
) -> None:
# Checks
if isinstance(dataset_folders, str):
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
dataset_folders = [dataset_folders]
# Init
self.dataset_folders = dataset_folders
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.random_seed = random_seed
self.datatrove_datasets = []
for dataset_folder in self.dataset_folders:
self.datatrove_datasets.append(
DatatroveFolderDataset(
folder_path=dataset_folder,
filename_pattern=os.path.join(dataset_folder, "*.ds"),
seq_len=sequence_length,
recursive=False,
token_size=token_size,
shuffle=True,
)
)
# Build Nanoset Index
## To build the index we need the length of each dataset
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
## Set dataset weights
if (
dataset_weights is None
): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch
self.dataset_weights = normalize(self.dataset_lengths)
else:
self.dataset_weights = normalize(dataset_weights)
assert len(dataset_folders) == len(
self.dataset_weights
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()
self.print_nanoset_info()
def __len__(self) -> int:
"""
Returns:
int: The number of samples of the Nanoset
"""
return len(self.dataset_index)
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"""
Returns sequence_length + 1 tokens from the memmap dataset
Args:
idx (int): The index into the dataset
Returns:
Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary
"""
dataset = self.dataset_index[idx]
dataset_sample = self.dataset_sample_index[idx]
return self.datatrove_datasets[dataset][dataset_sample]
def build_nanoset_index(self) -> np.ndarray:
"""
Build dataset index and dataset sample index
"""
# Compute samples per epoch and number of epochs
samples_per_epoch = sum(self.dataset_lengths)
num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1
# Build the dataset indexes for 1 epoch
dataset_index, dataset_sample_index = build_nanoset_index_helper(
n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths
)
# Shuffle the indexes the same way
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
# Just keep the necessary samples
dataset_index = dataset_index[: self.train_split_num_samples]
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]
return dataset_index, dataset_sample_index
def print_nanoset_info(self):
log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0)
log_rank(
f"> Total number of tokens: {len(self) * self.sequence_length}", logger=logger, level=logging.INFO, rank=0
)
# Print samples from each dataset + weight
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
for index, sample_count in enumerate(dataset_sample_count):
log_rank(
f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
logger=logger,
level=logging.INFO,
rank=0,
)
@jit(nopython=True, cache=True)
def build_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given multiple datasets and a weighting array, build samples indexes
such that it follows those weights
"""
# Create empty arrays for dataset indices and dataset sample indices
dataset_index = np.empty((n_samples,), dtype="uint")
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples
# Initialize buffer for number of samples used for each dataset
current_samples = np.zeros((len(weights),), dtype="long")
# Iterate over all samples
for sample_idx in range(n_samples):
# Convert sample index to float for comparison against weights
sample_idx_float = max(sample_idx, 1.0)
# Find the dataset with the highest error
errors = weights * sample_idx_float - current_samples
max_error_index = np.argmax(errors)
# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]
# Update the total samples for the selected dataset
current_samples[max_error_index] += 1
return dataset_index, dataset_sample_index
from typing import List
import numpy as np
def normalize(weights: List[float]) -> List[np.array]:
"""
Normalize elements of a list
Args:
weights (List[float]): The weights
Returns:
List[numpy.array]: The normalized weights
"""
w = np.array(weights, dtype=np.float64)
w_sum = np.sum(w)
w = w / w_sum
return w
def count_dataset_indexes(dataset_idx: np.ndarray, n_datasets: int):
counts = []
for dataset in range(n_datasets):
counts.append(np.count_nonzero(dataset_idx == dataset))
return counts
import dataclasses
import warnings
from typing import Dict, Generator, Iterator, List, Optional, Union
import numpy as np
import torch
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.random import set_random_seed
from nanotron.sanity_checks import (
assert_fail_except_rank_with,
assert_tensor_synced_across_pg,
)
try:
import datasets
from datasets import (
Dataset,
DatasetDict,
Features,
Sequence,
Value,
concatenate_datasets,
load_dataset,
)
from transformers import PreTrainedTokenizerBase
from transformers.trainer_pt_utils import DistributedSamplerWithLoop
except ImportError:
warnings.warn("Datasets and/or Transformers not installed, you'll be unable to use the dataloader.")
logger = logging.get_logger(__name__)
def sanity_check_dataloader(
dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]],
parallel_context: ParallelContext,
config: Config,
) -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]:
for batch in dataloader:
micro_batch = {
k: v if isinstance(v, TensorPointer) else v.to("cuda", memory_format=torch.contiguous_format)
for k, v in batch.items()
}
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check input are not the same across DP
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
continue
if "mask" in key:
# It's fine if mask is the same across DP
continue
with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg):
assert_tensor_synced_across_pg(
tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}"
)
# SANITY CHECK: Check input are synchronized throughout TP
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
continue
assert_tensor_synced_across_pg(
tensor=value,
pg=parallel_context.tp_pg,
msg=lambda err: f"{key} are not synchronized throughout TP {err}",
)
# SANITY CHECK: Check that input are synchronized throughout PP
# TODO @thomasw21: That's really hard to test as input gets sharded across the PP, let's assume it works for now.
# SANITY CHECK: Check that an input only exists on the PP rank responsible for it
# TODO @nouamanetazi: add this test
yield micro_batch
# Adapted from h4/src/h4/data/loading.py
def get_datasets(
hf_dataset_or_datasets: Union[dict, str],
hf_dataset_config_name: str,
splits: Optional[Union[List[str], str]] = ["train", "test"],
) -> "DatasetDict":
"""
Function to load dataset directly from DataArguments.
Args:
hf_dataset_or_datasets (Union[dict, str]): dict or string. When all probabilities are 1, we concatenate the datasets instead of sampling from them.
splits (Optional[List[str]], optional): Section of the dataset to load, defaults to "train", "test"
Can be one of `train_ift`, `test_rl`, or `..._rm` etc. H4 datasets are divided into 6 subsets for training / testing.
Returns
DatasetDict: DatasetDict object containing the dataset of the appropriate section with test + train parts.
"""
if isinstance(splits, str):
splits = [splits]
if isinstance(hf_dataset_or_datasets, dict):
# Structure of the config to read the datasets and their mix
# datasets_mixer:
# - 'dataset1': 0.5
# - 'dataset2': 0.3
# - 'dataset3': 0.2
raw_datasets = _get_dataset_mix(hf_dataset_or_datasets, splits=splits)
elif isinstance(hf_dataset_or_datasets, str):
# e.g. Dataset = "HuggingFaceH4/testing_alpaca_small"
# Note this returns things other than just train/test, which may not be intended
raw_datasets = DatasetDict()
for split in splits:
raw_datasets[split] = load_dataset(
hf_dataset_or_datasets,
hf_dataset_config_name,
split=split,
)
else:
raise ValueError(f"hf_dataset_or_datasets must be a dict or string but is {type(hf_dataset_or_datasets)}")
return raw_datasets
# Adapted from h4/src/h4/data/loading.py
def _get_dataset_mix(dataset_dict: dict, splits: List[str] = None, seed=42) -> "DatasetDict":
"""
Helper function to load dataset mix from dict configuration.
Args:
dataset_dict (dict): Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
splits (Optional[List[str]], optional): Section of the dataset to load, defaults to "train", "test"
Can be one of `train_{ift,rm,rl}` and `test_{ift,rm,rl}`. Our datasets are typically divided into 6 subsets for training / testing.
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_test_datasets = []
fracs = []
for ds, frac in dataset_dict.items():
if frac < 0:
raise ValueError(f"Dataset fraction for dataset {ds} is negative. (= {frac})")
fracs.append(frac)
for split in splits:
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
elif "test" in split:
raw_test_datasets.append(
load_dataset(
ds,
split=split,
)
)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=seed)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_test_datasets) > 0:
raw_datasets["test"] = concatenate_datasets(raw_test_datasets).shuffle(seed=seed)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_dict} not recognized with split {split}. Check the dataset has been correctly formatted."
)
return raw_datasets
def dummy_infinite_data_generator(
micro_batch_size: int,
sequence_length: int,
input_pp_rank: int,
output_pp_rank: int,
vocab_size: int,
seed: int,
parallel_context: ParallelContext,
):
def data_generator() -> Generator[Dict[str, Union[torch.Tensor, TensorPointer]], None, None]:
# Random generator
generator = torch.Generator(device="cuda")
# Make sure that TP are synced always
generator.manual_seed(
seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg))
)
while True:
yield {
"input_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"input_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"label_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
"label_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
}
return data_generator
# Adapted from https://github.com/huggingface/accelerate/blob/a73898027a211c3f6dc4460351b0ec246aa824aa/src/accelerate/data_loader.py#L781C1-L824C28
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
Note that in case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches
"""
def __init__(self, batch_sampler: BatchSampler, skip_batches: int, dp_size: int):
self.batch_sampler = batch_sampler
# In case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches
self.skip_batches = skip_batches // dp_size
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches:
yield samples
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches
def set_tensor_pointers(
input_dict: Dict[str, Union[torch.Tensor, TensorPointer]], group: dist.ProcessGroup, group_rank: int
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
"""Make sure only the group_rank rank has the data, others have TensorPointers."""
return {
k: v if dist.get_rank(group) == group_rank else TensorPointer(group_rank=group_rank)
for k, v in input_dict.items()
}
### CAUSAL LANGUAGE MODELING ###
def clm_process(
raw_dataset: "Dataset",
tokenizer: "PreTrainedTokenizerBase",
text_column_name: str,
dataset_processing_num_proc_per_process: int,
dataset_overwrite_cache: bool,
sequence_length: int,
):
"""Concatenate all texts from raw_dataset and generate chunks of `sequence_length + 1`, where chunks overlap by a single token."""
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/examples/pytorch/language-modeling/run_clm.py#L391-L439
def group_texts(examples: Dict[str, List[np.ndarray]]) -> Dict[str, List[np.ndarray]]:
# Concatenate all texts.
concatenated_examples = {k: np.concatenate(v) for k, v in examples.items()}
total_length = len(concatenated_examples[next(iter(examples.keys()))])
# WARNING: We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
# Split by chunks of sequence_length.
result = {
k: [
t[i : i + sequence_length + 1] for i in range(0, total_length - (sequence_length + 1), sequence_length)
]
for k, t in concatenated_examples.items()
}
return result
def _tokenize_and_group_texts(texts: List[str]) -> Dict[str, List[np.ndarray]]:
tokenized_batch = tokenizer.batch_encode_plus(texts, return_attention_mask=False, return_token_type_ids=False)
tokenized_batch = {k: [np.array(tokenized_texts) for tokenized_texts in v] for k, v in tokenized_batch.items()}
return group_texts(tokenized_batch)
train_dataset = raw_dataset.map(
_tokenize_and_group_texts,
input_columns=text_column_name,
remove_columns=raw_dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
batched=True,
num_proc=dataset_processing_num_proc_per_process,
load_from_cache_file=not dataset_overwrite_cache,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
return train_dataset
# Adapted from: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/data/data_collator.py#L607
@dataclasses.dataclass
class DataCollatorForCLM:
"""
Data collator used for causal language modeling.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[np.ndarray, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
# Cast np.array to torch.Tensor
result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()}
return result
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L763-L835
def get_sampler(
dl_ranks_size: int,
dl_rank: int,
train_dataset: Union["Dataset", torch.utils.data.Dataset],
consumed_train_samples: int,
seed: int = 42,
use_loop_to_round_batch_size: bool = False,
micro_batch_size: Optional[int] = None,
drop_last: Optional[bool] = True,
shuffle: bool = True,
) -> Optional[torch.utils.data.Sampler]:
"""returns sampler that restricts data loading to a subset of the dataset proper to the DP rank"""
# Build the sampler.
# TODO @nouamanetazi: Support group_by_length: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L783-L810
if use_loop_to_round_batch_size:
assert micro_batch_size is not None
# loops at the end back to the beginning of the shuffled samples to make each process have a round multiple of batch_size samples.
sampler = DistributedSamplerWithLoop(
train_dataset,
batch_size=micro_batch_size,
num_replicas=dl_ranks_size,
rank=dl_rank,
seed=seed,
drop_last=drop_last,
)
else:
sampler = DistributedSampler(
train_dataset, num_replicas=dl_ranks_size, rank=dl_rank, seed=seed, drop_last=drop_last, shuffle=shuffle
)
if consumed_train_samples > 0:
sampler = SkipBatchSampler(sampler, skip_batches=consumed_train_samples, dp_size=dl_ranks_size)
return sampler
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L837
def get_train_dataloader(
train_dataset: "Dataset",
sequence_length: int,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
micro_batch_size: int,
consumed_train_samples: int,
dataloader_num_workers: int,
seed_worker: int,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
use_loop_to_round_batch_size: bool = False,
) -> DataLoader:
if not isinstance(train_dataset, datasets.Dataset):
raise ValueError(f"training requires a datasets.Dataset, but got {type(train_dataset)}")
# Case of ranks requiring data
if dist.get_rank(parallel_context.pp_pg) in [
input_pp_rank,
output_pp_rank,
]:
train_dataset = train_dataset.with_format(type="numpy", columns=["input_ids"], output_all_columns=True)
# Case of ranks not requiring data. We give them an infinite dummy dataloader
else:
#
assert train_dataset.column_names == ["input_ids"], (
f"Dataset has to have a single column, with `input_ids` as the column name. "
f"Current dataset: {train_dataset}"
)
dataset_length = len(train_dataset)
train_dataset = train_dataset.remove_columns(column_names="input_ids")
assert (
len(train_dataset) == 0
), f"Dataset has to be empty after removing the `input_ids` column. Current dataset: {train_dataset}"
# HACK as if we remove the last column of a train_dataset, it becomes empty and it's number of rows becomes empty.
train_dataset = EmptyInfiniteDataset(length=dataset_length)
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0
data_collator = DataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
# TODO @nouamanetazi: Remove unused columns: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L852
# TODO @nouamanetazi: Support torch.utils.data.IterableDataset: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L855-L872
train_sampler = get_sampler(
dl_rank=dp_rank,
dl_ranks_size=dp_ranks_size,
train_dataset=train_dataset,
seed=seed_worker,
use_loop_to_round_batch_size=use_loop_to_round_batch_size,
micro_batch_size=micro_batch_size,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
)
return DataLoader(
train_dataset,
batch_size=micro_batch_size,
sampler=train_sampler,
collate_fn=data_collator,
drop_last=dataloader_drop_last, # we also drop_last in `clm_process()`
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
# TODO @thomasw21: I'm not sure but this doesn't seem to work at all.
# pin_memory_device="cuda",
)
def get_dataloader_worker_init(dp_rank: int):
"""Creates random states for each worker in order to get different state in each workers"""
def dataloader_worker_init(worker_id):
# Dataloader is TP/PP synced in random states
seed = 2 ** (1 + worker_id) * 3 ** (1 + dp_rank) % (2**32)
set_random_seed(seed)
return dataloader_worker_init
class EmptyInfiniteDataset:
"""Hack as removing all columns from a datasets.Dataset makes the number of rows 0."""
def __init__(self, length: int):
self._length = length
def __getitem__(self, item) -> Dict:
if isinstance(item, int):
return {}
raise NotImplementedError(f"{item} of type {type(item)} is not supported yet")
def __len__(self) -> int:
return self._length
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment