Commit 61e92904 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, List, Optional, Union
@dataclass
class RandomInit:
std: float
@dataclass
class SpectralMupInit:
"""This is used to initialize the model with spectral mup. Set it to True to use it."""
use_mup: bool
def __post_init__(self):
assert self.use_mup, "Remove `use_mup` if you don't want to use it"
@dataclass
class ExistingCheckpointInit:
"""This is used to initialize from an already existing model (without optimizer, lr_scheduler...)"""
path: Path
@dataclass
class LlamaConfig:
"""Configuration for a LLAMA model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
bos_token_id: int = 1
eos_token_id: int = 2
hidden_act: str = "silu"
hidden_size: int = 4096
initializer_range: float = 0.02
intermediate_size: int = 11008
is_llama_config: bool = True # We use this help differentiate models in yaml/python conversion
max_position_embeddings: int = 2048
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: Optional[int] = None
pad_token_id: Optional[int] = None
pretraining_tp: int = 1
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
def __post_init__(self):
# NOTE: user don't set self._init_method, ModelArgs will set it
# then we only pass LlamaConfig around
self._is_using_mup: bool = False
# self._init_method: Optional[Union[RandomInit, SpectralMupInit, ExistingCheckpointInit]] = None
# for backward compatibility
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
@property
def is_using_mup(self) -> bool:
return self._is_using_mup
@dataclass
class Starcoder2Config:
"""Configuration for a Starcoder2 model
Be careful on having a coherent typing as we use it to reconstruct the model from yaml
"""
activation_function: str = "gelu_pytorch_tanh"
attention_softmax_in_fp32: bool = True # TODO: not used
attn_pdrop: float = 0.1
bos_token_id: int = 49152 # TODO: not used
embd_pdrop: float = 0.1
eos_token_id: int = 49152
global_attn_layers: List[int] = field(default_factory=list)
grouped_query: bool = False # GQA
hidden_size: int = 2048
initializer_range: float = 0.02 # TODO: not used
intermediate_size: Optional[int] = None
is_starcoder2_config: bool = True # We use this help differentiate models in yaml/python conversion
layer_norm_epsilon: float = 1e-05
max_position_embeddings: int = 4096
multi_query: bool = False # MQA
num_attention_heads: int = 16
num_hidden_layers: int = 24
num_kv_heads: Optional[int] = None
resid_pdrop: float = 0.1
rope_theta: Optional[int] = 10000
scale_attention_softmax_in_fp32: bool = True
scale_attn_weights: bool = True
sliding_window_size: Optional[int] = None
use_position_embeddings: bool = False # TODO @nouamane this is not used
use_rotary_embeddings: bool = True
vocab_size: int = 49280
def __post_init__(self):
if self.global_attn_layers is None:
self.global_attn_layers = []
if self.grouped_query:
assert self.num_kv_heads is not None, "num_kv_heads must be specified for grouped query"
assert self.multi_query is False, "Cannot use both multi_query and grouped_query"
if not self.multi_query and not self.grouped_query:
self.multi_query = True
@property
def n_embed(self):
return self.hidden_size
@property
def n_head(self):
return self.num_attention_heads
@property
def n_layer(self):
return self.num_hidden_layers
@property
def n_positions(self):
return self.max_position_embeddings
@property
def n_inner(self):
return self.intermediate_size
NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any]
from dataclasses import dataclass
from typing import Optional
from nanotron.config.utils_config import (
cast_str_to_pipeline_engine,
)
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
@dataclass
class ParallelismArgs:
"""Arguments related to TP/PP/DP
Args:
dp: Number of DP replicas
pp: Number of PP stages
tp: Number of TP replicas
expert_parallel_size: Number of expert parallel replicas (used only for MoEs)
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""
dp: int
pp: int
tp: int
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False
tp_recompute_allgather: bool = True
expert_parallel_size: int = 1
def __post_init__(self):
# Conservative defaults
if self.pp_engine is None:
self.pp_engine = AllForwardAllBackwardPipelineEngine()
if self.tp_mode is None:
self.tp_mode = TensorParallelLinearMode.ALL_REDUCE
if self.tp_linear_async_communication is None:
self.tp_linear_async_communication = False
if isinstance(self.pp_engine, str):
self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine)
if isinstance(self.tp_mode, str):
self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()]
from dataclasses import fields
from enum import Enum, auto
from pathlib import Path
import torch
from nanotron.generation.sampler import SamplerType
from nanotron.parallel.pipeline_parallel.engine import (
AllForwardAllBackwardPipelineEngine,
OneForwardOneBackwardPipelineEngine,
PipelineEngine,
)
from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode
class RecomputeGranularity(Enum):
SELECTIVE = auto()
FULL = auto()
def serialize(data) -> dict:
"""Recursively serialize a nested dataclass to a dict - do some type conversions along the way"""
if data is None:
return None
if not hasattr(data, "__dataclass_fields__"):
return data
result = {}
for field in fields(data):
value = getattr(data, field.name)
if hasattr(value, "__dataclass_fields__"):
result[field.name] = serialize(value)
elif isinstance(value, Path):
result[field.name] = str(value)
elif isinstance(value, PipelineEngine):
result[field.name] = cast_pipeline_engine_to_str(value)
elif isinstance(value, TensorParallelLinearMode):
result[field.name] = value.name
elif isinstance(value, RecomputeGranularity):
result[field.name] = value.name
elif isinstance(value, SamplerType):
result[field.name] = value.name
elif isinstance(value, torch.dtype):
result[field.name] = dtype_to_str[value]
elif isinstance(value, (list, tuple)):
result[field.name] = [serialize(v) for v in value]
elif isinstance(value, dict) and not value:
result[field.name] = None # So we can serialize empty dicts without issue with `datasets` in particular
else:
result[field.name] = value
return result
str_to_dtype = {
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"bool": torch.bool,
}
dtype_to_str = {
torch.float32: "float32",
torch.float64: "float64",
torch.complex64: "complex64",
torch.complex128: "complex128",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.uint8: "uint8",
torch.int8: "int8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
torch.bool: "bool",
}
def cast_str_to_torch_dtype(str_dtype: str):
if str_dtype in str_to_dtype:
return str_to_dtype[str_dtype]
else:
raise ValueError(f"dtype should be a string selected in {str_to_dtype.keys()} and not {str_dtype}")
def cast_str_to_pipeline_engine(str_pp_engine: str) -> PipelineEngine:
if str_pp_engine == "afab":
return AllForwardAllBackwardPipelineEngine()
elif str_pp_engine == "1f1b":
return OneForwardOneBackwardPipelineEngine()
else:
raise ValueError(f"pp_engine should be a string selected in ['afab', '1f1b'] and not {str_pp_engine}")
def cast_pipeline_engine_to_str(pp_engine: PipelineEngine) -> str:
if isinstance(pp_engine, AllForwardAllBackwardPipelineEngine):
return "afab"
elif isinstance(pp_engine, OneForwardOneBackwardPipelineEngine):
return "1f1b"
else:
raise ValueError(
f"pp_engine should be aan instance of AllForwardAllBackwardPipelineEngine or OneForwardOneBackwardPipelineEngine, not {type(pp_engine)}"
)
import platform
from packaging.version import Version, parse
CHECKPOINT_VERSION = Version("1.4")
PY_VERSION = parse(platform.python_version())
#### FOR SERIALIZATION ####
CHECKPOINT_FILE_NAME = "checkpoint_metadata.json"
MODEL_CONFIG_FILE_NAME = "model_config.json"
import dataclasses
from typing import Dict, List, Union
import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
@dataclasses.dataclass
class NanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""
sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)
# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
return result
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
)
from nanotron.parallel import ParallelContext
from torch.utils.data import DataLoader
logger = logging.get_logger(__name__)
def build_nanoset_dataloader(
dataset,
sequence_length: int,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
micro_batch_size: int,
dataloader_num_workers: int,
consumed_train_samples: int = 0,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
) -> DataLoader:
# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]:
dataset_length = len(dataset)
dataset = EmptyInfiniteDataset(length=dataset_length)
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0
data_collator = NanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)
# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
sampler = get_sampler(
train_dataset=dataset,
dl_ranks_size=dp_ranks_size,
dl_rank=dp_rank,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
shuffle=False,
)
return DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
collate_fn=data_collator,
drop_last=dataloader_drop_last,
num_workers=dataloader_num_workers,
pin_memory=dataloader_pin_memory,
worker_init_fn=get_dataloader_worker_init(dp_rank=dp_rank),
)
import os
import warnings
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
from datatrove.utils.dataset import DatatroveFolderDataset
from nanotron import logging
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.logging import log_rank
from numba import jit
logger = logging.get_logger(__name__)
class Nanoset(torch.utils.data.Dataset):
"""
The Nanoset dataset
Args:
dataset_folders (List[str]): List of folders with tokenized datasets
dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__
sequence_length (int): Sequence length of the built samples
token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise
train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size
"""
def __init__(
self,
dataset_folders: List[str],
sequence_length: int,
token_size: int,
train_split_num_samples: int,
dataset_weights: Union[List[float], None] = None,
random_seed: int = 1234,
) -> None:
# Checks
if isinstance(dataset_folders, str):
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
dataset_folders = [dataset_folders]
# Init
self.dataset_folders = dataset_folders
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.random_seed = random_seed
self.datatrove_datasets = []
for dataset_folder in self.dataset_folders:
self.datatrove_datasets.append(
DatatroveFolderDataset(
folder_path=dataset_folder,
filename_pattern=os.path.join(dataset_folder, "*.ds"),
seq_len=sequence_length,
recursive=False,
token_size=token_size,
shuffle=True,
)
)
# Build Nanoset Index
## To build the index we need the length of each dataset
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
## Set dataset weights
if (
dataset_weights is None
): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch
self.dataset_weights = normalize(self.dataset_lengths)
else:
self.dataset_weights = normalize(dataset_weights)
assert len(dataset_folders) == len(
self.dataset_weights
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
self.dataset_index, self.dataset_sample_index = self.build_nanoset_index()
self.print_nanoset_info()
def __len__(self) -> int:
"""
Returns:
int: The number of samples of the Nanoset
"""
return len(self.dataset_index)
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"""
Returns sequence_length + 1 tokens from the memmap dataset
Args:
idx (int): The index into the dataset
Returns:
Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary
"""
dataset = self.dataset_index[idx]
dataset_sample = self.dataset_sample_index[idx]
return self.datatrove_datasets[dataset][dataset_sample]
def build_nanoset_index(self) -> np.ndarray:
"""
Build dataset index and dataset sample index
"""
# Compute samples per epoch and number of epochs
samples_per_epoch = sum(self.dataset_lengths)
num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1
# Build the dataset indexes for 1 epoch
dataset_index, dataset_sample_index = build_nanoset_index_helper(
n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths
)
# Shuffle the indexes the same way
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_index)
numpy_random_state = np.random.RandomState(self.random_seed)
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
# Just keep the necessary samples
dataset_index = dataset_index[: self.train_split_num_samples]
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]
return dataset_index, dataset_sample_index
def print_nanoset_info(self):
log_rank(f"> Total number of samples: {len(self)}", logger=logger, level=logging.INFO, rank=0)
log_rank(
f"> Total number of tokens: {len(self) * self.sequence_length}", logger=logger, level=logging.INFO, rank=0
)
# Print samples from each dataset + weight
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
for index, sample_count in enumerate(dataset_sample_count):
log_rank(
f"> Total number of samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
logger=logger,
level=logging.INFO,
rank=0,
)
@jit(nopython=True, cache=True)
def build_nanoset_index_helper(
n_samples: int, weights: np.ndarray, dataset_sizes: List[int]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given multiple datasets and a weighting array, build samples indexes
such that it follows those weights
"""
# Create empty arrays for dataset indices and dataset sample indices
dataset_index = np.empty((n_samples,), dtype="uint")
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples
# Initialize buffer for number of samples used for each dataset
current_samples = np.zeros((len(weights),), dtype="long")
# Iterate over all samples
for sample_idx in range(n_samples):
# Convert sample index to float for comparison against weights
sample_idx_float = max(sample_idx, 1.0)
# Find the dataset with the highest error
errors = weights * sample_idx_float - current_samples
max_error_index = np.argmax(errors)
# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]
# Update the total samples for the selected dataset
current_samples[max_error_index] += 1
return dataset_index, dataset_sample_index
from typing import List
import numpy as np
def normalize(weights: List[float]) -> List[np.array]:
"""
Normalize elements of a list
Args:
weights (List[float]): The weights
Returns:
List[numpy.array]: The normalized weights
"""
w = np.array(weights, dtype=np.float64)
w_sum = np.sum(w)
w = w / w_sum
return w
def count_dataset_indexes(dataset_idx: np.ndarray, n_datasets: int):
counts = []
for dataset in range(n_datasets):
counts.append(np.count_nonzero(dataset_idx == dataset))
return counts
This diff is collapsed.
import datetime
import os
from functools import cache, lru_cache
from typing import List, Optional, Tuple
import torch
from packaging import version
from torch import distributed as dist
from torch.distributed import * # noqa
from torch.distributed.distributed_c10d import ProcessGroup
from nanotron.utils import find_free_port
torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0")
Work = dist.Work if torch_version_above_1_13 else dist._Work
default_pg_timeout = datetime.timedelta(minutes=10)
def new_group( # pylint: disable=function-redefined
ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None
) -> ProcessGroup:
if len(ranks) == 0:
raise ValueError("Cannot create a group with not ranks inside it")
return dist.new_group(ranks=ranks, timeout=timeout, backend=backend, pg_options=pg_options)
def reduce_scatter_tensor( # pylint: disable=function-redefined
output: torch.Tensor,
input: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[Work]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
assert (
group.size() > 1
), "You should probably not call `reduce_scatter_tensor` with a single rank, as it copies data over"
if torch_version_above_1_13:
return dist.reduce_scatter_tensor(output=output, input=input, group=group, op=op, async_op=async_op)
else:
# Support pytorch 1.12
return dist._reduce_scatter_base(output=output, input=input, group=group, op=op, async_op=async_op)
def all_gather_into_tensor( # pylint: disable=function-redefined
output_tensor, input_tensor, group: Optional[ProcessGroup] = None, async_op: bool = False
) -> Optional[Work]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
assert (
group.size() > 1
), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over"
if torch_version_above_1_13:
return dist.all_gather_into_tensor(
output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op
)
else:
# Support Pytorch 1.12
return dist.distributed_c10d._all_gather_base(
output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op
)
def reduce_scatter_coalesced(
output_tensor_list: List[torch.Tensor],
input_tensor_lists: List[List[torch.Tensor]],
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
"""
Reduces, then scatters a list of tensors to all processes in a group.
Args:
output_tensor_list (list[Tensor]): Output tensor.
input_tensor_lists (list[list[Tensor]]): List of tensors to reduce and scatter.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
assert len(output_tensor_list) > 0
assert len(input_tensor_lists) == len(output_tensor_list)
device = output_tensor_list[0].device
dtype = output_tensor_list[0].dtype
group_size = len(input_tensor_lists[0])
assert (
group_size > 1
), "You should probably not call `reduce_scatter_coalesced` with a single rank, as it copies data over"
for output_tensor in output_tensor_list:
assert device == output_tensor.device
assert dtype == output_tensor.dtype
for input_tensor_list in input_tensor_lists:
assert len(input_tensor_list) == group_size, f"Expected {len(input_tensor_list)} == {group_size}"
for input_tensor in input_tensor_list:
assert device == input_tensor.device
assert dtype == input_tensor.dtype
output_tensor_buffer = torch._utils._flatten_dense_tensors(output_tensor_list)
input_tensor_buffer_list = [
torch._utils._flatten_dense_tensors(
[input_tensor_list[group_rank] for input_tensor_list in input_tensor_lists]
)
for group_rank in range(group_size)
]
work = dist.reduce_scatter(output_tensor_buffer, input_tensor_buffer_list, op=op, group=group, async_op=async_op)
def update_output():
for original_buffer, reduced_buffer in zip(
output_tensor_list, torch._utils._unflatten_dense_tensors(output_tensor_buffer, output_tensor_list)
):
original_buffer.copy_(reduced_buffer)
if async_op is True:
return work.get_future().then(lambda fut: update_output())
else:
# No need to run `work.wait()` since `dist.reduce_scatter` already waits
update_output()
def all_reduce_coalesced( # pylint: disable=function-redefined
tensors: List[torch.Tensor],
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()
if group.size() == 1:
return
return dist.all_reduce_coalesced(tensors, op=op, group=group, async_op=async_op)
def all_gather_coalesced( # pylint: disable=function-redefined
output_tensor_lists: List[List[torch.Tensor]],
input_tensor_list: List[torch.Tensor],
group: Optional[ProcessGroup] = None,
async_op: bool = False,
) -> Optional[torch._C.Future]:
"""
`torch` has a deprecated version of this method that doesn't work over NCCL.
All gathers a list of tensors to all processes in a group.
Args:
output_tensor_lists (list[list[Tensor]]): Output tensor.
input_tensor_list (list[Tensor]): List of tensors to all_gather from.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group.
"""
assert len(output_tensor_lists) > 0
assert len(input_tensor_list) == len(output_tensor_lists)
device = input_tensor_list[0].device
dtype = input_tensor_list[0].dtype
group_size = len(output_tensor_lists[0])
assert (
group_size > 1
), "You should probably not call `all_gather_coalesced` with a single rank, as it copies data over"
for input_tensor in input_tensor_list:
assert device == input_tensor.device
assert dtype == input_tensor.dtype
for output_tensor_list in output_tensor_lists:
assert len(output_tensor_list) == group_size
for output_tensor in output_tensor_list:
assert device == output_tensor.device
assert dtype == output_tensor.dtype
# Invert from `[param_idx][group_rank]` to `[group_rank][param_idx]`
output_tensor_lists = [
[output_tensor_list[group_rank] for output_tensor_list in output_tensor_lists]
for group_rank in range(group_size)
]
input_tensor_buffer = torch._utils._flatten_dense_tensors(input_tensor_list)
output_tensor_buffer_list = [
torch._utils._flatten_dense_tensors(output_tensor_list) for output_tensor_list in output_tensor_lists
]
work = dist.all_gather(output_tensor_buffer_list, input_tensor_buffer, group=group, async_op=async_op)
def update_output():
for original_buffer_list, gathered_buffer_tensor in zip(output_tensor_lists, output_tensor_buffer_list):
for original_buffer, gathered_buffer in zip(
original_buffer_list,
torch._utils._unflatten_dense_tensors(gathered_buffer_tensor, original_buffer_list),
):
original_buffer.copy_(gathered_buffer)
if async_op is True:
return work.get_future().then(lambda fut: update_output())
else:
# No need to run `work.wait()` since `dist.reduce_scatter` already waits
update_output()
# This cache has a speedup of 4 tflops on a 7b model
@cache
def get_global_rank(group: ProcessGroup, group_rank: int) -> int: # pylint: disable=function-redefined
if torch_version_above_1_13:
return dist.get_global_rank(group, group_rank=group_rank)
else:
# Support pytorch 1.12
return dist.distributed_c10d._get_global_rank(group=group, rank=group_rank)
def get_global_ranks(group: ProcessGroup) -> Tuple[int]:
return tuple(sorted((get_global_rank(group, i) for i in range(group.size()))))
# We cache for dp, pp, tp process groups, world group, and tied process group for tied params
@lru_cache
def get_rank(group: Optional[ProcessGroup] = None) -> int: # pylint: disable=function-redefined
"""Similar to `get_rank` except we raise an exception instead of return -1 when current rank is not part of the group"""
result = dist.get_rank(group)
if result == -1:
raise RuntimeError("Can not call `get_rank` on a group in which current process is not a part of")
return result
def initialize_torch_distributed():
"""Initializes torch distributed with the environment variables"""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if torch.cuda.is_available():
# Set the device id.
# `torch.cuda.device_count` should return the number of device on a single node.
# We assume the nodes to be homogeneous (same number of gpus per node)
device_id = local_rank
torch.cuda.set_device(torch.cuda.device(device_id))
backend = "nccl"
else:
# TODO @thomasw21: Maybe figure out a way to do distributed `cpu` training at some point
raise NotImplementedError(f"CUDA was not found: torch.cuda.is_available(): {torch.cuda.is_available()}")
backend = "gloo"
# Call the init process.
port = os.getenv("MASTER_PORT")
if port is None:
port = find_free_port()
else:
port = int(port)
init_method = f"env://localhost:{port}"
dist.init_process_group(
init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout
)
return True
import warnings
from nanotron.fp8.dtypes import DTypes # noqa
from nanotron.fp8.linear import FP8Linear # noqa
from nanotron.fp8.parameter import FP8Parameter # noqa
from nanotron.fp8.tensor import FP8Tensor # noqa
try:
import transformer_engine as te # noqa
import transformer_engine_extensions as tex # noqa
except ImportError:
warnings.warn("Please install Transformer engine for FP8 training!")
import torch
from nanotron.fp8.dtypes import DTypes
FP8_GPU_NAMES = ["h100", "rtx 4090"]
INITIAL_AMAX = 1.0
INITIAL_SCALING_FACTOR = 1.0
# FP8_DTYPES = [torch.fp8e4m3, torch.fp8e5m2]
# FP8E4M3_DTYPE = torch.fp8e4m3
# FP8E5M2_DTYPE = torch.fp8e5m2
FP8_DTYPES = [torch.int8, torch.uint8]
FP8E4M3_DTYPE = torch.int8
FP8E5M2_DTYPE = torch.uint8
DTYPE_TO_FP8_MAX = {DTypes.FP8E4M3: 448.0, DTypes.FP8E5M2: 57344.0, DTypes.KFLOAT16: 65504.0}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment