""" Nanotron training script example using a custom dataloader. Usage: ``` export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml ``` """ import argparse from typing import Dict, cast import datasets import numpy as np from nanotron import logging from nanotron.config import ( DataArgs, DatasetStageArgs, PretrainDatasetsArgs, ) from nanotron.dataloader import ( DataCollatorForCLM, clm_process, get_dataloader_worker_init, get_datasets, get_train_dataloader, ) from nanotron.helpers import ( compute_remain_train_steps_of_a_data_stage_from_ckp, get_consumed_train_samples_of_a_data_stage_from_ckp, ) from nanotron.logging import log_rank from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks from nanotron.trainer import DistributedTrainer from nanotron.utils import main_rank_first from torch.utils.data import DataLoader try: from huggingface_hub import __version__ as hf_hub_version from transformers import AutoTokenizer from transformers import __version__ as tf_version except ImportError: hf_hub_version = None tf_version = None logger = logging.get_logger(__name__) def get_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, consumed_train_samples: int, num_remaining_train_steps: int, ): """ Returns a dataloader for a given data stage. data: The data configuration for the current stage. consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). num_remaining_train_steps: The number of remaining training steps for this stage. """ assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" # First, we need to know which ranks to feed the dataloader to input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) # Case 1: custom data generator if data.dataset is None: log_rank("Using custom data generator", logger=logger, level=logging.INFO, rank=0) ########################################################################################################### # This can be replaced with your own tokenized data generator ########################################################################################################### train_dataset = datasets.Dataset.from_dict( { "input_ids": np.random.randint( 0, trainer.config.model.model_config.vocab_size, (trainer.global_batch_size * num_remaining_train_steps, trainer.sequence_length + 1), ), } ) ########################################################################################################### data_collator = DataCollatorForCLM( sequence_length=trainer.sequence_length, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, parallel_context=trainer.parallel_context, ) return DataLoader( train_dataset, batch_size=trainer.micro_batch_size, collate_fn=data_collator, drop_last=True, num_workers=0, pin_memory=True, worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()), ) # Case 2: HuggingFace datasets elif isinstance(data.dataset, PretrainDatasetsArgs): log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path log_rank( f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}", logger=logger, level=logging.INFO, rank=0, ) # We need to the 1st device to process dataset and cache it, then other devices load from cache with main_rank_first(trainer.parallel_context.world_pg): # We load the raw dataset raw_dataset = get_datasets( hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets, hf_dataset_config_name=data.dataset.hf_dataset_config_name, splits=data.dataset.hf_dataset_splits, )["train"] tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" # We apply the Causal Language Modeling preprocessing train_dataset = clm_process( raw_dataset=raw_dataset, tokenizer=tokenizer, text_column_name=data.dataset.text_column_name, dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process, dataset_overwrite_cache=data.dataset.dataset_overwrite_cache, sequence_length=trainer.sequence_length, ) # We load the processed dataset on the ranks requiring it dataloader = get_train_dataloader( train_dataset=train_dataset, sequence_length=trainer.sequence_length, parallel_context=trainer.parallel_context, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, ) # Check if we have enough samples for train_steps total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length num_tokens_needed_for_training = ( num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length ) assert num_tokens_needed_for_training <= total_tokens_dataset, ( f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" ) else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") return dataloader def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} for stage_idx, stage in enumerate(trainer.config.data_stages): # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) assert ( consumed_train_samples is not None ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, trainer.config, trainer.metadata ) log_rank( f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", logger=logger, level=logging.INFO, rank=0, ) dataloader = ( get_dataloader_from_data_stage( trainer, stage.data, consumed_train_samples=consumed_train_samples, num_remaining_train_steps=num_remaining_train_steps, ) if stage_idx == 0 else lambda stage=stage: get_dataloader_from_data_stage( trainer, stage.data, consumed_train_samples=consumed_train_samples, num_remaining_train_steps=num_remaining_train_steps, ) ) dataloaders[stage.name] = dataloader return dataloaders def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") return parser.parse_args() if __name__ == "__main__": args = get_args() config_file = args.config_file # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) # Train trainer.train(dataloader)