Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional, Union
@dataclass
class RandomInit:
std: float
@dataclass
class SpectralMupInit:
"""This is used to initialize the model with spectral mup. Set it to True to use it."""
use_mup: bool
def __post_init__(self):
assert self.use_mup, "Remove `use_mup` if you don't want to use it"
@dataclass
class ExistingCheckpointInit:
"""This is used to initialize from an already existing model (without optimizer, lr_scheduler...)"""
path: Path
@dataclass
class LlamaConfig:
"""Configuration for a LLAMA model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
is_llama_config: bool = True # We use this help differentiate models in yaml/python conversion
max_position_embeddings: int = 2048
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: Optional[int] = None
pad_token_id: Optional[int] = None
pretraining_tp: int = 1
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
# then we only pass LlamaConfig around
self._is_using_mup: bool = False
# self._init_method: Optional[Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]] = None
# for backward compatibility
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@property
def is_using_mup(self) -> bool:
return self._is_using_mup
@dataclass
class Starcoder2Config:
"""Configuration for a Starcoder2 model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
activation_function: str = "gelu_pytorch_tanh"
attention_softmax_in_fp32: bool = True # TODO: not used
attn_pdrop: float = 0.1
bos_token_id: int = 49152 # TODO: not used
embd_pdrop: float = 0.1
eos_token_id: int = 49152
global_attn_layers: List[int] = field(default_factory=list)
grouped_query: bool = False # GQA
hidden_size: int = 2048
initializer_range: float = 0.02 # TODO: not used
intermediate_size: Optional[int] = None
is_starcoder2_config: bool = True # We use this help differentiate models in yaml/python conversion
layer_norm_epsilon: float = 1e-05
max_position_embeddings: int = 4096
multi_query: bool = False # MQA
num_attention_heads: int = 16
num_hidden_layers: int = 24
num_kv_heads: Optional[int] = None
resid_pdrop: float = 0.1
rope_theta: Optional[int] = 10000
scale_attention_softmax_in_fp32: bool = True
scale_attn_weights: bool = True
sliding_window_size: Optional[int] = None
use_position_embeddings: bool = False # TODO @nouamane this is not used
use_rotary_embeddings: bool = True
vocab_size: int = 49280
def __post_init__(self):
if self.global_attn_layers is None:
self.global_attn_layers = []
if self.grouped_query:
assert self.num_kv_heads is not None, "num_kv_heads must be specified for grouped query"
assert self.multi_query is False, "Cannot use both multi_query and grouped_query"
if not self.multi_query and not self.grouped_query:
self.multi_query = True
@property
def n_embed(self):
return self.hidden_size
@property
def n_head(self):
return self.num_attention_heads
@property
def n_layer(self):
return self.num_hidden_layers
@property
def n_positions(self):
return self.max_position_embeddings
@property
def n_inner(self):
return self.intermediate_size
NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any]
from dataclasses import dataclass
from typing import Optional
from nanotron.config.utils_config import (
cast_str_to_pipeline_engine,
)
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
@dataclass
class ParallelismArgs:
"""Arguments related to TP/PP/DP
Args:
dp: Number of DP replicas
pp: Number of PP stages
tp: Number of TP replicas
expert_parallel_size: Number of expert parallel replicas (used only for MoEs)
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""
dp: int
pp: int
tp: int
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False
tp_recompute_allgather: bool = True
expert_parallel_size: int = 1
def __post_init__(self):
# Conservative defaults
if self.pp_engine is None:
self.pp_engine = AllForwardAllBackwardPipelineEngine()
if self.tp_mode is None:
self.tp_mode = TensorParallelLinearMode.ALL_REDUCE
if self.tp_linear_async_communication is None:
self.tp_linear_async_communication = False
if isinstance(self.pp_engine, str):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
from dataclasses import fields
from enum import Enum, auto
from pathlib import Path
import torch
from nanotron.generation.sampler import SamplerType
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
OneForwardOneBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
class RecomputeGranularity(Enum):
SELECTIVE = auto()
FULL = auto()
def serialize(data) -> dict:
"""Recursively serialize a nested dataclass to a dict - do some type conversions along the way"""
if data is None:
return None
if not hasattr(data, "__dataclass_fields__"):
return data
result = {}
for field in fields(data):
value = getattr(data, field.name)
if hasattr(value, "__dataclass_fields__"):
result[field.name] = serialize(value)
elif isinstance(value, Path):
result[field.name] = str(value)
elif isinstance(value, PipelineEngine):
result[field.name] = cast_pipeline_engine_to_str(value)
elif isinstance(value, TensorParallelLinearMode):
result[field.name] = value.name
elif isinstance(value, RecomputeGranularity):
result[field.name] = value.name
elif isinstance(value, SamplerType):
result[field.name] = value.name
elif isinstance(value, torch.dtype):
result[field.name] = dtype_to_str[value]
elif isinstance(value, (list, tuple)):
result[field.name] = [serialize(v) for v in value]
elif isinstance(value, dict) and not value:
result[field.name] = None # So we can serialize empty dicts without issue with `datasets` in particular
else:
result[field.name] = value
return result
str_to_dtype = {
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
dtype_to_str = {
torch.float32: "float32",
torch.float64: "float64",
torch.complex64: "complex64",
torch.complex128: "complex128",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.uint8: "uint8",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.bool: "bool",
}
def cast_str_to_torch_dtype(str_dtype: str):
if str_dtype in str_to_dtype:
return str_to_dtype[str_dtype]
else:
raise ValueError(f"dtype should be a string selected in {str_to_dtype.keys()} and not {str_dtype}")
def cast_str_to_pipeline_engine(str_pp_engine: str) -> PipelineEngine:
if str_pp_engine == "afab":
return AllForwardAllBackwardPipelineEngine()
elif str_pp_engine == "1f1b":
return OneForwardOneBackwardPipelineEngine()
else:
raise ValueError(f"pp_engine should be a string selected in ['afab', '1f1b'] and not {str_pp_engine}")
def cast_pipeline_engine_to_str(pp_engine: PipelineEngine) -> str:
if isinstance(pp_engine, AllForwardAllBackwardPipelineEngine):
return "afab"
elif isinstance(pp_engine, OneForwardOneBackwardPipelineEngine):
return "1f1b"
else:
raise ValueError(
f"pp_engine should be aan instance of AllForwardAllBackwardPipelineEngine or OneForwardOneBackwardPipelineEngine, not {type(pp_engine)}"
)
import platform
from packaging.version import Version, parse
CHECKPOINT_VERSION = Version("1.4")
PY_VERSION = parse(platform.python_version())
#### FOR SERIALIZATION ####
CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"
import dataclasses
from typing import Dict, List, Union
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
@dataclasses.dataclass
class NanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
return result
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
)
from nanotron.parallel import ParallelContext
from torch.utils.data import DataLoader
logger = logging.get_logger(__name__)
def build_nanoset_dataloader(
dataset,
sequence_length: int,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
micro_batch_size: int,
dataloader_num_workers: int,
consumed_train_samples: int = 0,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
) -> DataLoader:
# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]:
dataset_length = len(dataset)
dataset = EmptyInfiniteDataset(length=dataset_length)
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0
data_collator = NanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
sampler = get_sampler(
train_dataset=dataset,
dl_ranks_size=dp_ranks_size,
dl_rank=dp_rank,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
shuffle=False,
)
return DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
collate_fn=data_collator,
drop_last=dataloader_drop_last,
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
)
import os
import warnings
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
from datatrove.utils.dataset import DatatroveFolderDataset
from nanotron import logging
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.logging import log_rank
from numba import jit
logger = logging.get_logger(__name__)
class Nanoset(torch.utils.data.Dataset):
"""
The Nanoset dataset
Args:
dataset_folders (List[str]): List of folders with tokenized datasets
dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
sequence_length (int): Sequence length of the built samples
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
"""
def __init__(
self,
dataset_folders: List[str],
sequence_length: int,
token_size: int,
train_split_num_samples: int,
dataset_weights: Union[List[float], None] = None,
random_seed: int = 1234,
) -> None:
# Checks
if isinstance(dataset_folders, str):
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
dataset_folders = [dataset_folders]
# Init
self.dataset_folders = dataset_folders
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.random_seed = random_seed
self.datatrove_datasets = []
for dataset_folder in self.dataset_folders:
self.datatrove_datasets.append(
DatatroveFolderDataset(
folder_path=dataset_folder,
filename_pattern=os.path.join(dataset_folder, "*.ds"),
seq_len=sequence_length,
recursive=False,
token_size=token_size,
shuffle=True,
)
)
# Build Nanoset Index
## To build the index we need the length of each dataset
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
## Set dataset weights
if (
dataset_weights is None
): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch
self.dataset_weights = normalize(self.dataset_lengths)
else:
self.dataset_weights = normalize(dataset_weights)
assert len(dataset_folders) == len(
self.dataset_weights
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()
self.print_nanoset_info()
def __len__(self) -> int:
"""
Returns:
int: The number of samples of the Nanoset
"""
return len(self.dataset_index)
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"""
Returns sequence_length + 1 tokens from the memmap dataset
Args:
idx (int): The index into the dataset
Returns:
Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary
"""
dataset = self.dataset_index[idx]
dataset_sample = self.dataset_sample_index[idx]
return self.datatrove_datasets[dataset][dataset_sample]
def build_nanoset_index(self) -> np.ndarray:
"""
Build dataset index and dataset sample index
"""
# Compute samples per epoch and number of epochs
samples_per_epoch = sum(self.dataset_lengths)
num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1
# Build the dataset indexes for 1 epoch
dataset_index, dataset_sample_index = build_nanoset_index_helper(
n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths
)
# Shuffle the indexes the same way
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
# Just keep the necessary samples
dataset_index = dataset_index[: self.train_split_num_samples]
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]
return dataset_index, dataset_sample_index
def print_nanoset_info(self):
log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0)
log_rank(
f"> Total number of tokens: {len(self) * self.sequence_length}", logger=logger, level=logging.INFO, rank=0
)
# Print samples from each dataset + weight
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
for index, sample_count in enumerate(dataset_sample_count):
log_rank(
f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
logger=logger,
level=logging.INFO,
rank=0,
)
@jit(nopython=True, cache=True)
def build_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given multiple datasets and a weighting array, build samples indexes
such that it follows those weights
"""
# Create empty arrays for dataset indices and dataset sample indices
dataset_index = np.empty((n_samples,), dtype="uint")
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples
# Initialize buffer for number of samples used for each dataset
current_samples = np.zeros((len(weights),), dtype="long")
# Iterate over all samples
for sample_idx in range(n_samples):
# Convert sample index to float for comparison against weights
sample_idx_float = max(sample_idx, 1.0)
# Find the dataset with the highest error
errors = weights * sample_idx_float - current_samples
max_error_index = np.argmax(errors)
# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]
# Update the total samples for the selected dataset
current_samples[max_error_index] += 1
return dataset_index, dataset_sample_index
from typing import List
import numpy as np
def normalize(weights: List[float]) -> List[np.array]:
"""
Normalize elements of a list
Args:
weights (List[float]): The weights
Returns:
List[numpy.array]: The normalized weights
"""
w = np.array(weights, dtype=np.float64)
w_sum = np.sum(w)
w = w / w_sum
return w
def count_dataset_indexes(dataset_idx: np.ndarray, n_datasets: int):
counts = []
for dataset in range(n_datasets):
counts.append(np.count_nonzero(dataset_idx == dataset))
return counts
import dataclasses
import warnings
from typing import Dict, Generator, Iterator, List, Optional, Union
import numpy as np
import torch
from torch.utils.data import BatchSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import Config
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.random import set_random_seed
from nanotron.sanity_checks import (
assert_fail_except_rank_with,
assert_tensor_synced_across_pg,
)
try:
import datasets
from datasets import (
Dataset,
DatasetDict,
Features,
Sequence,
Value,
concatenate_datasets,
load_dataset,
)
from transformers import PreTrainedTokenizerBase
from transformers.trainer_pt_utils import DistributedSamplerWithLoop
except ImportError:
warnings.warn("Datasets and/or Transformers not installed, you'll be unable to use the dataloader.")
logger = logging.get_logger(__name__)
def sanity_check_dataloader(
dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]],
parallel_context: ParallelContext,
config: Config,
) -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]:
for batch in dataloader:
micro_batch = {
k: v if isinstance(v, TensorPointer) else v.to("cuda", memory_format=torch.contiguous_format)
for k, v in batch.items()
}
if not config.general.ignore_sanity_checks:
# SANITY CHECK: Check input are not the same across DP
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
continue
if "mask" in key:
# It's fine if mask is the same across DP
continue
with assert_fail_except_rank_with(AssertionError, rank_exception=0, pg=parallel_context.dp_pg):
assert_tensor_synced_across_pg(
tensor=value, pg=parallel_context.dp_pg, msg=lambda err: f"{key} {err}"
)
# SANITY CHECK: Check input are synchronized throughout TP
for key, value in sorted(micro_batch.items(), key=lambda x: x[0]):
if isinstance(value, TensorPointer):
continue
assert_tensor_synced_across_pg(
tensor=value,
pg=parallel_context.tp_pg,
msg=lambda err: f"{key} are not synchronized throughout TP {err}",
)
# SANITY CHECK: Check that input are synchronized throughout PP
# TODO @thomasw21: That's really hard to test as input gets sharded across the PP, let's assume it works for now.
# SANITY CHECK: Check that an input only exists on the PP rank responsible for it
# TODO @nouamanetazi: add this test
yield micro_batch
# Adapted from h4/src/h4/data/loading.py
def get_datasets(
hf_dataset_or_datasets: Union[dict, str],
hf_dataset_config_name: str,
splits: Optional[Union[List[str], str]] = ["train", "test"],
) -> "DatasetDict":
"""
Function to load dataset directly from DataArguments.
Args:
hf_dataset_or_datasets (Union[dict, str]): dict or string. When all probabilities are 1, we concatenate the datasets instead of sampling from them.
splits (Optional[List[str]], optional): Section of the dataset to load, defaults to "train", "test"
Can be one of `train_ift`, `test_rl`, or `..._rm` etc. H4 datasets are divided into 6 subsets for training / testing.
Returns
DatasetDict: DatasetDict object containing the dataset of the appropriate section with test + train parts.
"""
if isinstance(splits, str):
splits = [splits]
if isinstance(hf_dataset_or_datasets, dict):
# Structure of the config to read the datasets and their mix
# datasets_mixer:
# - 'dataset1': 0.5
# - 'dataset2': 0.3
# - 'dataset3': 0.2
raw_datasets = _get_dataset_mix(hf_dataset_or_datasets, splits=splits)
elif isinstance(hf_dataset_or_datasets, str):
# e.g. Dataset = "HuggingFaceH4/testing_alpaca_small"
# Note this returns things other than just train/test, which may not be intended
raw_datasets = DatasetDict()
for split in splits:
raw_datasets[split] = load_dataset(
hf_dataset_or_datasets,
hf_dataset_config_name,
split=split,
cache_dir=".",
)
else:
raise ValueError(f"hf_dataset_or_datasets must be a dict or string but is {type(hf_dataset_or_datasets)}")
return raw_datasets
# Adapted from h4/src/h4/data/loading.py
def _get_dataset_mix(dataset_dict: dict, splits: List[str] = None, seed=42) -> "DatasetDict":
"""
Helper function to load dataset mix from dict configuration.
Args:
dataset_dict (dict): Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
splits (Optional[List[str]], optional): Section of the dataset to load, defaults to "train", "test"
Can be one of `train_{ift,rm,rl}` and `test_{ift,rm,rl}`. Our datasets are typically divided into 6 subsets for training / testing.
"""
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_test_datasets = []
fracs = []
for ds, frac in dataset_dict.items():
if frac < 0:
raise ValueError(f"Dataset fraction for dataset {ds} is negative. (= {frac})")
fracs.append(frac)
for split in splits:
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
elif "test" in split:
raw_test_datasets.append(
load_dataset(
ds,
split=split,
)
)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")
if len(raw_train_datasets) > 0:
train_subsets = []
for dataset, frac in zip(raw_train_datasets, fracs):
train_subset = dataset.select(range(int(frac * len(dataset))))
train_subsets.append(train_subset)
raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=seed)
# No subsampling for test datasets to enable fair comparison across models
if len(raw_test_datasets) > 0:
raw_datasets["test"] = concatenate_datasets(raw_test_datasets).shuffle(seed=seed)
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_dict} not recognized with split {split}. Check the dataset has been correctly formatted."
)
return raw_datasets
def dummy_infinite_data_generator(
micro_batch_size: int,
sequence_length: int,
input_pp_rank: int,
output_pp_rank: int,
vocab_size: int,
seed: int,
parallel_context: ParallelContext,
):
def data_generator() -> Generator[Dict[str, Union[torch.Tensor, TensorPointer]], None, None]:
# Random generator
generator = torch.Generator(device="cuda")
# Make sure that TP are synced always
generator.manual_seed(
seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg))
)
while True:
yield {
"input_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"input_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank
else TensorPointer(group_rank=input_pp_rank),
"label_ids": torch.randint(
0,
vocab_size,
(micro_batch_size, sequence_length),
dtype=torch.long,
device="cuda",
generator=generator,
)
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
"label_mask": torch.ones(
micro_batch_size,
sequence_length,
dtype=torch.bool,
device="cuda",
)
if dist.get_rank(parallel_context.pp_pg) == output_pp_rank
else TensorPointer(group_rank=output_pp_rank),
}
return data_generator
# Adapted from https://github.com/huggingface/accelerate/blob/a73898027a211c3f6dc4460351b0ec246aa824aa/src/accelerate/data_loader.py#L781C1-L824C28
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
Note that in case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches
"""
def __init__(self, batch_sampler: BatchSampler, skip_batches: int, dp_size: int):
self.batch_sampler = batch_sampler
# In case of DDP, we skip batches on each rank, so a total of `skip_batches * parallel_context.dp_pg.size()` batches
self.skip_batches = skip_batches // dp_size
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches:
yield samples
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches
def set_tensor_pointers(
input_dict: Dict[str, Union[torch.Tensor, TensorPointer]], group: dist.ProcessGroup, group_rank: int
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
"""Make sure only the group_rank rank has the data, others have TensorPointers."""
return {
k: v if dist.get_rank(group) == group_rank else TensorPointer(group_rank=group_rank)
for k, v in input_dict.items()
}
### CAUSAL LANGUAGE MODELING ###
def clm_process(
raw_dataset: "Dataset",
tokenizer: "PreTrainedTokenizerBase",
text_column_name: str,
dataset_processing_num_proc_per_process: int,
dataset_overwrite_cache: bool,
sequence_length: int,
):
"""Concatenate all texts from raw_dataset and generate chunks of `sequence_length + 1`, where chunks overlap by a single token."""
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/examples/pytorch/language-modeling/run_clm.py#L391-L439
def group_texts(examples: Dict[str, List[np.ndarray]]) -> Dict[str, List[np.ndarray]]:
# Concatenate all texts.
concatenated_examples = {k: np.concatenate(v) for k, v in examples.items()}
total_length = len(concatenated_examples[next(iter(examples.keys()))])
# WARNING: We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
# Split by chunks of sequence_length.
result = {
k: [
t[i : i + sequence_length + 1] for i in range(0, total_length - (sequence_length + 1), sequence_length)
]
for k, t in concatenated_examples.items()
}
return result
def _tokenize_and_group_texts(texts: List[str]) -> Dict[str, List[np.ndarray]]:
tokenized_batch = tokenizer.batch_encode_plus(texts, return_attention_mask=False, return_token_type_ids=False)
tokenized_batch = {k: [np.array(tokenized_texts) for tokenized_texts in v] for k, v in tokenized_batch.items()}
return group_texts(tokenized_batch)
train_dataset = raw_dataset.map(
_tokenize_and_group_texts,
input_columns=text_column_name,
remove_columns=raw_dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
batched=True,
num_proc=dataset_processing_num_proc_per_process,
load_from_cache_file=not dataset_overwrite_cache,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
return train_dataset
# Adapted from: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/data/data_collator.py#L607
@dataclasses.dataclass
class DataCollatorForCLM:
"""
Data collator used for causal language modeling.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[np.ndarray, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
# Cast np.array to torch.Tensor
result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()}
return result
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L763-L835
def get_sampler(
dl_ranks_size: int,
dl_rank: int,
train_dataset: Union["Dataset", torch.utils.data.Dataset],
consumed_train_samples: int,
seed: int = 42,
use_loop_to_round_batch_size: bool = False,
micro_batch_size: Optional[int] = None,
drop_last: Optional[bool] = True,
shuffle: bool = True,
) -> Optional[torch.utils.data.Sampler]:
"""returns sampler that restricts data loading to a subset of the dataset proper to the DP rank"""
# Build the sampler.
# TODO @nouamanetazi: Support group_by_length: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L783-L810
if use_loop_to_round_batch_size:
assert micro_batch_size is not None
# loops at the end back to the beginning of the shuffled samples to make each process have a round multiple of batch_size samples.
sampler = DistributedSamplerWithLoop(
train_dataset,
batch_size=micro_batch_size,
num_replicas=dl_ranks_size,
rank=dl_rank,
seed=seed,
drop_last=drop_last,
)
else:
sampler = DistributedSampler(
train_dataset, num_replicas=dl_ranks_size, rank=dl_rank, seed=seed, drop_last=drop_last, shuffle=shuffle
)
if consumed_train_samples > 0:
sampler = SkipBatchSampler(sampler, skip_batches=consumed_train_samples, dp_size=dl_ranks_size)
return sampler
# Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L837
def get_train_dataloader(
train_dataset: "Dataset",
sequence_length: int,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
micro_batch_size: int,
consumed_train_samples: int,
dataloader_num_workers: int,
seed_worker: int,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
use_loop_to_round_batch_size: bool = False,
) -> DataLoader:
if not isinstance(train_dataset, datasets.Dataset):
raise ValueError(f"training requires a datasets.Dataset, but got {type(train_dataset)}")
# Case of ranks requiring data
if dist.get_rank(parallel_context.pp_pg) in [
input_pp_rank,
output_pp_rank,
]:
train_dataset = train_dataset.with_format(type="numpy", columns=["input_ids"], output_all_columns=True)
# Case of ranks not requiring data. We give them an infinite dummy dataloader
else:
#
assert train_dataset.column_names == ["input_ids"], (
f"Dataset has to have a single column, with `input_ids` as the column name. "
f"Current dataset: {train_dataset}"
)
dataset_length = len(train_dataset)
train_dataset = train_dataset.remove_columns(column_names="input_ids")
assert (
len(train_dataset) == 0
), f"Dataset has to be empty after removing the `input_ids` column. Current dataset: {train_dataset}"
# HACK as if we remove the last column of a train_dataset, it becomes empty and it's number of rows becomes empty.
train_dataset = EmptyInfiniteDataset(length=dataset_length)
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0
data_collator = DataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
# TODO @nouamanetazi: Remove unused columns: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L852
# TODO @nouamanetazi: Support torch.utils.data.IterableDataset: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L855-L872
train_sampler = get_sampler(
dl_rank=dp_rank,
dl_ranks_size=dp_ranks_size,
train_dataset=train_dataset,
seed=seed_worker,
use_loop_to_round_batch_size=use_loop_to_round_batch_size,
micro_batch_size=micro_batch_size,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
)
return DataLoader(
train_dataset,
batch_size=micro_batch_size,
sampler=train_sampler,
collate_fn=data_collator,
drop_last=dataloader_drop_last, # we also drop_last in `clm_process()`
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
# TODO @thomasw21: I'm not sure but this doesn't seem to work at all.
# pin_memory_device="cuda",
)
def get_dataloader_worker_init(dp_rank: int):
"""Creates random states for each worker in order to get different state in each workers"""
def dataloader_worker_init(worker_id):
# Dataloader is TP/PP synced in random states
seed = 2 ** (1 + worker_id) * 3 ** (1 + dp_rank) % (2**32)
set_random_seed(seed)
return dataloader_worker_init
class EmptyInfiniteDataset:
"""Hack as removing all columns from a datasets.Dataset makes the number of rows 0."""
def __init__(self, length: int):
self._length = length
def __getitem__(self, item) -> Dict:
if isinstance(item, int):
return {}
raise NotImplementedError(f"{item} of type {type(item)} is not supported yet")
def __len__(self) -> int:
return self._length
import datetime
import os
from functools import cache, lru_cache
from typing import List, Optional, Tuple
import torch
from packaging import version
from torch import distributed as dist
from torch.distributed import * # noqa
from torch.distributed.distributed_c10d import ProcessGroup
from nanotron.utils import find_free_port
torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0")
Work = dist.Work if torch_version_above_1_13 else dist._Work
default_pg_timeout = datetime.timedelta(minutes=10)
def new_group( # pylint: disable=function-redefined
ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None
) -> ProcessGroup:
if len(ranks) == 0:
raise ValueError("Cannot create a group with not ranks inside it")
return dist.new_group(ranks=ranks, timeout=timeout, backend=backend, pg_options=pg_options)
def reduce_scatter_tensor( # pylint: disable=function-redefined
output: torch.Tensor,
input: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[Work]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
assert (
group.size() > 1
), "You should probably not call `reduce_scatter_tensor` with a single rank, as it copies data over"
if torch_version_above_1_13:
return dist.reduce_scatter_tensor(output=output, input=input, group=group, op=op, async_op=async_op)
else:
# Support pytorch 1.12
return dist._reduce_scatter_base(output=output, input=input, group=group, op=op, async_op=async_op)
def all_gather_into_tensor( # pylint: disable=function-redefined
output_tensor, input_tensor, group: Optional[ProcessGroup] = None, async_op: bool = False
) -> Optional[Work]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
assert (
group.size() > 1
), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over"
if torch_version_above_1_13:
return dist.all_gather_into_tensor(
output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op
)
else:
# Support Pytorch 1.12
return dist.distributed_c10d._all_gather_base(
output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op
)
def reduce_scatter_coalesced(
output_tensor_list: List[torch.Tensor],
input_tensor_lists: List[List[torch.Tensor]],
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
"""
Reduces, then scatters a list of tensors to all processes in a group.
Args:
output_tensor_list (list[Tensor]): Output tensor.
input_tensor_lists (list[list[Tensor]]): List of tensors to reduce and scatter.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
assert len(output_tensor_list) > 0
assert len(input_tensor_lists) == len(output_tensor_list)
device = output_tensor_list[0].device
dtype = output_tensor_list[0].dtype
group_size = len(input_tensor_lists[0])
assert (
group_size > 1
), "You should probably not call `reduce_scatter_coalesced` with a single rank, as it copies data over"
for output_tensor in output_tensor_list:
assert device == output_tensor.device
assert dtype == output_tensor.dtype
for input_tensor_list in input_tensor_lists:
assert len(input_tensor_list) == group_size, f"Expected {len(input_tensor_list)} == {group_size}"
for input_tensor in input_tensor_list:
assert device == input_tensor.device
assert dtype == input_tensor.dtype
output_tensor_buffer = torch._utils._flatten_dense_tensors(output_tensor_list)
input_tensor_buffer_list = [
torch._utils._flatten_dense_tensors(
[input_tensor_list[group_rank] for input_tensor_list in input_tensor_lists]
)
for group_rank in range(group_size)
]
work = dist.reduce_scatter(output_tensor_buffer, input_tensor_buffer_list, op=op, group=group, async_op=async_op)
def update_output():
for original_buffer, reduced_buffer in zip(
output_tensor_list, torch._utils._unflatten_dense_tensors(output_tensor_buffer, output_tensor_list)
):
original_buffer.copy_(reduced_buffer)
if async_op is True:
return work.get_future().then(lambda fut: update_output())
else:
# No need to run `work.wait()` since `dist.reduce_scatter` already waits
update_output()
def all_reduce_coalesced( # pylint: disable=function-redefined
tensors: List[torch.Tensor],
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
if group.size() == 1:
return
return dist.all_reduce_coalesced(tensors, op=op, group=group, async_op=async_op)
def all_gather_coalesced( # pylint: disable=function-redefined
output_tensor_lists: List[List[torch.Tensor]],
input_tensor_list: List[torch.Tensor],
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
"""
`torch` has a deprecated version of this method that doesn't work over NCCL.
All gathers a list of tensors to all processes in a group.
Args:
output_tensor_lists (list[list[Tensor]]): Output tensor.
input_tensor_list (list[Tensor]): List of tensors to all_gather from.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
assert len(output_tensor_lists) > 0
assert len(input_tensor_list) == len(output_tensor_lists)
device = input_tensor_list[0].device
dtype = input_tensor_list[0].dtype
group_size = len(output_tensor_lists[0])
assert (
group_size > 1
), "You should probably not call `all_gather_coalesced` with a single rank, as it copies data over"
for input_tensor in input_tensor_list:
assert device == input_tensor.device
assert dtype == input_tensor.dtype
for output_tensor_list in output_tensor_lists:
assert len(output_tensor_list) == group_size
for output_tensor in output_tensor_list:
assert device == output_tensor.device
assert dtype == output_tensor.dtype
# Invert from `[param_idx][group_rank]` to `[group_rank][param_idx]`
output_tensor_lists = [
[output_tensor_list[group_rank] for output_tensor_list in output_tensor_lists]
for group_rank in range(group_size)
]
input_tensor_buffer = torch._utils._flatten_dense_tensors(input_tensor_list)
output_tensor_buffer_list = [
torch._utils._flatten_dense_tensors(output_tensor_list) for output_tensor_list in output_tensor_lists
]
work = dist.all_gather(output_tensor_buffer_list, input_tensor_buffer, group=group, async_op=async_op)
def update_output():
for original_buffer_list, gathered_buffer_tensor in zip(output_tensor_lists, output_tensor_buffer_list):
for original_buffer, gathered_buffer in zip(
original_buffer_list,
torch._utils._unflatten_dense_tensors(gathered_buffer_tensor, original_buffer_list),
):
original_buffer.copy_(gathered_buffer)
if async_op is True:
return work.get_future().then(lambda fut: update_output())
else:
# No need to run `work.wait()` since `dist.reduce_scatter` already waits
update_output()
# This cache has a speedup of 4 tflops on a 7b model
@cache
def get_global_rank(group: ProcessGroup, group_rank: int) -> int: # pylint: disable=function-redefined
if torch_version_above_1_13:
return dist.get_global_rank(group, group_rank=group_rank)
else:
# Support pytorch 1.12
return dist.distributed_c10d._get_global_rank(group=group, rank=group_rank)
def get_global_ranks(group: ProcessGroup) -> Tuple[int]:
return tuple(sorted((get_global_rank(group, i) for i in range(group.size()))))
# We cache for dp, pp, tp process groups, world group, and tied process group for tied params
@lru_cache
def get_rank(group: Optional[ProcessGroup] = None) -> int: # pylint: disable=function-redefined
"""Similar to `get_rank` except we raise an exception instead of return -1 when current rank is not part of the group"""
result = dist.get_rank(group)
if result == -1:
raise RuntimeError("Can not call `get_rank` on a group in which current process is not a part of")
return result
def initialize_torch_distributed():
"""Initializes torch distributed with the environment variables"""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if torch.cuda.is_available():
# Set the device id.
# `torch.cuda.device_count` should return the number of device on a single node.
# We assume the nodes to be homogeneous (same number of gpus per node)
device_id = local_rank
torch.cuda.set_device(torch.cuda.device(device_id))
backend = "nccl"
else:
# TODO @thomasw21: Maybe figure out a way to do distributed `cpu` training at some point
raise NotImplementedError(f"CUDA was not found: torch.cuda.is_available(): {torch.cuda.is_available()}")
backend = "gloo"
# Call the init process.
port = os.getenv("MASTER_PORT")
if port is None:
port = find_free_port()
else:
port = int(port)
init_method = f"env://localhost:{port}"
dist.init_process_group(
init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout
)
return True
import warnings
from nanotron.fp8.dtypes import DTypes # noqa
from nanotron.fp8.linear import FP8Linear # noqa
from nanotron.fp8.parameter import FP8Parameter # noqa
from nanotron.fp8.tensor import FP8Tensor # noqa
try:
import transformer_engine as te # noqa
import transformer_engine_extensions as tex # noqa
except ImportError:
warnings.warn("Please install Transformer engine for FP8 training!")
import torch
from nanotron.fp8.dtypes import DTypes
FP8_GPU_NAMES = ["h100", "rtx 4090"]
INITIAL_AMAX = 1.0
INITIAL_SCALING_FACTOR = 1.0
# FP8_DTYPES = [torch.fp8e4m3, torch.fp8e5m2]
# FP8E4M3_DTYPE = torch.fp8e4m3
# FP8E5M2_DTYPE = torch.fp8e5m2
FP8_DTYPES = [torch.int8, torch.uint8]
FP8E4M3_DTYPE = torch.int8
FP8E5M2_DTYPE = torch.uint8
DTYPE_TO_FP8_MAX = {DTypes.FP8E4M3: 448.0, DTypes.FP8E5M2: 57344.0, DTypes.KFLOAT16: 65504.0}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment