"doc/vscode:/vscode.git/clone" did not exist on "c35662214779d0ab0ad6cefbc8349776779f1c23"
Commit 71e79847 authored by chenzk's avatar chenzk
Browse files

v1.0.3

parents
Pipeline #2034 canceled with stages
from pathlib import Path
import torch
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.random import RandomStates
def save_random_states(
random_states: RandomStates,
parallel_context: ParallelContext,
root_folder: Path,
):
"""All processes save their own random state"""
filename = (
root_folder
/ "random"
/ f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"
)
filename.parent.mkdir(exist_ok=True, parents=True)
# TODO @thomasw21: That's annothing but this actually uses pickle, we might need to change that for something else
torch.save(random_states, filename)
def load_random_states(parallel_context: ParallelContext, root_folder: Path):
# TODO @thomasw21: This basically assumes that we have exactly the same topology as the one we used when saving.
filename = (
root_folder
/ "random"
/ f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"
)
# TODO @thomasw21: That's annothing but this actually uses pickle, we might need to change that for something else
state = torch.load(filename)
return state
import re
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import SlicesPair
from nanotron.serialize.metadata import TensorMetadata
class ObjectType(Enum):
MODEL = "model"
OPTIMIZER = "optimizer"
LR_SCHEDULER = "lr_scheduler"
def get_exp_tp_pp_rank_and_size_from(
world_rank: int, parallel_context: ParallelContext
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
result = parallel_context.get_local_ranks(world_rank=world_rank)
return (
(result[0], parallel_context.expert_pg.size()),
(result[3], parallel_context.tp_pg.size()),
(result[1], parallel_context.pp_pg.size()),
)
def get_path(
tensor_name: str,
type: ObjectType,
exp_tp_pp_rank_and_size: Tuple[Tuple[int, int], Tuple[int, int]],
is_expert_sharded: bool,
prefix: Optional[Path] = None,
) -> List[str]:
suffix = tensor_name.split(".")
suffix_path, suffix_name = suffix[:-1], suffix[-1]
if exp_tp_pp_rank_and_size:
# We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided
(exp_rank, exp_size), (tp_rank, tp_size), (pp_rank, pp_size) = exp_tp_pp_rank_and_size
if not is_expert_sharded or exp_size == 1:
suffix_name = (
f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}.safetensors"
)
else:
# We only show exp_rank if tensor is exp_sharded and exp_size > 1
suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors"
else:
suffix_name = f"{type.value}_{suffix_name}.safetensors"
suffix_path.append(suffix_name)
if prefix is None:
return suffix_path
else:
return prefix.joinpath(*suffix_path)
def extract_tp_pp_rank_from_shard_path(shard_path: Path):
pattern = r"pp-rank-(\d+)-of-\d+_tp-rank-(\d+)-of-\d+"
match = re.search(pattern, str(shard_path))
pp_rank, tp_rank = match.groups()
return pp_rank, tp_rank
def merge_and_shard_tp_tensors(
buffer: torch.Tensor,
unsharded_buffer: torch.Tensor,
shards_and_slices_maps: List[Tuple[torch.Tensor, Tuple[SlicesPair, ...]]],
shard_metadata: TensorMetadata,
) -> torch.Tensor:
for shard, slices_pairs in shards_and_slices_maps:
for slices_pair in slices_pairs:
local_slices = slices_pair.local_slices
global_slices = slices_pair.global_slices
unsharded_buffer[global_slices] = shard[local_slices]
for slices_pair in shard_metadata.local_global_slices_pairs:
local_slices = slices_pair.local_slices
global_slices = slices_pair.global_slices
buffer[local_slices] = unsharded_buffer[global_slices]
return buffer
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import dacite
import torch
from packaging.version import Version
from safetensors.torch import safe_open, save_file
from torch import nn
from tqdm import tqdm
from nanotron import distributed as dist
from nanotron import logging
from nanotron.constants import CHECKPOINT_VERSION
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter, ShardedInfo, SlicesPair
from nanotron.serialize.metadata import CheckpointMetadata, TensorMetadata, load_meta
from nanotron.serialize.utils import (
ObjectType,
extract_tp_pp_rank_from_shard_path,
get_exp_tp_pp_rank_and_size_from,
get_path,
merge_and_shard_tp_tensors,
)
logger = logging.get_logger(__name__)
def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path):
root_folder = root_folder / "model"
# We save only `dist.get_rank(parallel_context.dp_pg) == 0`
# TODO @thomasw21: Figure how this works with Zero-3
if dist.get_rank(parallel_context.dp_pg) != 0:
return
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
# We chunk everything by `tp_world_size` in order to make sure that we gather all the weights into a single device before saving it
for name, param_or_buffer in tqdm(model.state_dict().items(), desc="Saving weights"):
# exp_rank=0 saves all weights whereas exp_rank>0 save only MLP weights
if dist.get_rank(parallel_context.expert_pg) != 0:
if "experts" not in name:
continue
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
# TODO @nouamanetazi: Handle buffers
param = None
if isinstance(param, NanotronParameter):
metadata = {}
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
# Only the first rank of the group of the tied weights saves weights
# TODO @thomasw21: We could rotate in order to balance the load.
if dist.get_rank(group) != 0:
continue
else:
base_name = name
if param.is_sharded:
sharded_info: ShardedInfo = param.get_sharded_info()
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)),
parallel_context=parallel_context,
)
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
metadata = TensorMetadata(
version=CHECKPOINT_VERSION,
local_global_slices_pairs=sharded_info.local_global_slices_pairs,
unsharded_shape=sharded_info.unsharded_shape,
).to_str_dict()
else:
exp_tp_pp_rank_and_size = None
is_expert_sharded = False
path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
is_expert_sharded=is_expert_sharded,
prefix=root_folder,
)
path.parent.mkdir(exist_ok=True, parents=True)
try:
tensors = {"data": param_or_buffer}
save_file(tensors=tensors, filename=path, metadata=metadata)
except Exception as e:
log_rank(
f"Error saving {path} with {metadata}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
else:
raise NotImplementedError("Parameters are required to be NanotronParameter")
class CheckpointVersionFromShardFileException(Exception):
"""Raise when loading checkpoint version from shard file fails"""
def read_checkpoint_version_from_shard_file(param_save_path: Path) -> Version:
try:
with safe_open(param_save_path, framework="pt", device=str("cpu")) as fi:
param_metadata = fi.metadata()
param_metadata = TensorMetadata.from_str_dict(param_metadata)
checkpoint_version = param_metadata.version
except (dacite.exceptions.MissingValueError, dacite.exceptions.UnexpectedDataError):
raise CheckpointVersionFromShardFileException()
return checkpoint_version
def read_checkpoint_version_from_meta(parallel_context: ParallelContext, root_folder: Path) -> Version:
checkpoint_metadata: CheckpointMetadata = load_meta(parallel_context=parallel_context, root_folder=root_folder)
checkpoint_version = checkpoint_metadata.version
return checkpoint_version
def get_checkpoint_version(parallel_context, root_folder, param_save_path: Path) -> Version:
try:
checkpoint_version = read_checkpoint_version_from_shard_file(param_save_path=param_save_path)
except CheckpointVersionFromShardFileException:
log_rank(
f"Failed to read checkpoint version from shard file {param_save_path}, reading from meta file.",
logger=logger,
level=logging.ERROR,
rank=0,
)
checkpoint_version = read_checkpoint_version_from_meta(
parallel_context=parallel_context, root_folder=root_folder
)
return checkpoint_version
def load_sharded_param_latest(
param_or_buffer: torch.Tensor,
sharded_info: ShardedInfo,
shards_path: List[Path],
param_shard_metadata: Optional[Dict] = None,
):
checkpoint_unsharded_shape = None
shards_and_slices_maps: List[Tuple[torch.Tensor, Tuple[SlicesPair, ...]]] = []
for shard_path in shards_path:
with safe_open(shard_path, framework="pt", device=str(param_or_buffer.device)) as fi:
# TODO @thomasw21: Choose only a slice if we switch the TP topology
param_metadata = fi.metadata()
param_metadata = TensorMetadata.from_str_dict(param_metadata)
shards_and_slices_maps.append((fi.get_tensor("data"), param_metadata.local_global_slices_pairs))
if checkpoint_unsharded_shape is None:
checkpoint_unsharded_shape = param_metadata.unsharded_shape
else:
assert checkpoint_unsharded_shape == param_metadata.unsharded_shape
if param_shard_metadata is not None:
# NOTE: store how does model parameter are sharded
# so that we can shard optimizer checkpoints in this way
pp_rank, tp_rank = extract_tp_pp_rank_from_shard_path(shard_path)
param_shard_metadata[(pp_rank, tp_rank)] = param_metadata
assert checkpoint_unsharded_shape is not None
# TODO @thomasw21: Interestingly enough we don't actually need to instantiate the entire model at all.
unsharded_tensor = torch.empty(checkpoint_unsharded_shape, device=param_or_buffer.device)
merge_and_shard_tp_tensors(
buffer=param_or_buffer,
unsharded_buffer=unsharded_tensor,
shards_and_slices_maps=shards_and_slices_maps,
shard_metadata=sharded_info,
)
return param_shard_metadata
def load_weights(
model: nn.Module,
parallel_context: ParallelContext,
root_folder: Path,
filtered_state_dict: Optional[Dict[str, Any]] = None,
):
"""Load weights from a checkpoint
Args:
model: model to load weights into
parallel_context: distributed process groups
root_folder: root folder of the checkpoint
filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict()
"""
param_root_folder = root_folder / "model"
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
checkpoint_version: Optional[Version] = None
filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict()
param_shard_metadata = {}
for name, param_or_buffer in tqdm(
filtered_state_dict.items(), disable=dist.get_rank(parallel_context.world_pg) != 0, desc="Loading weights"
):
# NOTE: extract how does the current model parameter are sharded
# so that we can load optimizer checkpoints in this way
param_shard_metadata[name] = {}
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
param = None
if isinstance(param, NanotronParameter):
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
base_name = name
if param.is_sharded:
sharded_info = param.get_sharded_info()
if param.is_tied:
# When params are tied only the first rank of tied param group stores weights (see save_weights)
group = parallel_context.world_ranks_to_pg[tied_info.global_ranks]
group_rank = 0
else:
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
group_rank = dist.get_rank(group)
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context
)
# TODO @nouamane: do we consider exp_size=1 expert_sharded?
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
else:
exp_tp_pp_rank_and_size = None
is_expert_sharded = False
path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
prefix=param_root_folder,
is_expert_sharded=is_expert_sharded,
)
if path.exists():
with safe_open(path, framework="pt", device=str(param.device)) as fi:
# TODO @thomasw21: Choose only a slice if we switch the TP topology
param_or_buffer[:] = fi.get_tensor("data")
elif not path.parent.exists():
raise ValueError(
f"Checkpoint is empty or checkpoint structure is not matching the model architecture."
f"Couldn't find folder {path.parent} in checkpoint at {root_folder}"
)
else:
# Let's assume that the topology changed and the param is sharded.
# We search for all the files from the shards, concatenate the "unsharded" tensor
# and load the specific shard we're interested in.
if not param.is_sharded:
raise ValueError(
f"`{name}` is not a sharded parameter. It's possible you were expecting {path} to exist."
)
# TODO @thomasw21: Make so that we don't need to code this logic somewhere else than in `get_path`
sharded_info = param.get_sharded_info()
suffix = base_name.rsplit(".", 1)[-1]
shards_path = list(path.parent.glob(f"{ObjectType.MODEL.value}_{suffix}*.safetensors"))
if len(shards_path) <= 0:
raise ValueError(
f"Could not find any shards {ObjectType.MODEL.value}_{suffix}*.safetensors in {path.parent}."
f"If you notice `.safetensors` in the middle of the name of some of the checkpoints files. You need to run `scripts/fix_checkpoint_bad_naming.py`."
)
if checkpoint_version is None:
checkpoint_version = get_checkpoint_version(
parallel_context, root_folder, param_save_path=shards_path[0]
)
else:
current_checkpoint_version = None
try:
current_checkpoint_version = read_checkpoint_version_from_shard_file(
param_save_path=shards_path[0]
)
except CheckpointVersionFromShardFileException:
# The checkpoint version is read from the meta file
current_checkpoint_version = checkpoint_version
finally:
assert (
current_checkpoint_version == checkpoint_version
), f"Checkpoint version mismatch at {shards_path[0]}."
if checkpoint_version <= CHECKPOINT_VERSION:
load_sharded_param_latest(
param_or_buffer=param_or_buffer,
sharded_info=sharded_info,
shards_path=shards_path,
param_shard_metadata=param_shard_metadata[name],
)
else:
raise ValueError(f"Unsupported checkpoint version {checkpoint_version}")
else:
raise NotImplementedError(f"Parameters {param} should be a NanotronParameter")
return param_shard_metadata
def get_checkpoint_paths_list(
model: nn.Module,
parallel_context: ParallelContext,
root_folder: Path,
only_list_folders: bool = False,
only_list_current_process: bool = True,
filtered_state_dict: Optional[Dict[str, Any]] = None,
):
"""Return the list of all the files or folders created/accessed by the current process in a checkpoint
Args:
model: model to load weights into
parallel_context: distributed process groups
root_folder: root folder of the checkpoint
filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict()
"""
param_root_folder = root_folder / "model"
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
paths = []
filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict()
for name in tqdm(
filtered_state_dict.values(),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Listing checkpoint paths",
):
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
param = None
if isinstance(param, NanotronParameter) or not only_list_current_process:
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
base_name = name
if param.is_sharded:
sharded_info = param.get_sharded_info()
if param.is_tied:
# When params are tied only the first rank of tied param group stores weights (see save_weights)
group = parallel_context.world_ranks_to_pg[tied_info.global_ranks]
group_rank = 0
else:
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
group_rank = dist.get_rank(group)
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context
)
else:
exp_tp_pp_rank_and_size = None
if only_list_folders:
paths.append(param_root_folder.joinpath(base_name.split(".")[:-1]))
else:
paths.append(
get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
prefix=param_root_folder,
)
)
return paths
import datetime
import gc
import json
import os
import shutil
import time
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
Config,
DatasetStageArgs,
ExistingCheckpointInit,
ParallelismArgs,
RandomInit,
SpectralMupInit,
get_config_from_file,
)
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.dataloader import sanity_check_dataloader
from nanotron.helpers import (
_vocab_size_with_padding,
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
get_profiler,
init_optimizer_and_grad_accumulator,
init_random_states,
log_throughput,
lr_scheduler_builder,
)
from nanotron.logging import (
LoggerWriter,
LogItem,
human_format,
log_memory,
log_rank,
set_ranks_logging_level,
)
from nanotron.models import NanotronModel, build_model
from nanotron.models.base import check_model_has_grad
from nanotron.models.llama import LlamaForTraining, RotaryEmbedding
from nanotron.models.starcoder2 import Starcoder2ForTraining
from nanotron.optim.clip_grads import clip_grad_norm
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
PipelineEngine,
TensorPointer,
)
from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear
from nanotron.parallel.tied_parameters import (
create_pg_for_tied_weights,
get_tied_id_to_param,
sync_tied_weights_gradients,
tie_parameters,
)
from nanotron.random import set_random_seed
from nanotron.s3_checkpoints import S3Mover, check_path_is_local
from nanotron.sanity_checks import (
after_optim_step_sanity_checks,
after_tbi_sanity_checks,
before_optim_step_sanity_checks,
before_tbi_sanity_checks,
)
from nanotron.scaling.parametrization import ParametrizationMethod
from nanotron.serialize import (
load_lr_scheduler,
load_meta,
load_weights,
parse_ckpt_path,
save,
save_random_states,
)
from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata
from nanotron.serialize.optimizer import load_optimizer, state_dict_to_device
logger = logging.get_logger(__name__)
# Reduce the logging noise from torch.distributed when creating new process groups
dist_logger = logging.get_logger(dist.dist.__name__)
dist_logger.setLevel(logging.WARNING)
CONFIG_TO_MODEL_CLASS = {
"LlamaConfig": LlamaForTraining,
"Starcoder2Config": Starcoder2ForTraining,
}
try:
import wandb
except ImportError:
wandb = None
class DistributedTrainer:
def __init__(
self,
config_or_config_file: Union[Config, str],
config_class: Type[Config] = Config,
model_config_class: Optional[Type] = None,
model_class: Type[NanotronModel] = None,
):
"""
Nanotron's distributed trainer.
Args:
config_or_config_file: Either a `Config` object or a path to a YAML file containing the config.
config_class: The `Config` class to use.
model_config_class: The `ModelConfig` class to use (for example `LlamaConfig`). Defaults to `None` which will use the model config class defined in the config.
model_class: The `NanotronModel` class to use (for example `LlamaForTraining`). Defaults to `None` which will use the model class defined in the config.
"""
super().__init__()
self.config = get_config_from_file(
config_or_config_file, config_class=config_class, model_config_class=model_config_class
)
self.model_config = self.config.model.model_config
if model_class is not None:
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class
########################################
## We start with setting up loggers and process groups
########################################
# Initialise all process groups
self.parallel_context = ParallelContext(
tensor_parallel_size=self.config.parallelism.tp,
pipeline_parallel_size=self.config.parallelism.pp,
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)
self.pre_init()
# Set log levels
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)
# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
log_throughput(self.config, self.parallel_context)
########################################
## Setting up our model, optimizers, schedulers, etc.
########################################
# Set random states
set_random_seed(self.config.general.seed)
# Init model and build on pp ranks
self.random_states = init_random_states(
parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg
)
self.model = self.init_model() # Defines self.model
print("self.model:", self.model)
self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
)
# TODO: find a better way to handle this
parametrization_method = (
ParametrizationMethod.SPECTRAL_MUP
if hasattr(self.config.model.init_method, "use_mup") and self.config.model.init_method.use_mup
else ParametrizationMethod.STANDARD
)
# Init optimizer
self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator(
parametrization_method=parametrization_method,
model=self.model,
optimizer_args=self.config.optimizer,
parallel_context=self.parallel_context,
)
if self.init_checkpoint_path is not None:
load_optimizer(
optimizer=self.optimizer,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
param_shard_metadata=self.param_shard_metadata,
model=self.unwrapped_model,
map_location="cpu",
)
# Init learning rate scheduler
self.lr_scheduler = lr_scheduler_builder(
optimizer=self.optimizer,
lr_scheduler_args=self.config.optimizer.learning_rate_scheduler,
total_training_steps=self.config.tokens.train_steps,
)
if self.init_checkpoint_path is not None:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
is_zero=self.config.optimizer.zero_stage,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
# Define iteration start state
if self.init_checkpoint_path is not None:
checkpoint_metadata = load_meta(
parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
)
assert isinstance(checkpoint_metadata.metas, TrainingMetadata)
log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0)
self.metadata: TrainingMetadata = checkpoint_metadata.metas
# NOTE: we should not change data stages
assert (
self.config.tokens.train_steps > self.metadata.last_train_step
), f"Loaded checkpoint has already trained {self.metadata.last_train_step} batches, you need to specify a higher `config.tokens.train_steps`"
else:
data_stages = [
DataStageMetadata(
name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0
)
for stage in self.config.data_stages
]
self.metadata: TrainingMetadata = TrainingMetadata(
consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages
)
# Setup tensorboard write and log writers on output rank
self.logger_ranks = self.parallel_context.get_global_rank(
ep_rank=0, pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0
).flatten()
self.loggerwriter = self.setup_log_writers()
# Log where each module is instantiated
self.unwrapped_model.log_modules(level=logging.DEBUG, group=self.parallel_context.world_pg, rank=0)
self.micro_batch_size = self.config.tokens.micro_batch_size
self.n_micro_batches_per_batch = self.config.tokens.batch_accumulation_per_replica
self.global_batch_size = (
self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size()
)
self.sequence_length = self.config.tokens.sequence_length
self.iteration_step = self.metadata.last_train_step
self.limit_val_batches = self.config.tokens.limit_val_batches
# NOTE: the dataloader currently in use for the current training stage
self.current_dataloader: Optional[DataLoader] = None
self.post_init()
def pre_init(self):
self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context)
def post_init(self):
# S3 Mover and save initial state
if self.config.s3_upload is not None:
# NOTE: Only local rank 0 should upload
dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0)
self.s3_mover = S3Mover(
local_path=self.config.checkpoints.checkpoints_path,
s3_path=self.config.s3_upload.upload_s3_path,
remove_after_upload=self.config.s3_upload.remove_after_upload,
s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency,
s5cmd_path=self.config.s3_upload.s5cmd_path,
dummy=dummy,
)
else:
self.s3_mover = None
def pre_training(self, *args, **kwargs):
self._print_training_plan()
metadata: TrainingMetadata = self.metadata
log_rank(
f"[Start training] datetime: {datetime.datetime.now()} | mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | global_batch_size: {self.global_batch_size} | sequence_length: {self.sequence_length} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_train_samples: {metadata.consumed_train_samples}", # noqa
logger=logger,
level=logging.INFO,
rank=0,
)
current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.init(
project=self.config.general.project,
name=f"{current_time}_{self.config.general.run}",
config={"nanotron_config": self.config.as_dict()},
)
def post_train_step(self):
# Update our background upload/removal of checkpoints
if self.s3_mover is not None:
self.s3_mover.update()
def post_training(self):
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg)
def _print_training_plan(self):
if hasattr(self.config, "data_stages") and self.config.data_stages is not None:
stages_info = "".join(
f"[Stage {stage.name}] start from step {stage.start_training_step} \n"
for stage in self.config.data_stages
)
full_log_message = (
f"[Training Plan] There are {len(self.config.data_stages)} training stages \n{stages_info}"
)
log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)
def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
from collections.abc import Generator
if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
if self.current_dataloader is None:
if isinstance(dataloaders, tuple):
dataloader = dataloaders[0]
else:
dataloader = dataloaders
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
# remove this in the next PR
self.current_dataloader = dataloaders
return
assert len(dataloaders) > 0, "No dataloaders provided"
assert len(dataloaders) == len(
self.config.data_stages
), "Number of dataloaders should match the number of dataset stages"
def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc
log_rank(
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
# NOTE: Clear dataloader from memory
del dataloader.dataset
del dataloader.sampler
del dataloader.batch_sampler
gc.collect()
dataloader = None
def find_stage_idx_to_resume():
reversed_data_stages = sorted(self.config.data_stages, key=lambda x: x.start_training_step, reverse=True)
for idx, stage in enumerate(reversed_data_stages):
if self.iteration_step >= stage.start_training_step:
return len(self.config.data_stages) - idx - 1
return None
stage_idx_to_resume = find_stage_idx_to_resume()
for stage_idx, stage in enumerate(self.config.data_stages):
if stage_idx < self.metadata.last_stage_idx:
continue
stage = cast(DatasetStageArgs, stage)
is_resume_from_training = self.current_dataloader is None and stage_idx_to_resume == stage_idx
if (stage.start_training_step == self.iteration_step) or is_resume_from_training:
if self.current_dataloader is not None:
prev_stage_name = self.config.data_stages[stage_idx - 1].name
prev_dataloader = dataloaders[prev_stage_name]
if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
self.metadata.last_stage_idx = stage_idx
if is_resume_from_training:
remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, self.config, self.metadata
)
consumed_train_steps = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata)
log_rank(
f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps",
logger=logger,
level=logging.INFO,
rank=0,
)
dataloader = dataloaders[stage.name]
# NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
dataloader = dataloader() if callable(dataloader) else dataloader
break
if dataloader is not None:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
def train(
self,
dataloader_or_dls: Dict[
str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]
],
**kwargs,
) -> None:
self.pre_training(**kwargs)
if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None:
self.save_checkpoint()
self.pipeline_engine: PipelineEngine = self.config.parallelism.pp_engine
self.pipeline_engine.nb_microbatches = self.n_micro_batches_per_batch
# TODO @nouamanetazi: refactor this
# Useful mapping
self.unwrapped_model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
self.unwrapped_model.module_id_to_prefix = {
id(module): f"{module_name}." for module_name, module in self.unwrapped_model.named_modules()
}
# Fix the root_model
self.unwrapped_model.module_id_to_prefix[id(self.unwrapped_model)] = ""
self.initial_iter_step = self.metadata.last_train_step + 1
self.last_iter_step = self.config.tokens.train_steps
prof = get_profiler(config=self.config)
# free memory
gc.collect()
torch.cuda.empty_cache()
with prof:
for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
if isinstance(prof, torch.profiler.profile):
prof.step()
self.iteration_start_time = time.time()
self._update_dataloader_based_on_training_stages(dataloader_or_dls)
# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
self.metadata.consumed_train_samples += self.global_batch_size
self.metadata.last_train_step = self.iteration_step
self.metadata.data_stages[
self.metadata.last_stage_idx
].consumed_train_samples += self.global_batch_size
if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0:
self.train_step_logs(outputs=outputs, loss_avg=loss_avg)
# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
self.save_checkpoint()
dist.barrier() # let's wait for everyone before leaving
if self.config.checkpoints.save_final_state:
self.save_checkpoint()
self.post_training()
def training_step(
self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]
) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]:
before_tbi_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)
outputs = self.pipeline_engine.train_batch_iter(
model=self.model,
pg=self.parallel_context.pp_pg,
batch=(next(dataloader) for _ in range(self.n_micro_batches_per_batch)),
nb_microbatches=self.n_micro_batches_per_batch,
grad_accumulator=self.grad_accumulator,
)
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)
after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
if isinstance(self.model, DistributedDataParallel) and self.grad_accumulator is not None:
# Wait for fp32 grads allreduce to finish to make sure grads are synced across DP
assert (
self.grad_accumulator.fp32_grads_allreduce_handle is not None
), "No fp32_grads_allreduce_handle maybe you're using only a single training process"
self.grad_accumulator.fp32_grads_allreduce_handle.wait()
# Sync tied weights
if not isinstance(self.model, DistributedDataParallel):
# Manually sync across DP if it's not handled by DDP
sync_gradients_across_dp(
module=self.model,
dp_pg=self.parallel_context.dp_pg,
reduce_op=dist.ReduceOp.AVG,
# TODO @thomasw21: This is too memory hungry, instead we run all_reduce
reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer),
grad_accumulator=self.grad_accumulator,
)
# TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass.
sync_tied_weights_gradients(
module=self.unwrapped_model,
parallel_context=self.parallel_context,
grad_accumulator=self.grad_accumulator,
)
# Clip gradients
if self.config.optimizer.clip_grad is not None:
# Unwrap DDP
named_parameters = [
(name, param)
for name, param in self.unwrapped_model.get_named_params_with_correct_tied()
if param.requires_grad
]
self.grad_norm_unclipped = clip_grad_norm(
mp_pg=self.parallel_context.mp_pg,
named_parameters=named_parameters,
grad_accumulator=self.grad_accumulator,
max_norm=self.config.optimizer.clip_grad,
)
# Compute DP average loss and overlap with optimizer step
if isinstance(outputs[0]["loss"], torch.Tensor):
# This is an average on only one data rank.
loss_avg = torch.stack(
[output["loss"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync loss across DP
handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
loss_avg = None
handle = None
# Move optimizer states back to GPU before optimizer step
if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step:
state_dict_to_device(self.optimizer.state_dict(), "cuda")
before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)
# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
# Update the learning rate
self.lr_scheduler.step()
after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
if handle is not None:
handle.wait()
self.post_train_step()
return outputs, loss_avg
def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
outputs = self.pipeline_engine.validate_batch_iter(
model=self.model,
batch=(next(dataloader) for _ in range(self.limit_val_batches)),
nb_microbatches=self.limit_val_batches,
)
return outputs
def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
) -> None:
# TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607
dist.barrier()
torch.cuda.synchronize()
elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000
tokens_per_sec = (
self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000)
) # tokens_per_sec is calculated using sequence_length
model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec(
iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000,
sequence_length=self.sequence_length,
global_batch_size=self.global_batch_size,
)
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks"
lr = self.lr_scheduler.get_last_lr()[0]
log_entries = [
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
"human_format",
), # , "12d"),
LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"),
LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"),
LogItem(
"tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format"
), # , "1.6E"),
LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"),
LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"),
LogItem("lr", lr, "human_format"), # , ".3E"),
LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"),
LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
]
if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))
# Log not too often the memory
if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0:
total, used, free = shutil.disk_usage("/")
log_entries.extend(
[
LogItem(
"cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format"
), # / 1024**2, ".2f"),
LogItem(
"cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format"
), # / 1024**2, ".2f"),
LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"),
LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"),
LogItem("hd_free_memory_tb", free, "human_format"), # / (2**40), ".2f"),
]
)
# NOTE: only one rank writes to wandb
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
)
self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)
# Nanotron Benchmark mode: we log the throughput and exit
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1" and self.iteration_step == 3:
log_throughput(
self.config,
self.parallel_context,
model_tflops,
hardware_tflops,
tokens_per_sec,
)
log_rank("Throughput logging complete", logger=logger, level=logging.INFO)
if "SLURM_JOB_ID" in os.environ:
os.system("scancel " + os.environ["SLURM_JOB_ID"])
else:
exit(0)
def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
"""Initialize the model and load weights from checkpoint if needed."""
# TODO: add max_position_embeddings
self.model_config.vocab_size = _vocab_size_with_padding(
self.model_config.vocab_size,
pg_size=self.parallel_context.tp_pg.size(),
make_vocab_size_divisible_by=self.config.model.make_vocab_size_divisible_by,
)
if (
getattr(self.model_config, "max_position_embeddings", None) is not None
and self.model_config.max_position_embeddings != self.config.tokens.sequence_length
):
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
log_rank(
f"Finetuning a model with a sequence length {self.config.tokens.sequence_length} that is different from the checkpoint's max_position_embeddings {self.model_config.max_position_embeddings}.", # noqa
logger=logger,
level=logging.WARNING,
rank=0,
)
else:
assert (
self.config.tokens.sequence_length == self.model_config.max_position_embeddings
), "The tokenizer's sequence length does not match the model's maximum position embeddings."
log_rank("Config:\n" + pformat(self.config), logger=logger, level=logging.INFO, rank=0)
log_rank("Model Config:\n" + pformat(self.model_config), logger=logger, level=logging.INFO, rank=0)
model = self._init_model_instance()
model = self._load_model_checkpoint(model)
return model
def _init_model_instance(self) -> NanotronModel:
model_config_cls = self.model_config.__class__.__name__
assert (
model_config_cls in CONFIG_TO_MODEL_CLASS
), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
model = self._init_model(
model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls](
config=self.model_config,
parallel_context=self.parallel_context,
parallel_config=self.config.parallelism,
random_states=self.random_states,
),
)
return model
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model
# Load or initialize model weights
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Load from a pre existing checkpoint
if check_path_is_local(self.init_checkpoint_path):
# Reload from a training checkpoint
log_rank(
f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0
)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0)
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
# Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.config.model.init_method.path,
)
elif isinstance(self.config.model.init_method, (RandomInit, SpectralMupInit)):
unwrapped_model.init_model_randomly(config=self.config)
# Synchronize parameters so that the model is consistent
# sync all params across dp
for _, param in sorted(model.named_parameters(), key=lambda x: x[0]):
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)
# sync tied params across tied groups
for (_, group_ranks), param in sorted(
get_tied_id_to_param(
parameters=model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
):
group = self.parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
else:
raise ValueError(f"Unsupported {self.config.model.init_method}")
return model
def _init_model(
self,
model_builder: Callable[[], NanotronModel],
target_pp_ranks: Optional[List[int]] = None,
) -> NanotronModel:
config = self.config
parallel_context = self.parallel_context
parallel_config = config.parallelism
make_ddp = parallel_context.data_parallel_size > 1 and not (
config.optimizer.accumulate_grad_in_fp32 and config.optimizer.zero_stage > 0
)
# Build model and set pp ranks
model = build_model(
parallel_context=parallel_context,
dtype=config.model.dtype,
target_pp_ranks=target_pp_ranks,
model_builder=model_builder,
)
# Initialize rotary embeddings
for module in model.modules():
if not isinstance(module, RotaryEmbedding):
continue
module.init_rotary_embeddings()
# Mark some parameters as tied
self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
# count number of parameters
num_params = sum(p.numel() for p in model.parameters())
size_params = sum(p.numel() * p.element_size() for p in model.parameters())
total_params = torch.tensor(num_params, device="cuda")
total_size = torch.tensor(size_params, device="cuda")
dist.all_reduce(total_params, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP
dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP
dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM)
dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM)
# TODO @nouamanetazi: better memory logs
log_rank(
f"Total number of parameters: {human_format(total_params.item())} ({total_size.item() / 1024**2:.2f}MiB)",
logger=logger,
level=logging.INFO,
group=parallel_context.world_pg,
rank=0,
)
log_rank(
f"Local number of parameters: {human_format(num_params)} ({size_params / 1024**2:.2f}MiB)",
logger=logger,
level=logging.INFO,
group=parallel_context.dp_pg,
rank=0,
)
log_rank(
f"[After model building] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB."
f" Peak allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB"
f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB",
logger=logger,
level=logging.INFO,
group=parallel_context.dp_pg,
rank=0,
)
# Model make it DDP
if make_ddp is True:
# Check that the model has at least one grad. Necessary for DDP
check_model_has_grad(model=model, parallel_context=parallel_context)
# TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway)
model = DistributedDataParallel(
model,
process_group=parallel_context.dp_pg,
broadcast_buffers=False,
bucket_cap_mb=config.model.ddp_bucket_cap_mb,
)
# Sanity check the model, all parameters must be NanotronParameter (either tied or sharded)
sanity_check(root_module=model)
return model
def setup_log_writers(
self,
) -> Optional[LoggerWriter]:
"""Setup all log writers on the appropriate ranks
Args:
config (Config): The config object
logger_ranks (Iterable[int]): The ranks that should log
parallel_context (DistributedProcessGroups): The distributed process groups
"""
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
loggerwriter = LoggerWriter(global_step=self.config.tokens.train_steps)
else:
loggerwriter = None
return loggerwriter
def pre_save_checkpoint(self) -> Path:
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg)
if self.s3_mover.post_upload_callback_outputs is not None:
slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs
self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval")
def post_save_checkpoint(self):
# Upload to S3
if self.s3_mover is not None:
self.s3_mover.start_uploading()
def save_checkpoint(self) -> Path:
self.pre_save_checkpoint()
checkpoints_path = self.config.checkpoints.checkpoints_path
checkpoint_path = checkpoints_path / f"{self.iteration_step}"
if self.config.checkpoints.checkpoints_path_is_shared_file_system:
should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0
else:
should_mkdir = bool(int(os.environ.get("LOCAL_RANK", None)) == 0)
if should_mkdir:
checkpoint_path.mkdir(parents=True, exist_ok=True)
dist.barrier(self.parallel_context.world_pg)
log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0)
# Update step/samples numbers before we save the config
self.config.general.step = self.metadata.last_train_step
self.config.general.consumed_train_samples = self.metadata.consumed_train_samples
save(
model=self.unwrapped_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
should_save_model=bool(
dist.get_rank(self.parallel_context.dp_pg) == 0
), # We only save the weights on DP==0
should_save_optimizer=True,
should_save_lr_scheduler=True,
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
parallel_context=self.parallel_context,
root_folder=checkpoint_path,
training_metadata=self.metadata,
config=self.config,
)
save_random_states(
random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path
)
with open(checkpoints_path / "latest.txt", mode="w") as fo:
fo.write(f"{self.iteration_step}")
if hasattr(self.model_config, "to_json_file"):
self.model_config.to_json_file(checkpoint_path / MODEL_CONFIG_FILE_NAME)
else:
with open(checkpoint_path / MODEL_CONFIG_FILE_NAME, mode="w") as fo:
fo.write(json.dumps(asdict(self.model_config)))
self.post_save_checkpoint()
return checkpoint_path
def _mark_tied_parameters(
self,
model: NanotronModel,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs] = None,
):
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
def mark_tied_parameters(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None
):
# Tie embeddings
embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names()
if len(embeddings_lm_head_tied_names) > 0:
shared_embeddings = [
(
target,
(
parallel_context.get_global_rank(
ep_rank=dist.get_rank(parallel_context.expert_pg),
pp_rank=get_pp_rank_of(target, module=model),
dp_rank=dist.get_rank(parallel_context.dp_pg),
tp_rank=dist.get_rank(parallel_context.tp_pg),
),
),
)
for target in embeddings_lm_head_tied_names
]
tie_parameters(
root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM
)
# Tie custom params
model.tie_custom_params()
# Sync all parameters that have the same name and that are not sharded across TP and EXP
assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point"
mark_unsharded_params_as_tied_across_tp(model, parallel_context, parallel_config)
mark_unsharded_params_as_tied_across_expert(model, parallel_context, parallel_config)
create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context)
def mark_unsharded_params_as_tied_across_tp(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs"
):
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if isinstance(param, NanotronParameter):
# We skip tying if param already tied or sharded along tp
if param.is_tied:
continue
if param.is_sharded:
sharded_info = param.get_sharded_info()
if sharded_info.is_tp_sharded(parallel_context=parallel_context):
continue
if isinstance(module, TensorParallelRowLinear) and "bias" == param_name:
# bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it
continue
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
if parallel_config is None or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce
# when TP=2 we have LN that is duplicated across TP, so by design it's tied
reduce_op = None
else:
reduce_op = dist.ReduceOp.SUM
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)
def mark_unsharded_params_as_tied_across_expert(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs"
):
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if isinstance(param, NanotronParameter):
# We skip tying if param already tied or sharded along expert
if param.is_tied:
continue
if param.is_sharded:
sharded_info = param.get_sharded_info()
if sharded_info.is_expert_sharded(parallel_context):
continue
shared_weights = [
(
name,
# sync across expert group
tuple(sorted(dist.get_process_group_ranks(parallel_context.expert_pg))),
)
]
# Besides MoE block which sees shards tokens, the rest of the model sees the full tokens
# so we don't need to reduce the gradients across expert group
reduce_op = None
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)
import functools
import inspect
import os
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import ContextManager, List, Optional
import torch
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint
from nanotron import distributed as dist
class Singleton(type):
"""
Singleton metaclass.
Create objects using this class as the metaclass to enable singleton behaviour.
For instance:
```
class Logger(metaclass=Singleton):
...
```
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `transformers` library.
"""
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({[context_manager.gen.__qualname__ for context_manager in self.context_managers]})"
@contextmanager
def main_rank_first(group: dist.ProcessGroup):
"""Context manager that executes the code in the context with the rank zero of the group going first."""
is_main = dist.get_rank(group) == 0
if is_main:
yield
dist.barrier(group)
if not is_main:
yield
@contextmanager
def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None):
"""Context manager that executes the code in the context with all the local rank zero of the group going first.
Useful to run only once per node first (e.g. to create local files, etc)
"""
is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0
if is_main:
yield
dist.barrier(group)
if not is_main:
yield
def checkpoint_method(attr_name: str):
"""Decorator to checkpoint a method of a class."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_self = args[0]
checkpoint_activated = getattr(_self, attr_name)
if checkpoint_activated:
all_args = list(args)
signature_params = inspect.signature(func).parameters
# Parameters are ordered in the function definition order: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
for i, (arg_name, arg_value) in enumerate(signature_params.items()):
if arg_value.kind in [inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL]:
raise NotImplementedError(
"Checkpointing of functions with *args or **kwargs is not supported."
)
if i < len(args):
continue
if arg_name not in kwargs:
assert (
arg_value.default is not inspect.Parameter.empty
), f"Missing argument {arg_name} from {kwargs} for {func.__name__}"
all_args.append(arg_value.default)
else:
all_args.append(kwargs[arg_name])
assert len(all_args) == len(signature_params), f"Missing arguments for {func.__name__}"
# TODO @nouamanetazi: we pass `self`(which is module) to checkpoint, so it's stored in `ctx.inputs` whereas some other methods create a custom fwd and pass only tensors without `self`. Need to investigate which is better
return checkpoint(func, *all_args)
else:
return func(*args, **kwargs)
return wrapper
return decorator
def get_parameter_and_parent_module(target: str, root_module: nn.Module):
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = root_module.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
param: torch.nn.Parameter = getattr(mod, param_name)
if not isinstance(param, torch.nn.Parameter):
raise AttributeError("`" + param_name + "` is not an " "nn.Parameter")
return param, mod, param_name
def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage:
if version.parse(torch.__version__) >= version.parse("2.0"):
return tensor.untyped_storage()
else:
return tensor.storage().untyped()
def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype):
# TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage.
device = untyped_storage.device
tensor = torch.empty([], dtype=dtype, device=device)
tensor.set_(source=untyped_storage)
return tensor
def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
while True:
port = random.randint(min_port, max_port)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
10K slice of OpenWebText - An open-source replication of the WebText dataset from OpenAI.
This is a small subset representing the first 10K records from the original dataset - created for testing.
The full 8M-record dataset is [here](https://huggingface.co/datasets/openwebtext).
```
$ python -c "from datasets import load_dataset; ds=load_dataset('stas/openwebtext-10k'); print(ds)"
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 10000
})
})
```
* Records: 10,000
* compressed size: ~15MB
* uncompressed size: 50MB
To convert to jsonlines:
```
from datasets import load_dataset
dataset_name = "stas/openwebtext-10k"
name = dataset_name.split('/')[-1]
ds = load_dataset(dataset_name, split='train')
ds.to_json(f"{name}.jsonl", orient="records", lines=True)
```
To see how this subset was created, here is the [instructions file](https://huggingface.co/datasets/stas/openwebtext-10k/blob/main/process.txt).
{"plain_text": {"description": "An open-source replication of the WebText dataset from OpenAI.\n\nThis is a small subset representing the first 10K records from the original dataset - created for testing.\n\nThe full 8M-record dataset is at https://huggingface.co/datasets/openwebtext\n", "citation": "@misc{Gokaslan2019OpenWeb,\n title={OpenWebText Corpus},\n author={Aaron Gokaslan*, Vanya Cohen*, Ellie Pavlick, Stefanie Tellex},\n howpublished{\\url{http://Skylion007.github.io/OpenWebTextCorpus}},\n year={2019}\n}\n", "homepage": "https://skylion007.github.io/OpenWebTextCorpus/", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "builder_name": "openwebtext10k", "config_name": "plain_text", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 49670861, "num_examples": 10000, "dataset_name": "openwebtext10k"}}, "download_checksums": {"https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/openwebtext-10k.tar.xz": {"num_bytes": 14723792, "checksum": "1dd150ffa3361ab32fa9f129d1b5ce20ac48728be16be436558f844d1761c572"}}, "download_size": 14723792, "post_processing_size": null, "dataset_size": 49670861, "size_in_bytes": 64394653}}
\ No newline at end of file
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Open WebText Corpus"""
import os
import re
from itertools import chain
import datasets
_CITATION = """\
@misc{Gokaslan2019OpenWeb,
title={OpenWebText Corpus},
author={Aaron Gokaslan*, Vanya Cohen*, Ellie Pavlick, Stefanie Tellex},
howpublished{\\url{http://Skylion007.github.io/OpenWebTextCorpus}},
year={2019}
}
"""
_DESCRIPTION = """\
An open-source replication of the WebText dataset from OpenAI.
This is a small subset representing the first 10K records from the original dataset - created for testing.
The full 8M-record dataset is at https://huggingface.co/datasets/openwebtext
"""
# _URL = "https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/openwebtext-10k.tar.xz"
_URL = "/home/nanotron/openwebtext-10k.tar.xz"
class Openwebtext10k(datasets.GeneratorBasedBuilder):
"""The Open WebText dataset."""
BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="plain_text",
description="Plain text",
version=datasets.Version("1.0.0"),
)
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features({"text": datasets.Value("string")}),
homepage="https://skylion007.github.io/OpenWebTextCorpus/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
dl_dir = dl_manager.download_and_extract(_URL)
owt_dir = os.path.join(dl_dir, "openwebtext-10k")
subset_xzs = [
os.path.join(owt_dir, file_name)
for file_name in sorted(os.listdir(owt_dir))
if file_name.endswith("xz") # filter out ...xz.lock
]
# ex_dirs = dl_manager.extract(subset_xzs, num_proc=round(os.cpu_count() * 0.75))
ex_dirs = dl_manager.extract(subset_xzs)
nested_txt_files = [
[
os.path.join(ex_dir, txt_file_name)
for txt_file_name in sorted(os.listdir(ex_dir))
if txt_file_name.endswith("txt")
]
for ex_dir in ex_dirs
]
txt_files = chain(*nested_txt_files)
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"txt_files": txt_files}),
]
def _generate_examples(self, txt_files):
"""Yields examples."""
for idx, filepath in enumerate(txt_files):
with open(filepath, encoding="utf-8") as f:
yield idx, {"text": re.sub("\n\n\n+", "\n\n", f.read()).strip()}
# this is a small derivative from 8M-big openwebtext dataset for testing
# how this build script and dataset_infos.json were generated
#
mkdir openwebtext-10k
cd openwebtext-10k
# data
wget https://zenodo.org/record/3834942/files/openwebtext.tar.xz
tar xf openwebtext.tar.xz
cd openwebtext
rename.pl 's|-|-00|; s|-00(\d\d\d)|-$1|; s|-00(\d\d)|-0$1|;' *xz
# now open the first 30 archives
mkdir subset
cp urlsf_subset00-0[0-2]*_data.xz subset
cd subset
find . -name "*xz" -exec tar xf {} \;
mkdir 10k
find . -name "*txt" | sort | head -10000 | xargs mv -t 10k
tar cfJ 10k.xz -C 10k .
mkdir openwebtext-10k
mv 10k.xz openwebtext-10k
tar cfJ openwebtext-10k.tar.xz openwebtext-10k
# the openwebtext subdir gets created on the fly
aws s3 cp openwebtext-10k.tar.xz s3://datasets.huggingface.co/nlp/datasets/openwebtext/
# script
wget https://raw.githubusercontent.com/huggingface/datasets/master/datasets/openwebtext/openwebtext.py
mv openwebtext.py openwebtext-10k.py
perl -pi -e 's|openwebtext|openwebtext-10k|g' openwebtext-10k.py
perl -pi -e 's|https://zenodo.org/record/3834942/files/|https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/|g' openwebtext-10k.py
perl -pi -e 's|Openwebtext|Openwebtext10k|g' openwebtext-10k.py
# manually check that the script is correct - edit the descriptions
# create a new dataset entry on the hub
https://huggingface.co/new-dataset
# once created clone it
git clone https://huggingface.co/datasets/stas/openwebtext-10k
cp openwebtext-10k.py process.txt openwebtext-10k
cd openwebtext-10k
git add openwebtext-10k.py process.txt
git commit -m "build script" openwebtext-10k.py process.txt
git push
# test and generate config file
cd ..
datasets-cli test ./openwebtext-10k --save_infos --all_configs
# add and push the generated config
cd openwebtext-10k
git add dataset_infos.json
git commit -m "add dataset_infos.json" dataset_infos.json
git push
# test that the dataset is working
python -c "from datasets import load_dataset; ds=load_dataset('stas/openwebtext-10k'); print(ds)"
import torch
from nanotron.fp8 import DTypes, FP8Parameter, FP8Tensor
from nanotron.fp8.meta import FP8Meta
def test_create_fp8_parameter():
# TODO(xrsrke): test FP8E5M2 format
# TODO(xrsrke): test take a cpu tensor
tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32)
fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3)
assert isinstance(fp8_parameter.data, FP8Tensor)
assert fp8_parameter.requires_grad is True
assert fp8_parameter.grad is None
assert isinstance(fp8_parameter.fp8_meta, FP8Meta)
assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta)
# TODO(xrsrke): add test for preventing torch autograd do the backward pass
# on a FP8Parameter
import pytest
import torch
from nanotron.fp8 import DTypes, FP8Linear, FP8Parameter, FP8Tensor
from torch import nn
from torch.optim import Adam
@pytest.mark.parametrize("is_bias", [True, False])
def test_fp8_linear_forward_pass(is_bias):
input = torch.randn(16, 16, device="cuda", dtype=torch.float32)
ref_input = input.detach().clone()
ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda", dtype=torch.float32)
fp8_linear = FP8Linear(16, 16, bias=is_bias, device="cuda:0")
fp8_linear.weight = FP8Parameter(ref_linear.weight.detach().clone(), DTypes.FP8E4M3)
if is_bias:
fp8_linear.bias.data = ref_linear.bias.detach().clone()
ref_output = ref_linear(ref_input)
output = fp8_linear(input)
assert isinstance(output, torch.Tensor)
assert output.dtype == torch.float32
assert torch.allclose(output, ref_output, rtol=0, atol=0.1)
# TODO(xrsrke): add cases where the input requires and don't require grad
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_fp8_linear_backward_pass(input_requires_grad, device):
input = torch.randn(16, 16, device=device, dtype=torch.float32, requires_grad=input_requires_grad)
ref_input = input.detach().clone().requires_grad_(True)
ref_linear = nn.Linear(16, 16, device=device, dtype=torch.float32)
fp8_linear = FP8Linear(16, 16, device=device)
if device == "cpu":
fp8_linear.weight.data = ref_linear.weight.detach().clone()
else:
fp8_linear.weight.data = FP8Tensor(ref_linear.weight.detach().clone(), dtype=DTypes.FP8E4M3)
fp8_linear.bias.data = ref_linear.bias.detach().clone()
ref_linear(ref_input).sum().backward()
fp8_linear(input).sum().backward()
# TODO(xrsrke): investigate why input.grad is so high tolerance
# assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True
assert torch.allclose(fp8_linear.weight.grad, ref_linear.weight.grad, 0.1, 0.1)
assert torch.allclose(fp8_linear.bias.grad, ref_linear.bias.grad, 0, 0.1)
# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear
def test_fp8_linear_attrs():
fp8_linear = FP8Linear(16, 16, device="cuda:0")
assert next(fp8_linear.parameters()) is not None
assert all(p.requires_grad for p in fp8_linear.parameters()) is True
# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based
# on the requires_grad of the input, weight, or bias
def test_fp8_model_bwd():
HIDEEN_SIZE = 128
N_LAYERS = 5
N_EPOCHS = 3
input = torch.randn(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda", requires_grad=True)
model = nn.Sequential(
*[nn.Sequential(FP8Linear(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda"), nn.ReLU()) for _ in range(N_LAYERS)]
)
optim = Adam(model.parameters(), lr=1e-3)
for _ in range(N_EPOCHS):
optim.zero_grad()
model(input).sum().backward()
optim.step()
assert all(p.grad is not None for p in model.parameters())
from copy import deepcopy
import numpy as np
import pytest
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8 import DTypes, FP8Tensor
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import convert_tensor_from_fp8
@pytest.mark.parametrize("size", [4, 8, 16, 64])
def test_quantize_and_dequantize_tensor_in_fp8(size):
tensor = torch.randn((size, size), dtype=torch.float32, device="cuda")
ref_tensor = deepcopy(tensor)
fp8_tensor = FP8Tensor(tensor, dtype=DTypes.FP8E4M3)
assert isinstance(fp8_tensor, FP8Tensor)
assert isinstance(fp8_tensor.fp8_meta, FP8Meta)
assert fp8_tensor.device == ref_tensor.device
assert fp8_tensor.dtype == torch.uint8
assert fp8_tensor.shape == ref_tensor.shape
assert fp8_tensor.numel() == ref_tensor.numel()
assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy())
# TODO(xrsrke): remove the fixed 1 factor
# it couples with the current implementation of FP8Meta
# because we initialize scale with 1
assert fp8_tensor.fp8_meta.amax == ref_tensor.abs().max()
assert isinstance(fp8_tensor.fp8_meta.inverse_scale, torch.Tensor)
assert fp8_tensor.fp8_meta.scale != 0.1 and fp8_tensor.fp8_meta.scale != 1
assert isinstance(fp8_tensor.fp8_meta.te_dtype, tex.DType)
tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32)
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == ref_tensor.dtype
assert torch.allclose(tensor, ref_tensor, rtol=1e-1, atol=1e-1)
def test_fp8_tensor_attrs():
SIZE = 64
tensor = torch.randn((SIZE, SIZE), dtype=torch.float32, device="cuda:0")
ref_tensor = tensor.detach().clone()
fp8_tensor = FP8Tensor(tensor, DTypes.FP8E4M3)
assert isinstance(fp8_tensor, FP8Tensor)
assert isinstance(fp8_tensor.fp8_meta, FP8Meta)
assert fp8_tensor.device == ref_tensor.device
assert fp8_tensor.dtype == torch.uint8
assert fp8_tensor.shape == ref_tensor.shape
assert fp8_tensor.numel() == ref_tensor.numel()
assert fp8_tensor.device == ref_tensor.device
# TODO(xrsrke): test it has all the methods of torch.Tensor
# TODO(xrsrke): test it has all the attributes of its input tensor
import shutil
import uuid
from functools import lru_cache
from pathlib import Path
class TestContext:
def __init__(self):
self._random_string = str(uuid.uuid1())
self._root_dir = Path(__file__).parent.parent / ".test_cache"
self._root_dir.mkdir(parents=True, exist_ok=True)
@lru_cache(maxsize=1)
def get_auto_remove_tmp_dir(self):
path = self._root_dir / self._random_string
path.mkdir(parents=True, exist_ok=True)
return path
def __del__(self):
path = self.get_auto_remove_tmp_dir()
shutil.rmtree(path)
import hashlib
import importlib
import json
import os
import sys
from argparse import Namespace
from collections import OrderedDict
from pathlib import Path
package = importlib.import_module("nanotron")
package_path = Path(package.__file__).parent.parent.parent
sys.path.append(str(package_path))
import nanotron.distributed as dist
import torch
from nanotron.data.nanoset import Nanoset
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.sanity_checks import assert_tensor_synced_across_pg
from tools.preprocess_data import main
def create_dataset_paths(tmp_dir: str, quantity: int):
json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}.json") for i in range(quantity)]
datatrove_tokenized_dataset_paths = [os.path.join(tmp_dir, f"tokenized_documents_{i}") for i in range(quantity)]
return json_dataset_path, datatrove_tokenized_dataset_paths
def create_dummy_json_dataset(path_to_json: str, dummy_text: str, n_samples: int = 50000):
with open(path_to_json, "a") as json_file:
for sample in range(n_samples):
sample_dict = {"text": f"[{sample}] Hello! Im sample {sample}! And this is my dummy text: {dummy_text}"}
json_file.write(json.dumps(sample_dict))
json_file.write("\n")
def preprocess_dummy_dataset(json_dataset_path: str, datatrove_tokenized_dataset_path: str, tokenizer: str):
# Create args for preprocessing
args = Namespace(
readers="jsonl",
dataset=json_dataset_path,
column="text",
glob_pattern=None,
output_folder=datatrove_tokenized_dataset_path,
tokenizer_name_or_path=tokenizer,
eos_token=None,
n_tasks=1,
logging_dir=None,
)
# tools/preprocess_data.py main
main(args)
def assert_batch_dataloader(
batch: dict, parallel_context: ParallelContext, micro_batch_size: int, sequence_length: int
):
"""
batch (dict): Batch produced from the Dataloader, with keys input_ids, input_mask, label_ids, label_mask
"""
for element in batch:
tensor = batch[element]
# Assert that inputs are only present in input_pp_rank and outputs in output_pp_rank
input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank and element.startswith("input_"):
assert isinstance(tensor, torch.Tensor)
elif dist.get_rank(parallel_context.pp_pg) == output_pp_rank and element.startswith("label_"):
assert isinstance(tensor, torch.Tensor)
else:
assert isinstance(tensor, TensorPointer)
data_class = (
0 # 0 if tensor is from the ids, 1 if TensorPointer and 2 if mask. Used in the data parallel group check
)
# Check shape of mask and ids tensors
if isinstance(tensor, torch.Tensor):
assert tensor.shape == (micro_batch_size, sequence_length)
# TensorPointer case: Check that all TensorPointers from the same tp_pg point to the same group_rank. Create torch.tensor with group_rank
if isinstance(tensor, TensorPointer):
tensor = torch.tensor(tensor.group_rank)
data_class = 1
# Attention Masks case: dtype is torch.bool --> Transform to int64
if tensor.dtype == torch.bool:
tensor = tensor.long()
data_class = 2
# Assert that we have the SAME element in all the processes belonging to the same tensor parallel group
assert_tensor_synced_across_pg(
tensor=tensor.flatten().cuda(),
pg=parallel_context.tp_pg,
msg=lambda err: f"{element} is not synchronized across TP {err}",
)
# Assert that we have the SAME class of data in all processes belonging to the same data parallel group
assert_tensor_synced_across_pg(
tensor=torch.tensor(data_class, device="cuda"),
pg=parallel_context.dp_pg,
msg=lambda err: f"{element} is not synchronized across DP {err}",
)
def compute_hash(identifier: OrderedDict, n_digit: int = 8) -> int:
"""
Creates a sha256 hash from the elements of a OrderedDict
"""
unique_description = json.dumps(identifier, indent=4)
# Create n_digit description hash
unique_description_hash = int(hashlib.sha256(unique_description.encode("utf-8")).hexdigest(), 16) % 10**n_digit
return unique_description_hash
def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: ParallelContext):
"""
Checks that the same Nanoset is created in all processes
"""
# Extract a sample from the Nanoset
IDX_SAMPLE = 23
nanoset_identifiers = OrderedDict()
nanoset_identifiers["dataset_folders"] = nanoset.dataset_folders
nanoset_identifiers["dataset_weights"] = nanoset.dataset_weights.tolist()
nanoset_identifiers["sequence_length"] = nanoset.sequence_length
nanoset_identifiers["train_split_num_samples"] = nanoset.train_split_num_samples
nanoset_identifiers["random_seed"] = nanoset.random_seed
nanoset_identifiers["length"] = len(nanoset)
nanoset_identifiers["input_ids"] = nanoset[IDX_SAMPLE]["input_ids"].tolist()
nanoset_identifiers["dataset_index"] = nanoset.dataset_index.tolist()
nanoset_identifiers["dataset_sample_index"] = nanoset.dataset_sample_index.tolist()
nanoset_identifiers["token_size"] = nanoset.token_size
unique_description_hash = compute_hash(nanoset_identifiers)
assert_tensor_synced_across_pg(
tensor=torch.tensor(unique_description_hash, device="cuda"),
pg=parallel_context.world_pg,
msg=lambda err: f"Nanoset is not synchronized across all processes {err}",
)
def compute_batch_hash(batch: dict) -> int:
"""
Checks that the Nanoset/BlendedNanoset is in the same state after recovering from a crash
batch (dict): Batch produced from the Dataloader, with keys input_ids, input_mask, label_ids, label_mask
"""
batch_identifiers = OrderedDict()
for element in batch:
tensor = batch[element]
# TensorPointer
if isinstance(tensor, TensorPointer):
identifier = tensor.group_rank
# Attention Masks case: dtype is torch.bool --> Transform to int64
elif tensor.dtype == torch.bool:
identifier = tensor.long().tolist()
# Input IDs tensor
else:
identifier = tensor.tolist()
batch_identifiers[element] = identifier
unique_description_hash = compute_hash(batch_identifiers)
return unique_description_hash
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup, get_global_rank
def assert_tensor_equal_over_group(tensor: torch.Tensor, group: ProcessGroup, assert_: bool = True) -> bool:
"""We assume that tensors are already of correct size."""
reference_rank = 0
if dist.get_rank(group) == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=get_global_rank(group=group, group_rank=reference_rank),
group=group,
)
if assert_:
torch.testing.assert_close(tensor, reference_tensor, atol=0, rtol=0)
else:
result = torch.allclose(tensor, reference_tensor, atol=0.0, rtol=0.0)
results = [0] * group.size()
dist.all_gather_object(results, result, group)
return all(results)
from math import ceil
from typing import Union
import torch
from nanotron import distributed as dist
from nanotron.models import init_on_device_and_dtype
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.parallel.utils import initial_sync
from torch import nn
from torch.nn.parallel import DistributedDataParallel
class DummyModel(nn.Module):
def __init__(
self,
p2p: P2P,
):
super().__init__()
self.p2p = p2p
self.mlp = nn.Sequential(
*(
nn.ModuleDict(
{
"linear": PipelineBlock(
p2p=p2p,
module_builder=nn.Linear,
module_kwargs={"in_features": 10, "out_features": 10},
module_input_keys={"input"},
module_output_keys={"output"},
),
"activation": PipelineBlock(
p2p=p2p,
module_builder=nn.Sigmoid if pp_rank < p2p.pg.size() - 1 else nn.Identity,
module_kwargs={},
module_input_keys={"input"},
module_output_keys={"output"},
),
}
)
for pp_rank in range(p2p.pg.size())
)
)
self.loss = PipelineBlock(
p2p=p2p,
module_builder=lambda: lambda x: x.sum(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(self, x: Union[torch.Tensor, TensorPointer]):
for non_linear in self.mlp:
x = non_linear.linear(input=x)["output"]
x = non_linear.activation(input=x)["output"]
x = self.loss(x=x)["output"]
return x
def init_dummy_model(parallel_context: ParallelContext, dtype: torch.dtype = torch.float) -> DummyModel:
p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda"))
model = DummyModel(p2p=p2p)
# Build model using contiguous segments
pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)]
with init_on_device_and_dtype(device=torch.device("cuda"), dtype=dtype):
contiguous_size = ceil(len(pipeline_blocks) / parallel_context.pp_pg.size())
for i, block in enumerate(pipeline_blocks):
rank = i // contiguous_size
block.build_and_set_rank(rank)
# Sync all parameters that have the same name and that are not sharded across TP.
for name, param in model.named_parameters():
if isinstance(param, NanotronParameter) and param.is_sharded:
continue
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM
)
initial_sync(model=model, parallel_context=parallel_context)
if len(list(model.named_parameters())) > 0:
model = DistributedDataParallel(model, process_group=parallel_context.dp_pg)
else:
# No parameters, so no need to use DDP to sync parameters gradients
model = model
return model
def init_dummy_optimizer(model: nn.Module, parallel_context: ParallelContext) -> BaseOptimizer:
optimizer = NamedOptimizer(
named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params)
)
# Synchronize across dp: basic assumption, already done as nothing in optimizer initialization is stochastic
return optimizer
def dummy_infinite_data_loader(pp_pg: dist.ProcessGroup, dtype=torch.float, input_pp_rank=0):
micro_batch_size = 3
# We assume the first linear is always built on the first rank.
current_pp_rank = dist.get_rank(pp_pg)
while True:
yield {
"x": torch.randn(micro_batch_size, 10, dtype=dtype, device="cuda")
if current_pp_rank == input_pp_rank
else TensorPointer(group_rank=input_pp_rank)
}
import contextlib
import signal
from typing import Optional
from nanotron import distributed as dist
@contextlib.contextmanager
def assert_fail_with(exception_class, error_msg: Optional[str] = None):
try:
yield
except exception_class as e:
if error_msg is None:
return
if error_msg == str(e):
return
else:
raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
@contextlib.contextmanager
def assert_fail_except_rank_with(
exception_class, rank_exception: int, pg: dist.ProcessGroup, error_msg: Optional[str] = None
):
try:
yield
except exception_class as e:
if rank_exception == dist.get_rank(pg):
raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
else:
if error_msg is None:
return
if error_msg == str(e):
return
else:
raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
if dist.get_rank(pg) != rank_exception:
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
@contextlib.contextmanager
def timeout_after(ms=500):
"""Timeout context manager."""
def signal_handler(signum, frame):
raise TimeoutError(f"Timed out after {ms} ms.")
signal.signal(signal.SIGALRM, signal_handler)
signal.setitimer(signal.ITIMER_REAL, ms / 1000)
try:
yield
finally:
signal.alarm(0)
import torch
from nanotron.config import (
AllForwardAllBackwardPipelineEngine,
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LoggingArgs,
LRSchedulerArgs,
ModelArgs,
OptimizerArgs,
ParallelismArgs,
TensorParallelLinearMode,
TokenizerArgs,
TokensArgs,
)
from nanotron.config.config import PretrainDatasetsArgs
from nanotron.models import build_model
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel.context import ParallelContext
from nanotron.trainer import mark_tied_parameters
TINY_LLAMA_CONFIG = LlamaConfig(
**{
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 16,
"initializer_range": 0.02,
"intermediate_size": 32,
"is_llama_config": True,
"max_position_embeddings": 128,
"num_attention_heads": 8,
"num_hidden_layers": 4,
"num_key_value_heads": 4,
"pad_token_id": None,
"pretraining_tp": 1,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"tie_word_embeddings": False,
"use_cache": True,
"vocab_size": 4096,
}
)
def get_llama_training_config(model_config: ModelArgs):
return Config(
model=model_config,
general=GeneralArgs(project="unittest", run="sanity_llama", seed=42),
checkpoints=CheckpointsArgs(
checkpoints_path="./checkpoints",
checkpoint_interval=10,
),
parallelism=ParallelismArgs(
dp=1,
pp=1,
tp=2,
expert_parallel_size=2,
pp_engine="1f1b",
tp_mode="ALL_REDUCE",
tp_linear_async_communication=False,
),
tokenizer=TokenizerArgs("gpt2"),
optimizer=OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=False,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=LRSchedulerArgs(
learning_rate=3e-4,
lr_warmup_steps=100,
lr_warmup_style="linear",
lr_decay_style="cosine",
min_decay_lr=1e-5,
),
),
logging=LoggingArgs(),
tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1),
data_stages=[
DatasetStageArgs(
name="train",
start_training_step=1,
data=DataArgs(
seed=42,
num_loading_workers=1,
dataset=PretrainDatasetsArgs(
hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small",
hf_dataset_splits="train",
text_column_name="completion",
dataset_processing_num_proc_per_process=12,
),
),
)
],
)
def create_llama_from_config(
model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext
) -> LlamaForTraining:
"""
Creates and returns a nanotron model.
If `model_config` is None, then `checkpoint_path` must be set, in which case
the configuration will be loaded from such path.
If `checkpoint_path` is None, then `model_config` must be set, in which case
the model created will have random weights.
"""
parallel_config = ParallelismArgs(
dp=parallel_context.data_parallel_size,
pp=parallel_context.pipeline_parallel_size,
tp=parallel_context.tensor_parallel_size,
pp_engine=AllForwardAllBackwardPipelineEngine(),
tp_mode=TensorParallelLinearMode.ALL_REDUCE,
tp_linear_async_communication=False,
)
model = build_model(
model_builder=lambda: LlamaForTraining(
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
parallel_context=parallel_context,
dtype=torch.bfloat16,
device=device,
)
mark_tied_parameters(model=model, parallel_context=parallel_context)
return model
import contextlib
import os
import re
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch.cuda
import torch.multiprocessing as mp
from nanotron.parallel import ParallelContext
from packaging import version
def available_gpus():
if not torch.cuda.is_available():
return 0
device_properties = [torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count())]
# We filter out
blacklisted_gpu_names = {"NVIDIA DGX Display"}
device_properties = [property_ for property_ in device_properties if property_.name not in blacklisted_gpu_names]
# TODO @thomasw21: Can we do this cross node
return len(device_properties)
# from https://stackoverflow.com/a/34333710/9201239
@contextlib.contextmanager
def mock_os_environ(remove_keys: List[str] = None, update_key_values: Dict[str, Any] = None):
"""
Temporarily updates the ``os.environ`` dictionary in-place.
The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations.
Args:
remove_keys: Environment variables to remove.
update_key_values: Dictionary of environment variables and values to add/update.
"""
env = os.environ
update_key_values = update_key_values or {}
remove_keys = remove_keys or []
update_keys = set(update_key_values.keys())
remove_keys = set(remove_keys)
assert remove_keys.isdisjoint(update_keys)
stomped = (update_keys | remove_keys) & set(env.keys())
reverse_change = {
# Environment variables and values to restore on exit.
**{k: env[k] for k in update_keys & stomped},
# Environment variables and values to remove on exit.
**{k: env[k] for k in remove_keys & stomped},
}
try:
env.update(update_key_values)
for k in remove_keys:
env.pop(k, None)
yield
finally:
env.update(reverse_change)
def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> Tuple[bool, Optional[str]]:
"""Returns True or False if the dictionaries match, and an additional message when it's False"""
if sub_paths is None:
sub_paths = []
first_keys = set(first.keys())
second_keys = set(second.keys())
if first_keys != second_keys:
return False, f"Keys don't match in {'.'.join(sub_paths)}.\nCur: {first_keys}\nRef: {second_keys}"
for key in first_keys:
first_elt = first[key]
second_elt = second[key]
if isinstance(first_elt, dict):
if not isinstance(second_elt, dict):
return (
False,
f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}",
)
match, msg = is_dict_equal(first_elt, second_elt, sub_paths=sub_paths + [str(key)])
if match is False:
return False, msg
elif isinstance(first_elt, torch.Tensor):
if not isinstance(second_elt, torch.Tensor):
return (
False,
f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}",
)
try:
torch.testing.assert_close(
first_elt,
second_elt,
atol=0.0,
rtol=0.0,
msg=lambda msg: f"Tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}",
)
except AssertionError as error:
return False, error.args[0]
else:
if first_elt != second_elt:
return (
False,
f"Objects at key {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}",
)
return True, None
def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]:
"""Given a number of gpus, we want all 3d configurations possible such that pp * dp * tp = gpus"""
result = []
for tp in range(1, gpus + 1):
if gpus % tp != 0:
continue
gpus_left_after_tp = gpus // tp
for dp in range(1, gpus_left_after_tp + 1):
if gpus_left_after_tp % dp != 0:
continue
gpus_left_after_dp = gpus_left_after_tp // dp
for pp in range(1, gpus_left_after_dp + 1):
if gpus_left_after_dp % pp != 0:
continue
if tp * dp * pp == gpus:
result.append((pp, dp, tp))
return result
def rerun_if_address_is_in_use(max_try: int = 500):
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157
Usage::
@rerun_if_address_is_in_use()
def test_something():
...
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
# only torch >= 1.8 has ProcessRaisedException
if torch_version >= version.parse("1.8.0"):
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try)
return func_wrapper
def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable:
"""
A decorator on a function to re-run when an exception occurs.
Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71
Usage::
# rerun for all kinds of exception
@rerun_on_exception()
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for RuntimeError only
@rerun_on_exception(exception_type=RuntimeError)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for maximum 10 times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=10)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
def test_method():
print('hey')
raise RuntimeError('Address already in use')
Args:
exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message,
the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun forever if exception keeps occurring
"""
def _match_lines(lines, pattern):
for line in lines:
if re.match(pattern, line):
return True
return False
def _wrapper(func):
def _run_until_success(*args, **kwargs):
try_count = 0
assert max_try is None or isinstance(
max_try, int
), f"Expected max_try to be None or int, but got {type(max_try)}"
while max_try is None or try_count < max_try:
try:
try_count += 1
ret = func(*args, **kwargs)
return ret
except exception_type as e:
error_lines = str(e).split("\n")
if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):
print("Exception is caught, retrying...")
# when pattern is not specified, we always skip the exception
# when pattern is specified, we only skip when pattern is matched
continue
else:
print("Maximum number of attempts is reached or pattern is not matched, no more retrying...")
raise e
# Override signature
# otherwise pytest.mark.parameterize will raise the following error:
# function does not use argument xxx
sig = signature(func)
_run_until_success.__signature__ = sig
return _run_until_success
return _wrapper
def global_wrapper(rank, func, tp, pp, dp, port, kwargs):
def setup_dist_env(rank, world_size, port):
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
# NOTE: since we do unit tests in a
# single node => this is fine!
os.environ["LOCAL_RANK"] = str(rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
world_size = tp * pp * dp
setup_dist_env(rank, world_size, port)
parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp)
func(parallel_context, **kwargs)
def init_distributed(tp: int, dp: int, pp: int):
def _init_distributed(func):
def wrapper(**kwargs):
from nanotron.utils import find_free_port
world_size = tp * pp * dp
port = find_free_port()
# Note that kwargs needs to be passed as part of args in a way that can be unpacked
args = (func, tp, pp, dp, port, kwargs)
mp.spawn(global_wrapper, args=args, nprocs=world_size)
return wrapper
return _init_distributed
import torch
from nanotron.logging import LoggerWriter
from nanotron.nn.layer_norm import TritonLayerNorm
from torch.nn import LayerNorm
def get_time_name():
import datetime
today = datetime.datetime.now()
return today.strftime("%d/%m/%Y_%H:%M:%S")
if __name__ == "__main__":
BATCH_SIZE = 1
SEQ_LEN = 2
DEVICE, DTYPE = torch.device("cuda:0"), torch.float32
HIDDEN_SIZE = 1024
NUM_STEPS = 10_000
inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=DEVICE, dtype=DTYPE)
layer_norm = LayerNorm(normalized_shape=inputs.size(-1), device=DEVICE, dtype=DTYPE)
fused_layer_norm = TritonLayerNorm(
normalized_shape=inputs.size(-1),
device=DEVICE,
dtype=DTYPE,
)
ref_optim = torch.optim.Adam(layer_norm.parameters(), lr=0.1)
optim = torch.optim.Adam(fused_layer_norm.parameters(), lr=0.1)
logger = LoggerWriter()
def loss_function(x):
return x.sum()
for step in range(NUM_STEPS):
# NOTE: just make the output fluctuate a bit
random = torch.randn(1, device=DEVICE) * 0.01
ref_outputs = layer_norm(inputs) * random
outputs = fused_layer_norm(inputs) * random
loss = loss_function(outputs)
ref_loss = loss_function(ref_outputs)
ref_optim.zero_grad()
ref_loss.backward()
ref_optim.step()
optim.zero_grad()
loss.backward()
optim.step()
print(f"Step: {step}, outputs: {outputs.sum()}, ref_loss: {ref_outputs.sum()}")
print(f"Step: {step}, loss: {loss}, ref_loss: {ref_loss}")
# wandb.log({"loss": loss.item(), "ref_loss": ref_loss.item(), "step": step})
logger.add_scalar("loss", loss.item(), step)
logger.add_scalar("ref_loss", ref_loss.item(), step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment