Commit dfcb88ff authored by chenzk's avatar chenzk
Browse files

v1.0.8

parents
import os
from pathlib import Path
from typing import Optional, cast
import torch
from datasets.download.streaming_download_manager import xPath
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from nanotron import distributed as dist
from nanotron import logging
from nanotron import optim as optim
from nanotron.config import Config
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open
from nanotron.sanity_checks import (
assert_tensor_synced_across_pg,
check_optim_state_in_sync,
)
from nanotron.serialize.metadata import TrainingMetadata, save_meta
from nanotron.serialize.optimizer import (
save_lr_scheduler,
save_optimizer,
)
from nanotron.serialize.weights import save_weights
"""
We're going to use safetensors. The reason is that loading segments is going to be much easier
Requirements:
- serialized format need to be able to recover the current training state. (random states, weights, optimizer states_
- serialized format should be topology agnostic. Will makes things much easier with varying topologies
Current way of thinking:
- one file = one tensor (it would create huge amount of files, but we should revisit only if that's a problem)
Version 1:
- serialize -> dumps every process weights in individual files
- load -> assume topology is exactly the same.
"""
logger = logging.get_logger(__name__)
def save(
config: "Config",
model: nn.Module,
optimizer: optim.BaseOptimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
parallel_context: ParallelContext,
training_metadata: TrainingMetadata,
root_folder: Path,
should_save_config: bool = True,
should_save_model: bool = True,
should_save_optimizer: bool = True,
should_save_lr_scheduler: bool = True,
sanity_checks: bool = True,
) -> None:
assert isinstance(training_metadata, TrainingMetadata)
try:
if should_save_config:
config.save_as_yaml(root_folder / "config.yaml")
except Exception as e:
# TODO @nouamane: catch full disk error
log_rank(
f"Error while saving config: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_model:
save_weights(model=model, parallel_context=parallel_context, root_folder=root_folder)
except Exception as e:
log_rank(
f"Error while saving weights checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_optimizer:
save_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder)
except Exception as e:
log_rank(
f"Error while saving optimizer checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
try:
if should_save_lr_scheduler:
lr_scheduler = cast(LambdaLR, lr_scheduler)
assert len(lr_scheduler.lr_lambdas) == len(
optimizer.param_groups
), "The number of lambdas functions in the scheduler should be equal to the number of parameter groups in the optimizer."
save_lr_scheduler(
lr_scheduler=lr_scheduler,
is_zero=config.optimizer.zero_stage,
parallel_context=parallel_context,
root_folder=root_folder,
)
except Exception as e:
log_rank(
f"Error while saving lr_scheduler checkpoint: {e}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
save_meta(root_folder=root_folder, parallel_context=parallel_context, training_metadata=training_metadata)
# TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs)
###
# SANITY CHECK: Check that the model params are synchronized across `parallel_context.dp_pg`
if sanity_checks:
for name, param_or_buffer in sorted(model.state_dict().items(), key=lambda x: x[0]):
assert_tensor_synced_across_pg(
tensor=param_or_buffer,
pg=parallel_context.dp_pg,
msg=lambda err: f"{name} are not synced across DP {err}",
)
# SANITY CHECK: Check that the tied parameters are synchronized
sorted_tied_parameters = sorted(
(
param
for parameters_group in optimizer.param_groups
for param in parameters_group["params"]
if param.requires_grad and isinstance(param, NanotronParameter) and param.is_tied
),
key=lambda param: param.get_tied_info().name,
)
for tied_param in sorted_tied_parameters:
tied_info = tied_param.get_tied_info()
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
assert_tensor_synced_across_pg(
tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}"
)
if not optimizer.inherit_from(optim.ZeroDistributedOptimizer):
check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg)
# SANITY CHECK: tied parameters have their optimizer states synchronized
# Compute a mapping from id_ to index in the optimizer sense
state_dict = optimizer.state_dict()
assert len(optimizer.param_groups) == len(state_dict["param_groups"])
index_to_param = {}
for real_param_group, index_param_group in zip(optimizer.param_groups, state_dict["param_groups"]):
indices = index_param_group["params"]
parameters = real_param_group["params"]
assert len(indices) == len(parameters)
for param, index in zip(parameters, indices):
assert index not in index_to_param
index_to_param[index] = param
current_state_dict = optimizer.state_dict()
for index, optim_state in sorted(current_state_dict["state"].items(), key=lambda x: x[0]):
param = index_to_param[index]
if not isinstance(param, NanotronParameter):
continue
if not param.is_tied:
# If it's not shared, we don't need to check it's synced
continue
tied_info = param.get_tied_info()
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
reference_rank = 0
current_rank = dist.get_rank(group)
for name, tensor in optim_state.items():
# FIXME @thomasw21: Some data is actually on `cpu`, just for this test we most it to `cuda`
tensor = tensor.to("cuda")
if current_rank == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=get_global_rank(group=group, group_rank=reference_rank),
group=group,
)
torch.testing.assert_close(
tensor,
reference_tensor,
atol=0,
rtol=0,
msg=lambda msg: f"tensor at {current_state_dict['names'][index]} doesn't match with our reference. Optimizer key: {name}\nCur: {tensor}\nRef: {reference_tensor}\n{msg}",
)
dist.barrier(parallel_context.world_pg)
def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]:
"""Parse checkpoint path from config and download checkpoint from S3 if needed.
Args:
config: Config object.
Returns:
Path to checkpoint or None if no checkpoint.
"""
load_from_candidate = config.checkpoints.resume_checkpoint_path
if load_from_candidate is not None:
if check_path_is_local(load_from_candidate):
latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi:
# TODO @thomasw21: make a better structure system so that we get typing correct
load_from_candidate = int(fi.read())
checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate)
elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists():
# we assume that the checkpoint path is a path to a checkpoint
checkpoint_path = config.checkpoints.resume_checkpoint_path
else:
log_rank(
f"No previous checkpoint found in: {latest_meta_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
return None
log_rank(
f"Loading checkpoint from {checkpoint_path}",
logger=logger,
level=logging.INFO,
rank=0,
)
else:
latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt"
if latest_meta_path.exists():
# if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint
with fs_open(latest_meta_path, mode="r") as fi:
latest_iteration = int(fi.read())
s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path
checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path
elif config.checkpoints.resume_checkpoint_path.exists():
# we assume that the checkpoint path is a path to a checkpoint
s3_path = config.checkpoints.resume_checkpoint_path # load_path
checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path
else:
log_rank(
f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.",
logger=logger,
level=logging.WARNING,
rank=0,
)
return None
log_rank(
f"Downloading checkpoint from S3 in {checkpoint_path} ",
logger=logger,
level=logging.WARNING,
rank=0,
)
# Download checkpoint from S3
s3_mover = S3Mover(
local_path=os.path.join(checkpoint_path),
s3_path=os.path.join(s3_path),
s5cmd_numworkers=config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=config.s3_upload.s5cmd_concurrency,
s5cmd_path=config.s3_upload.s5cmd_path,
dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0),
)
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
s3_mover.start_downloading()
s3_mover.distributed_wait_for_completion(parallel_context.world_pg)
return checkpoint_path
import dataclasses
import json
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
import dacite
import torch
from dacite import from_dict
from packaging.version import Version
from nanotron import distributed as dist
from nanotron.constants import CHECKPOINT_FILE_NAME, CHECKPOINT_VERSION
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import SlicesPair
@dataclasses.dataclass
class DataStageMetadata:
"""
consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero).
last_train_step: The last training step across all stages.
# NOTE: we should allow people to change the name of the data stages in the config file.
# but not the start_training_step, because it could
"""
name: str
start_training_step: int
consumed_train_samples: int
@dataclasses.dataclass
class TrainingMetadata:
"""
consumed_train_samples: The number of samples consumed globally, across all stages.
last_train_step: The last training step across all stages.
last_stage_idx: The index of the last stage that was trained.
data_stages: The metadata for each stage.
"""
consumed_train_samples: int
last_train_step: int
# TODO(xrsrke): make this not optional, once we entirely remove
# the old checkpoint version
last_stage_idx: Optional[int] = None
data_stages: Optional[List[DataStageMetadata]] = None
def __post_init__(self):
# NOTE: this is a sanity check after loading a trained checkpoint
total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages)
assert (
self.consumed_train_samples == total_consumed_samples_across_stages
), "Mismatch between the total consumed samples and the sum of consumed samples across stages! Something went wrong in the training."
# TODO(xrsrke): remove this once we entirely remove non-data-stage training
if self.last_stage_idx is not None:
assert self.data_stages is not None, "data_stages should not be None if last_stage_idx is not None"
@dataclasses.dataclass
class CheckpointMetadata:
version: Version
tp: int
dp: int
metas: TrainingMetadata
custom_metas: Optional[Dict[str, Any]] = None
@dataclasses.dataclass
class TensorMetadata:
# Mandatory for checkpoint version higher than 1.2
version: Version
# Anything users want to store
# Info of to what slice of the unsharded tensor (global_slices) the current sharded tensor corresponds (local_slices)
local_global_slices_pairs: Tuple[SlicesPair, ...]
# The shape of the unsharded tensor
unsharded_shape: Tuple[int, ...]
_metadata_config: ClassVar[dacite.Config] = dacite.Config(
cast=[Version],
type_hooks={
Tuple[SlicesPair, ...]: SlicesPair.tuple_from_str,
Tuple[int, ...]: lambda x: torch.Size(int(size) for size in x.strip("()").split(",") if size),
},
strict=True,
)
def to_str_dict(self) -> Dict[str, str]:
return {
"version": str(self.version),
"local_global_slices_pairs": SlicesPair.tuple_to_str(self.local_global_slices_pairs),
"unsharded_shape": str(tuple(self.unsharded_shape)),
}
@classmethod
def from_str_dict(cls, dictionary: Dict[str, str]) -> "TensorMetadata":
tensor_metadata: TensorMetadata = dacite.from_dict(
data_class=TensorMetadata,
data=dictionary,
config=cls._metadata_config,
)
return tensor_metadata
def process_type(elt: Any, type_hooks: Dict[Type, Callable[[Any], Any]]):
if isinstance(elt, dict):
return to_dict(elt, type_hooks=type_hooks)
elif elt.__class__ in type_hooks:
return type_hooks[elt.__class__](elt)
elif isinstance(elt, (list, tuple)):
return to_list(elt, type_hooks=type_hooks)
else:
return elt
def to_dict(dict_: Dict, type_hooks: Dict[Type, Callable[[Any], Any]]):
result = {}
for key, value in dict_.items():
result[key] = process_type(value, type_hooks=type_hooks)
return result
def to_list(list_: Union[List, Tuple], type_hooks: Dict[Type, Callable[[Any], Any]]):
return list_.__class__((process_type(elt, type_hooks=type_hooks) for elt in list_))
def save_meta(parallel_context: ParallelContext, root_folder: Path, training_metadata: TrainingMetadata):
assert isinstance(training_metadata, TrainingMetadata)
if dist.get_rank(parallel_context.world_pg) != 0:
return
root_folder.mkdir(exist_ok=True, parents=True)
checkpoint_metadata = CheckpointMetadata(
version=CHECKPOINT_VERSION,
tp=parallel_context.tp_pg.size(),
dp=parallel_context.dp_pg.size(),
metas=training_metadata,
)
# There are some types that require manual casting in order to work correctly.
processed_metadata = process_type(dataclasses.asdict(checkpoint_metadata), type_hooks={Version: lambda x: str(x)})
with open(root_folder / CHECKPOINT_FILE_NAME, mode="w") as fo:
json.dump(processed_metadata, fo, indent=2, sort_keys=True)
def load_meta(parallel_context: ParallelContext, root_folder: Path) -> CheckpointMetadata:
with open(root_folder / CHECKPOINT_FILE_NAME, mode="r") as fi:
checkpoint_metadata = json.load(fi)
checkpoint_metadata = from_dict(
data_class=CheckpointMetadata,
data=checkpoint_metadata,
config=dacite.Config(
cast=[Version],
),
)
# Assume that we're always backward compatible, we only increment CHECKPOINT_VERSION when there's a breaking change.
assert (
checkpoint_metadata.version <= CHECKPOINT_VERSION
), f"Checkpoint is of version {checkpoint_metadata.version}, Current `nanotron` checkpoint version is {CHECKPOINT_VERSION}"
return checkpoint_metadata
import json
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
from torch import nn
from tqdm import tqdm
from nanotron import distributed as dist
from nanotron import optim
from nanotron.optim.zero import (
ZeroDistributedOptimizer,
extract_parallel_ranks_from_shard_path,
find_optim_index_from_param_name,
get_sliced_tensor,
merge_dp_shard_in_zero1_optimizer,
)
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.serialize.metadata import TensorMetadata
from nanotron.serialize.utils import ObjectType, merge_and_shard_tp_tensors
# TODO(xrsrke): take rank instead of parallel_context
def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
def save_optimizer(
optimizer: optim.BaseOptimizer,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves optimizer states
- If Zero-0 is used, optimizer states are replicated across all DPs. Only DP-0 saves the states
- If Zero-1 is used, optimizer states are sharded across all DPs. Each DP saves its own states
"""
if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return
# TODO: Figure out if I need to save param groups. Right now I'm assuming no as we only store what's trainable
# TODO: We can probably "rotate" so that every process stores something (maybe doesn't matter if we're I/O bound)
root_folder = root_folder / "optimizer"
root_folder.mkdir(exist_ok=True, parents=True)
if dist.get_rank(parallel_context.world_pg) == 0:
with open(root_folder / "optimizer_config.json", "w") as fo:
tp_size = parallel_context.tp_pg.size()
pp_size = parallel_context.pp_pg.size()
dp_size = parallel_context.dp_pg.size()
expert_parallel_size = parallel_context.expert_parallel_size
config = {
"type": str(optimizer.__class__.__name__),
"parallelism": {
"tp_size": str(tp_size),
"dp_size": str(dp_size),
"pp_size": str(pp_size),
"expert_parallel_size": str(expert_parallel_size),
},
"configs": {},
}
if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: in order to serialize, we must save all keys and values as strings
def convert_to_string(input_item):
if isinstance(input_item, dict):
return {str(key): convert_to_string(value) for key, value in input_item.items()}
elif isinstance(input_item, list):
return [convert_to_string(element) for element in input_item]
elif isinstance(input_item, tuple):
return tuple(convert_to_string(element) for element in input_item)
else:
return str(input_item)
# NOTE: if it's a ZeRO-1 optimzier, then we save how the parameters are sharded
# across data parallel dimension, so that we can reconstruct the optimizer states
assert optimizer.param_name_to_dp_rank_offsets is not None, "param_name_to_dp_rank_offsets is required"
config["configs"]["param_name_to_dp_rank_offsets"] = convert_to_string(
optimizer.param_name_to_dp_rank_offsets
)
# NOTE: since tp sharded params are flattened, so we need to save the original param shapes
# so that we can recontruct the original shapes => reconstruct the unsharded params in tensor parallel dimension
config["configs"]["orig_param_shapes"] = convert_to_string(optimizer._orig_param_shapes)
json.dump(config, fo)
# We dump the optimizer state using `torch.save`
torch.save(
optimizer.state_dict(),
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
)
def save_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves lr scheduler states"""
if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return
root_folder = root_folder / "lr_scheduler"
root_folder.mkdir(exist_ok=True, parents=True)
# We dump the optimizer state using `torch.save`
torch.save(
lr_scheduler.state_dict(),
root_folder / lr_scheduler_filename(parallel_context, is_zero),
)
# Helper functions to move optimizer states
@torch.no_grad()
def state_dict_to_device(state_dict: Dict, device: str) -> Dict:
assert (
state_dict["state"][0]["exp_avg"].device.type == "cpu"
), "Optimizer states should be on CPU to avoid extra memory usage when loading from checkpoint"
torch.cuda.empty_cache()
for _, optim_state in sorted(state_dict["state"].items(), key=lambda x: x[0]):
for name, tensor in optim_state.items():
optim_state[name] = tensor.to(device)
assert (
state_dict["state"][0]["exp_avg"].device.type == "cuda"
), "Optimizer states should be on GPU because model is on GPU"
torch.cuda.empty_cache()
@torch.no_grad()
def load_optimizer(
optimizer: optim.BaseOptimizer,
parallel_context: ParallelContext,
root_folder: Path,
map_location: Optional[str] = None,
param_shard_metadata: Tuple[Tuple[int, int], TensorMetadata] = None, # (pp_rank, tp_rank) -> TensorMetadata
model: Optional[nn.Module] = None,
):
root_folder = root_folder / "optimizer"
ckp_optimizer_config_path = root_folder / "optimizer_config.json"
with open(ckp_optimizer_config_path, "r") as file:
ckp_optimizer_config = json.load(file)
ckp_pp_size = ckp_optimizer_config["parallelism"]["pp_size"]
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
ckpt_expert_parallel_size = ckp_optimizer_config["parallelism"]["expert_parallel_size"]
if int(ckp_tp_size) != int(parallel_context.tp_pg.size()) or int(ckp_pp_size) != int(
parallel_context.pp_pg.size()
):
if int(ckp_pp_size) != int(parallel_context.pp_pg.size()):
warnings.warn(
"You are resuming in a different PP size, so optimizer states need to be checked. Feel free to open a PR if you work on this!"
)
assert (
param_shard_metadata is not None
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
assert (
model is not None
), "You have to pass the model in order to adjust the optimizer states according to how the current parameters are sharded"
def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -> TensorMetadata:
return param_shard_metadata[param_name.replace("module.", "")][(str(pp_rank), str(tp_rank))]
ckp_optim_type = ckp_optimizer_config["type"]
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: if the checkpoint is from a Zero-1 optimizer, then we need to merge the shards
# across data parallel dimension, before merging the shards across tensor parallel dimension
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}-exp-*-of-{ckpt_expert_parallel_size}.pt"
)
)
ckp_sharded_optim_states = merge_dp_shard_in_zero1_optimizer(
model, ckp_optimizer_config, shard_paths, parallel_context, map_location
)
else:
# NOTE: if the checkpoint is from a Zero-0 optimizer, then we don't need to merge the shards
# across data parallel dimension, just directly load the checkpoints
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_tp-*-of-{ckp_tp_size}.pt"
) # WARN: wildcard here after tp can hold `0-of-1_exp-0`
)
ckp_sharded_optim_states = {}
for shard_path in shard_paths:
pp_rank, tp_rank = extract_parallel_ranks_from_shard_path(shard_path, is_zero1=False)
ckp_sharded_optim_states[(pp_rank, tp_rank)] = torch.load(
shard_path, map_location=map_location
) # load all optim states in mem
model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
new_optim_state_dict["state"] = defaultdict(dict)
# TODO: this does not handle the edge case of different pipeline parallel optimizer state shards saving different state keys
OPTIMIZER_STATE_NAMES = sorted(ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"])
OPTIMIZER_STATE_DTYPE = ckp_sharded_optim_states[(0, 0)]["state"][0][OPTIMIZER_STATE_NAMES[0]].dtype
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
param_names = list(model_state_dict.keys())
new_optim_state_param_names = {}
# NOTE: iterates through all model parameters in the local pipeline parallel rank (hence, might not be the full model).
# Since model parameters and optimizer states are aligned, loads only the optimizer states for these parameters from the checkpoint shards.
for param_index, param_name in tqdm(
enumerate(param_names),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Topology-agnostic optimizer loading",
):
try:
param = model.get_parameter(param_name)
except AttributeError:
param = None
if not isinstance(param, NanotronParameter):
raise NotImplementedError("Parameters are required to be NanotronParameter")
# NOTE: for tied parameters, the metadata is stored using the parameter name,
# while the data is stored using the name of the main tied parameter,
# which may be different (e.g. `model.token_position_embeddings.pp_block.token_embedding.weight`
# for `model.lm_head.pp_block.weight`).
base_name = param.get_tied_info().name if param.is_tied else param_name
if param_name != base_name:
# NOTE: skip tied parameter if main tied parameter has already been loaded
# (not always the case if pipeline parallel)
if base_name in new_optim_state_param_names.values():
continue
new_optim_state_param_names[param_index] = base_name
if param.is_sharded:
# NOTE: optimizer states's shape is equal to the parameter's shape
# NOTE: sometimes an unsharded parameter's shape differ
# from an unsharded optimizer state's shape
new_shard_metadata = param.get_sharded_info()
new_unshared_shape = new_shard_metadata.unsharded_shape
# NOTE: restore each state tensor (e.g. exg_avg) by iterating through
# the optimizer state shards saved using the previous topology
for state_key in OPTIMIZER_STATE_NAMES:
# TODO(xrsrke): free the memory of the shards that isn't
# corresponding to the current rank
# TODO: maybe better to allocate memory for all states at once
buffer = torch.zeros_like(param, device=map_location, dtype=OPTIMIZER_STATE_DTYPE)
unsharded_buffer = torch.empty(
new_unshared_shape, device=map_location, dtype=OPTIMIZER_STATE_DTYPE
)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # NOTE: param is not in this pp shard
ckp_shard_data = ckp_optim_state["state"][old_optim_state_index][state_key]
# NOTE: the metadata for the main parameter of a tied parameter might be in a
# different pipeline parallel shard.
if param.is_tied:
metadata_pp_rank = next(
iter(param_shard_metadata[param_name.replace("module.", "")].keys())
)[0]
else:
metadata_pp_rank = pp_rank
ckp_shard_metadata = get_checkpoint_state_metadata(param_name, metadata_pp_rank, tp_rank)
# NOTE: if the checkpoint is from a Zero-1 optimizer,
# so it's flattened, so we need to reshape it
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: this is the original shape of the parameter before being flattened
orig_shape = ckp_optimizer_config["configs"]["orig_param_shapes"][param_name]
orig_shape = [int(dim) for dim in orig_shape]
ckp_shard_data = ckp_shard_data.view(orig_shape)
new_optim_state_dict["state"][param_index][state_key] = merge_and_shard_tp_tensors(
buffer,
unsharded_buffer,
[
(ckp_shard_data, ckp_shard_metadata.local_global_slices_pairs),
],
new_shard_metadata,
)
else:
# Handle non-sharded params (e.g. layernorm)
for (pp_rank, tp_rank), ckp_optim_state in ckp_sharded_optim_states.items():
old_optim_state_index = find_optim_index_from_param_name(
base_name, ckp_sharded_optim_states, is_zero1=False, pp_rank=pp_rank
)
if old_optim_state_index is None:
continue # Param not in this PP shard
# For non-sharded params, just copy over the state directly
for state_key in OPTIMIZER_STATE_NAMES:
new_optim_state_dict["state"][param_index][state_key] = ckp_optim_state["state"][
old_optim_state_index
][state_key]
if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: flatten the optimizer states
new_optim_state_dict["state"][param_index][state_key] = new_optim_state_dict["state"][param_index][
state_key
].flatten()
# NOTE: a bit awkward, but while we're already reading this (pp,tp) shard for whatever state_key,
# try to get the step value as well.
step = ckp_optim_state["state"][old_optim_state_index].get("step")
if step is not None:
new_optim_state_dict["state"][param_index]["step"] = step
# NOTE: we throw away ckp_optim_state['gradient_accumulator'] which has fp32 grads
new_optim_state_dict["names"] = new_optim_state_param_names
state_dict = new_optim_state_dict
else:
# TODO @thomasw21: Load optimizer type and check that it's compatible otherwise we might be be loading something else completely
state_dict = torch.load(
root_folder
/ optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)),
map_location=map_location,
)
if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: only reshard after merging tp shards
# or we get a new dp_Size
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
current_dp_rank = dist.get_rank(parallel_context.dp_pg)
OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
sliced_tensor = get_sliced_tensor(
param=state_dict["state"][param_index][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
)
state_dict["state"][param_index][state_name] = sliced_tensor
optimizer.load_state_dict(state_dict, map_location=map_location)
def load_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
root_folder = root_folder / "lr_scheduler"
state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context, is_zero))
lr_scheduler.load_state_dict(state_dict)
lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate
from pathlib import Path
import torch
from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.random import RandomStates
def save_random_states(
random_states: RandomStates,
parallel_context: ParallelContext,
root_folder: Path,
):
"""All processes save their own random state"""
filename = (
root_folder
/ "random"
/ f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"
)
filename.parent.mkdir(exist_ok=True, parents=True)
# TODO @thomasw21: That's annothing but this actually uses pickle, we might need to change that for something else
torch.save(random_states, filename)
def load_random_states(parallel_context: ParallelContext, root_folder: Path):
# TODO @thomasw21: This basically assumes that we have exactly the same topology as the one we used when saving.
filename = (
root_folder
/ "random"
/ f"tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"
)
# TODO @thomasw21: That's annothing but this actually uses pickle, we might need to change that for something else
state = torch.load(filename)
return state
import re
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple
import torch
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import SlicesPair
from nanotron.serialize.metadata import TensorMetadata
class ObjectType(Enum):
MODEL = "model"
OPTIMIZER = "optimizer"
LR_SCHEDULER = "lr_scheduler"
def get_exp_tp_pp_rank_and_size_from(
world_rank: int, parallel_context: ParallelContext
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
result = parallel_context.get_local_ranks(world_rank=world_rank)
return (
(result[0], parallel_context.expert_pg.size()),
(result[3], parallel_context.tp_pg.size()),
(result[1], parallel_context.pp_pg.size()),
)
def get_path(
tensor_name: str,
type: ObjectType,
exp_tp_pp_rank_and_size: Tuple[Tuple[int, int], Tuple[int, int]],
is_expert_sharded: bool,
prefix: Optional[Path] = None,
) -> List[str]:
suffix = tensor_name.split(".")
suffix_path, suffix_name = suffix[:-1], suffix[-1]
if exp_tp_pp_rank_and_size:
# We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided
(exp_rank, exp_size), (tp_rank, tp_size), (pp_rank, pp_size) = exp_tp_pp_rank_and_size
if not is_expert_sharded or exp_size == 1:
suffix_name = (
f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}.safetensors"
)
else:
# We only show exp_rank if tensor is exp_sharded and exp_size > 1
suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors"
else:
suffix_name = f"{type.value}_{suffix_name}.safetensors"
suffix_path.append(suffix_name)
if prefix is None:
return suffix_path
else:
return prefix.joinpath(*suffix_path)
def extract_tp_pp_rank_from_shard_path(shard_path: Path):
pattern = r"pp-rank-(\d+)-of-\d+_tp-rank-(\d+)-of-\d+"
match = re.search(pattern, str(shard_path))
pp_rank, tp_rank = match.groups()
return pp_rank, tp_rank
def merge_and_shard_tp_tensors(
buffer: torch.Tensor,
unsharded_buffer: torch.Tensor,
shards_and_slices_maps: List[Tuple[torch.Tensor, Tuple[SlicesPair, ...]]],
shard_metadata: TensorMetadata,
) -> torch.Tensor:
for shard, slices_pairs in shards_and_slices_maps:
for slices_pair in slices_pairs:
local_slices = slices_pair.local_slices
global_slices = slices_pair.global_slices
unsharded_buffer[global_slices] = shard[local_slices]
for slices_pair in shard_metadata.local_global_slices_pairs:
local_slices = slices_pair.local_slices
global_slices = slices_pair.global_slices
buffer[local_slices] = unsharded_buffer[global_slices]
return buffer
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import dacite
import torch
from packaging.version import Version
from safetensors.torch import safe_open, save_file
from torch import nn
from tqdm import tqdm
from nanotron import distributed as dist
from nanotron import logging
from nanotron.constants import CHECKPOINT_VERSION
from nanotron.distributed import get_global_rank
from nanotron.logging import log_rank
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter, ShardedInfo, SlicesPair
from nanotron.serialize.metadata import CheckpointMetadata, TensorMetadata, load_meta
from nanotron.serialize.utils import (
ObjectType,
extract_tp_pp_rank_from_shard_path,
get_exp_tp_pp_rank_and_size_from,
get_path,
merge_and_shard_tp_tensors,
)
logger = logging.get_logger(__name__)
def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path):
root_folder = root_folder / "model"
# We save only `dist.get_rank(parallel_context.dp_pg) == 0`
# TODO @thomasw21: Figure how this works with Zero-3
if dist.get_rank(parallel_context.dp_pg) != 0:
return
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
# We chunk everything by `tp_world_size` in order to make sure that we gather all the weights into a single device before saving it
for name, param_or_buffer in tqdm(model.state_dict().items(), desc="Saving weights"):
# exp_rank=0 saves all weights whereas exp_rank>0 save only MLP weights
if dist.get_rank(parallel_context.expert_pg) != 0:
if "experts" not in name:
continue
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
# TODO @nouamanetazi: Handle buffers
param = None
if isinstance(param, NanotronParameter):
metadata = {}
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
group_ranks = tied_info.global_ranks
group = parallel_context.world_ranks_to_pg[group_ranks]
# Only the first rank of the group of the tied weights saves weights
# TODO @thomasw21: We could rotate in order to balance the load.
if dist.get_rank(group) != 0:
continue
else:
base_name = name
if param.is_sharded:
sharded_info: ShardedInfo = param.get_sharded_info()
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)),
parallel_context=parallel_context,
)
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
metadata = TensorMetadata(
version=CHECKPOINT_VERSION,
local_global_slices_pairs=sharded_info.local_global_slices_pairs,
unsharded_shape=sharded_info.unsharded_shape,
).to_str_dict()
else:
exp_tp_pp_rank_and_size = None
is_expert_sharded = False
path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
is_expert_sharded=is_expert_sharded,
prefix=root_folder,
)
path.parent.mkdir(exist_ok=True, parents=True)
try:
tensors = {"data": param_or_buffer}
save_file(tensors=tensors, filename=path, metadata=metadata)
except Exception as e:
log_rank(
f"Error saving {path} with {metadata}",
logger=logger,
level=logging.ERROR,
rank=0,
)
raise e
else:
raise NotImplementedError("Parameters are required to be NanotronParameter")
class CheckpointVersionFromShardFileException(Exception):
"""Raise when loading checkpoint version from shard file fails"""
def read_checkpoint_version_from_shard_file(param_save_path: Path) -> Version:
try:
with safe_open(param_save_path, framework="pt", device=str("cpu")) as fi:
param_metadata = fi.metadata()
param_metadata = TensorMetadata.from_str_dict(param_metadata)
checkpoint_version = param_metadata.version
except (dacite.exceptions.MissingValueError, dacite.exceptions.UnexpectedDataError):
raise CheckpointVersionFromShardFileException()
return checkpoint_version
def read_checkpoint_version_from_meta(parallel_context: ParallelContext, root_folder: Path) -> Version:
checkpoint_metadata: CheckpointMetadata = load_meta(parallel_context=parallel_context, root_folder=root_folder)
checkpoint_version = checkpoint_metadata.version
return checkpoint_version
def get_checkpoint_version(parallel_context, root_folder, param_save_path: Path) -> Version:
try:
checkpoint_version = read_checkpoint_version_from_shard_file(param_save_path=param_save_path)
except CheckpointVersionFromShardFileException:
log_rank(
f"Failed to read checkpoint version from shard file {param_save_path}, reading from meta file.",
logger=logger,
level=logging.ERROR,
rank=0,
)
checkpoint_version = read_checkpoint_version_from_meta(
parallel_context=parallel_context, root_folder=root_folder
)
return checkpoint_version
def load_sharded_param_latest(
param_or_buffer: torch.Tensor,
sharded_info: ShardedInfo,
shards_path: List[Path],
param_shard_metadata: Optional[Dict] = None,
):
checkpoint_unsharded_shape = None
shards_and_slices_maps: List[Tuple[torch.Tensor, Tuple[SlicesPair, ...]]] = []
for shard_path in shards_path:
with safe_open(shard_path, framework="pt", device=str(param_or_buffer.device)) as fi:
# TODO @thomasw21: Choose only a slice if we switch the TP topology
param_metadata = fi.metadata()
param_metadata = TensorMetadata.from_str_dict(param_metadata)
shards_and_slices_maps.append((fi.get_tensor("data"), param_metadata.local_global_slices_pairs))
if checkpoint_unsharded_shape is None:
checkpoint_unsharded_shape = param_metadata.unsharded_shape
else:
assert checkpoint_unsharded_shape == param_metadata.unsharded_shape
if param_shard_metadata is not None:
# NOTE: store how does model parameter are sharded
# so that we can shard optimizer checkpoints in this way
pp_rank, tp_rank = extract_tp_pp_rank_from_shard_path(shard_path)
param_shard_metadata[(pp_rank, tp_rank)] = param_metadata
assert checkpoint_unsharded_shape is not None
# TODO @thomasw21: Interestingly enough we don't actually need to instantiate the entire model at all.
unsharded_tensor = torch.empty(checkpoint_unsharded_shape, device=param_or_buffer.device)
merge_and_shard_tp_tensors(
buffer=param_or_buffer,
unsharded_buffer=unsharded_tensor,
shards_and_slices_maps=shards_and_slices_maps,
shard_metadata=sharded_info,
)
return param_shard_metadata
def load_weights(
model: nn.Module,
parallel_context: ParallelContext,
root_folder: Path,
filtered_state_dict: Optional[Dict[str, Any]] = None,
):
"""Load weights from a checkpoint
Args:
model: model to load weights into
parallel_context: distributed process groups
root_folder: root folder of the checkpoint
filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict()
"""
param_root_folder = root_folder / "model"
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
checkpoint_version: Optional[Version] = None
filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict()
param_shard_metadata = {}
for name, param_or_buffer in tqdm(
filtered_state_dict.items(), disable=dist.get_rank(parallel_context.world_pg) != 0, desc="Loading weights"
):
# NOTE: extract how does the current model parameter are sharded
# so that we can load optimizer checkpoints in this way
param_shard_metadata[name] = {}
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
param = None
if isinstance(param, NanotronParameter):
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
base_name = name
if param.is_sharded:
sharded_info = param.get_sharded_info()
if param.is_tied:
# When params are tied only the first rank of tied param group stores weights (see save_weights)
group = parallel_context.world_ranks_to_pg[tied_info.global_ranks]
group_rank = 0
else:
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
group_rank = dist.get_rank(group)
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context
)
# TODO @nouamane: do we consider exp_size=1 expert_sharded?
is_expert_sharded = sharded_info.is_expert_sharded(parallel_context)
else:
exp_tp_pp_rank_and_size = None
is_expert_sharded = False
path = get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
prefix=param_root_folder,
is_expert_sharded=is_expert_sharded,
)
if path.exists():
with safe_open(path, framework="pt", device=str(param.device)) as fi:
# TODO @thomasw21: Choose only a slice if we switch the TP topology
param_or_buffer[:] = fi.get_tensor("data")
elif not path.parent.exists():
raise ValueError(
f"Checkpoint is empty or checkpoint structure is not matching the model architecture."
f"Couldn't find folder {path.parent} in checkpoint at {root_folder}"
)
else:
# Let's assume that the topology changed and the param is sharded.
# We search for all the files from the shards, concatenate the "unsharded" tensor
# and load the specific shard we're interested in.
if not param.is_sharded:
raise ValueError(
f"`{name}` is not a sharded parameter. It's possible you were expecting {path} to exist."
)
# TODO @thomasw21: Make so that we don't need to code this logic somewhere else than in `get_path`
sharded_info = param.get_sharded_info()
suffix = base_name.rsplit(".", 1)[-1]
shards_path = list(path.parent.glob(f"{ObjectType.MODEL.value}_{suffix}*.safetensors"))
if len(shards_path) <= 0:
raise ValueError(
f"Could not find any shards {ObjectType.MODEL.value}_{suffix}*.safetensors in {path.parent}."
f"If you notice `.safetensors` in the middle of the name of some of the checkpoints files. You need to run `scripts/fix_checkpoint_bad_naming.py`."
)
if checkpoint_version is None:
checkpoint_version = get_checkpoint_version(
parallel_context, root_folder, param_save_path=shards_path[0]
)
else:
current_checkpoint_version = None
try:
current_checkpoint_version = read_checkpoint_version_from_shard_file(
param_save_path=shards_path[0]
)
except CheckpointVersionFromShardFileException:
# The checkpoint version is read from the meta file
current_checkpoint_version = checkpoint_version
finally:
assert (
current_checkpoint_version == checkpoint_version
), f"Checkpoint version mismatch at {shards_path[0]}."
if checkpoint_version <= CHECKPOINT_VERSION:
load_sharded_param_latest(
param_or_buffer=param_or_buffer,
sharded_info=sharded_info,
shards_path=shards_path,
param_shard_metadata=param_shard_metadata[name],
)
else:
raise ValueError(f"Unsupported checkpoint version {checkpoint_version}")
else:
raise NotImplementedError(f"Parameters {param} should be a NanotronParameter")
return param_shard_metadata
def get_checkpoint_paths_list(
model: nn.Module,
parallel_context: ParallelContext,
root_folder: Path,
only_list_folders: bool = False,
only_list_current_process: bool = True,
filtered_state_dict: Optional[Dict[str, Any]] = None,
):
"""Return the list of all the files or folders created/accessed by the current process in a checkpoint
Args:
model: model to load weights into
parallel_context: distributed process groups
root_folder: root folder of the checkpoint
filtered_state_dict: state dict to load from (overrides model.state_dict()). if None, load from model.state_dict()
"""
param_root_folder = root_folder / "model"
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
paths = []
filtered_state_dict = filtered_state_dict if filtered_state_dict is not None else model.state_dict()
for name in tqdm(
filtered_state_dict.values(),
disable=dist.get_rank(parallel_context.world_pg) != 0,
desc="Listing checkpoint paths",
):
# `state_dict` doesn't return a Param or a buffer, just a tensors which loses some metadata
try:
param = model.get_parameter(name)
except AttributeError:
param = None
if isinstance(param, NanotronParameter) or not only_list_current_process:
if param.is_tied:
tied_info = param.get_tied_info()
base_name = tied_info.get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
else:
base_name = name
if param.is_sharded:
sharded_info = param.get_sharded_info()
if param.is_tied:
# When params are tied only the first rank of tied param group stores weights (see save_weights)
group = parallel_context.world_ranks_to_pg[tied_info.global_ranks]
group_rank = 0
else:
group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks]
group_rank = dist.get_rank(group)
exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from(
world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context
)
else:
exp_tp_pp_rank_and_size = None
if only_list_folders:
paths.append(param_root_folder.joinpath(base_name.split(".")[:-1]))
else:
paths.append(
get_path(
base_name,
type=ObjectType.MODEL,
exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size,
prefix=param_root_folder,
)
)
return paths
import datetime
import gc
import json
import os
import shutil
import time
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import (
Config,
DatasetStageArgs,
ExistingCheckpointInit,
ParallelismArgs,
RandomInit,
SpectralMupInit,
get_config_from_file,
)
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.dataloader import sanity_check_dataloader
from nanotron.helpers import (
_vocab_size_with_padding,
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
get_profiler,
init_optimizer_and_grad_accumulator,
init_random_states,
log_throughput,
lr_scheduler_builder,
)
from nanotron.logging import (
LoggerWriter,
LogItem,
human_format,
log_memory,
log_rank,
set_ranks_logging_level,
)
from nanotron.models import NanotronModel, build_model
from nanotron.models.base import check_model_has_grad
from nanotron.models.llama import LlamaForTraining, RotaryEmbedding
from nanotron.models.starcoder2 import Starcoder2ForTraining
from nanotron.optim.clip_grads import clip_grad_norm
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter, sanity_check
from nanotron.parallel.pipeline_parallel.engine import (
PipelineEngine,
TensorPointer,
)
from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear
from nanotron.parallel.tied_parameters import (
create_pg_for_tied_weights,
get_tied_id_to_param,
sync_tied_weights_gradients,
tie_parameters,
)
from nanotron.random import set_random_seed
from nanotron.s3_checkpoints import S3Mover, check_path_is_local
from nanotron.sanity_checks import (
after_optim_step_sanity_checks,
after_tbi_sanity_checks,
before_optim_step_sanity_checks,
before_tbi_sanity_checks,
)
from nanotron.scaling.parametrization import ParametrizationMethod
from nanotron.serialize import (
load_lr_scheduler,
load_meta,
load_weights,
parse_ckpt_path,
save,
save_random_states,
)
from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata
from nanotron.serialize.optimizer import load_optimizer, state_dict_to_device
logger = logging.get_logger(__name__)
# Reduce the logging noise from torch.distributed when creating new process groups
dist_logger = logging.get_logger(dist.dist.__name__)
dist_logger.setLevel(logging.WARNING)
CONFIG_TO_MODEL_CLASS = {
"LlamaConfig": LlamaForTraining,
"Starcoder2Config": Starcoder2ForTraining,
}
try:
import wandb
except ImportError:
wandb = None
class DistributedTrainer:
def __init__(
self,
config_or_config_file: Union[Config, str],
config_class: Type[Config] = Config,
model_config_class: Optional[Type] = None,
model_class: Type[NanotronModel] = None,
):
"""
Nanotron's distributed trainer.
Args:
config_or_config_file: Either a `Config` object or a path to a YAML file containing the config.
config_class: The `Config` class to use.
model_config_class: The `ModelConfig` class to use (for example `LlamaConfig`). Defaults to `None` which will use the model config class defined in the config.
model_class: The `NanotronModel` class to use (for example `LlamaForTraining`). Defaults to `None` which will use the model class defined in the config.
"""
super().__init__()
self.config = get_config_from_file(
config_or_config_file, config_class=config_class, model_config_class=model_config_class
)
self.model_config = self.config.model.model_config
if model_class is not None:
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class
########################################
## We start with setting up loggers and process groups
########################################
# Initialise all process groups
self.parallel_context = ParallelContext(
tensor_parallel_size=self.config.parallelism.tp,
pipeline_parallel_size=self.config.parallelism.pp,
data_parallel_size=self.config.parallelism.dp,
expert_parallel_size=self.config.parallelism.expert_parallel_size,
)
self.pre_init()
# Set log levels
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)
# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
log_throughput(self.config, self.parallel_context)
########################################
## Setting up our model, optimizers, schedulers, etc.
########################################
# Set random states
set_random_seed(self.config.general.seed)
# Init model and build on pp ranks
self.random_states = init_random_states(
parallel_config=self.config.parallelism, tp_pg=self.parallel_context.tp_pg
)
self.model = self.init_model() # Defines self.model
print("self.model:", self.model)
self.unwrapped_model: NanotronModel = (
self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
)
# TODO: find a better way to handle this
parametrization_method = (
ParametrizationMethod.SPECTRAL_MUP
if hasattr(self.config.model.init_method, "use_mup") and self.config.model.init_method.use_mup
else ParametrizationMethod.STANDARD
)
# Init optimizer
self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator(
parametrization_method=parametrization_method,
model=self.model,
optimizer_args=self.config.optimizer,
parallel_context=self.parallel_context,
)
if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer:
load_optimizer(
optimizer=self.optimizer,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
param_shard_metadata=self.param_shard_metadata,
model=self.unwrapped_model,
map_location="cpu",
)
# Init learning rate scheduler
self.lr_scheduler = lr_scheduler_builder(
optimizer=self.optimizer,
lr_scheduler_args=self.config.optimizer.learning_rate_scheduler,
total_training_steps=self.config.tokens.train_steps,
)
if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
is_zero=self.config.optimizer.zero_stage,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
# Define iteration start state
if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler:
checkpoint_metadata = load_meta(
parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
)
assert isinstance(checkpoint_metadata.metas, TrainingMetadata)
log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0)
self.metadata: TrainingMetadata = checkpoint_metadata.metas
# NOTE: we should not change data stages
assert (
self.config.tokens.train_steps > self.metadata.last_train_step
), f"Loaded checkpoint has already trained {self.metadata.last_train_step} batches, you need to specify a higher `config.tokens.train_steps`"
else:
data_stages = [
DataStageMetadata(
name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0
)
for stage in self.config.data_stages
]
self.metadata: TrainingMetadata = TrainingMetadata(
consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages
)
# Setup tensorboard write and log writers on output rank
self.logger_ranks = self.parallel_context.get_global_rank(
ep_rank=0, pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0
).flatten()
self.loggerwriter = self.setup_log_writers()
# Log where each module is instantiated
self.unwrapped_model.log_modules(level=logging.DEBUG, group=self.parallel_context.world_pg, rank=0)
self.micro_batch_size = self.config.tokens.micro_batch_size
self.n_micro_batches_per_batch = self.config.tokens.batch_accumulation_per_replica
self.global_batch_size = (
self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size()
)
self.sequence_length = self.config.tokens.sequence_length
self.iteration_step = self.metadata.last_train_step
self.limit_val_batches = self.config.tokens.limit_val_batches
# NOTE: the dataloader currently in use for the current training stage
self.current_dataloader: Optional[DataLoader] = None
self.post_init()
def pre_init(self):
self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context)
def post_init(self):
# S3 Mover and save initial state
if self.config.s3_upload is not None:
# NOTE: Only local rank 0 should upload
dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0)
self.s3_mover = S3Mover(
local_path=self.config.checkpoints.checkpoints_path,
s3_path=self.config.s3_upload.upload_s3_path,
remove_after_upload=self.config.s3_upload.remove_after_upload,
s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers,
s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency,
s5cmd_path=self.config.s3_upload.s5cmd_path,
dummy=dummy,
)
else:
self.s3_mover = None
def pre_training(self, *args, **kwargs):
self._print_training_plan()
metadata: TrainingMetadata = self.metadata
log_rank(
f"[Start training] datetime: {datetime.datetime.now()} | mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | global_batch_size: {self.global_batch_size} | sequence_length: {self.sequence_length} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_train_samples: {metadata.consumed_train_samples}", # noqa
logger=logger,
level=logging.INFO,
rank=0,
)
current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.init(
project=self.config.general.project,
name=f"{current_time}_{self.config.general.run}",
config={"nanotron_config": self.config.as_dict()},
)
def post_train_step(self):
# Update our background upload/removal of checkpoints
if self.s3_mover is not None:
self.s3_mover.update()
def post_training(self):
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg)
def _print_training_plan(self):
if hasattr(self.config, "data_stages") and self.config.data_stages is not None:
stages_info = "".join(
f"[Stage {stage.name}] start from step {stage.start_training_step} \n"
for stage in self.config.data_stages
)
full_log_message = (
f"[Training Plan] There are {len(self.config.data_stages)} training stages \n{stages_info}"
)
log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0)
def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
from collections.abc import Generator
if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
if self.current_dataloader is None:
if isinstance(dataloaders, tuple):
dataloader = dataloaders[0]
else:
dataloader = dataloaders
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
# remove this in the next PR
self.current_dataloader = dataloaders
return
assert len(dataloaders) > 0, "No dataloaders provided"
assert len(dataloaders) == len(
self.config.data_stages
), "Number of dataloaders should match the number of dataset stages"
def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc
log_rank(
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
# NOTE: Clear dataloader from memory
del dataloader.dataset
del dataloader.sampler
del dataloader.batch_sampler
gc.collect()
dataloader = None
def find_stage_idx_to_resume():
reversed_data_stages = sorted(self.config.data_stages, key=lambda x: x.start_training_step, reverse=True)
for idx, stage in enumerate(reversed_data_stages):
if self.iteration_step >= stage.start_training_step:
return len(self.config.data_stages) - idx - 1
return None
stage_idx_to_resume = find_stage_idx_to_resume()
for stage_idx, stage in enumerate(self.config.data_stages):
if stage_idx < self.metadata.last_stage_idx:
continue
stage = cast(DatasetStageArgs, stage)
is_resume_from_training = self.current_dataloader is None and stage_idx_to_resume == stage_idx
if (stage.start_training_step == self.iteration_step) or is_resume_from_training:
if self.current_dataloader is not None:
prev_stage_name = self.config.data_stages[stage_idx - 1].name
prev_dataloader = dataloaders[prev_stage_name]
if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
self.metadata.last_stage_idx = stage_idx
if is_resume_from_training:
remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, self.config, self.metadata
)
consumed_train_steps = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata)
log_rank(
f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps",
logger=logger,
level=logging.INFO,
rank=0,
)
dataloader = dataloaders[stage.name]
# NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
dataloader = dataloader() if callable(dataloader) else dataloader
break
if dataloader is not None:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
def train(
self,
dataloader_or_dls: Dict[
str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]
],
**kwargs,
) -> None:
self.pre_training(**kwargs)
if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None:
self.save_checkpoint()
self.pipeline_engine: PipelineEngine = self.config.parallelism.pp_engine
self.pipeline_engine.nb_microbatches = self.n_micro_batches_per_batch
# TODO @nouamanetazi: refactor this
# Useful mapping
self.unwrapped_model = self.model.module if isinstance(self.model, DistributedDataParallel) else self.model
self.unwrapped_model.module_id_to_prefix = {
id(module): f"{module_name}." for module_name, module in self.unwrapped_model.named_modules()
}
# Fix the root_model
self.unwrapped_model.module_id_to_prefix[id(self.unwrapped_model)] = ""
self.initial_iter_step = self.metadata.last_train_step + 1
self.last_iter_step = self.config.tokens.train_steps
prof = get_profiler(config=self.config)
# free memory
gc.collect()
torch.cuda.empty_cache()
with prof:
for self.iteration_step in range(self.initial_iter_step, self.last_iter_step + 1):
if isinstance(prof, torch.profiler.profile):
prof.step()
self.iteration_start_time = time.time()
self._update_dataloader_based_on_training_stages(dataloader_or_dls)
# Training step
outputs, loss_avg = self.training_step(dataloader=self.current_dataloader)
# Training Logs
# TODO(xrsrke): refactor using callbacks would be better
self.metadata.consumed_train_samples += self.global_batch_size
self.metadata.last_train_step = self.iteration_step
self.metadata.data_stages[
self.metadata.last_stage_idx
].consumed_train_samples += self.global_batch_size
if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0:
self.train_step_logs(outputs=outputs, loss_avg=loss_avg)
# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
self.save_checkpoint()
dist.barrier() # let's wait for everyone before leaving
if self.config.checkpoints.save_final_state:
self.save_checkpoint()
self.post_training()
def training_step(
self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]
) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]:
before_tbi_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler
)
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)
outputs = self.pipeline_engine.train_batch_iter(
model=self.model,
pg=self.parallel_context.pp_pg,
batch=(next(dataloader) for _ in range(self.n_micro_batches_per_batch)),
nb_microbatches=self.n_micro_batches_per_batch,
grad_accumulator=self.grad_accumulator,
)
if self.iteration_step < self.initial_iter_step + 5:
log_memory(logger=logger)
after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
if isinstance(self.model, DistributedDataParallel) and self.grad_accumulator is not None:
# Wait for fp32 grads allreduce to finish to make sure grads are synced across DP
assert (
self.grad_accumulator.fp32_grads_allreduce_handle is not None
), "No fp32_grads_allreduce_handle maybe you're using only a single training process"
self.grad_accumulator.fp32_grads_allreduce_handle.wait()
# Sync tied weights
if not isinstance(self.model, DistributedDataParallel):
# Manually sync across DP if it's not handled by DDP
sync_gradients_across_dp(
module=self.model,
dp_pg=self.parallel_context.dp_pg,
reduce_op=dist.ReduceOp.AVG,
# TODO @thomasw21: This is too memory hungry, instead we run all_reduce
reduce_scatter=False, # optimizer.inherit_from(ZeroDistributedOptimizer),
grad_accumulator=self.grad_accumulator,
)
# TODO @nouamane: Put this in hooks so we can overlap communication with gradient computation on the last backward pass.
sync_tied_weights_gradients(
module=self.unwrapped_model,
parallel_context=self.parallel_context,
grad_accumulator=self.grad_accumulator,
)
# Clip gradients
if self.config.optimizer.clip_grad is not None:
# Unwrap DDP
named_parameters = [
(name, param)
for name, param in self.unwrapped_model.get_named_params_with_correct_tied()
if param.requires_grad
]
self.grad_norm_unclipped = clip_grad_norm(
mp_pg=self.parallel_context.mp_pg,
named_parameters=named_parameters,
grad_accumulator=self.grad_accumulator,
max_norm=self.config.optimizer.clip_grad,
)
# Compute DP average loss and overlap with optimizer step
if isinstance(outputs[0]["loss"], torch.Tensor):
# This is an average on only one data rank.
loss_avg = torch.stack(
[output["loss"] for output in outputs]
).sum() # already divided by n_micro_batches_per_batch
# sync loss across DP
handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG)
else:
loss_avg = None
handle = None
# Move optimizer states back to GPU before optimizer step
if (
self.init_checkpoint_path is not None
and self.config.checkpoints.load_optimizer
and self.iteration_step == self.initial_iter_step
):
state_dict_to_device(self.optimizer.state_dict(), "cuda")
before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer
)
# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
# Update the learning rate
self.lr_scheduler.step()
after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)
if handle is not None:
handle.wait()
self.post_train_step()
return outputs, loss_avg
def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]:
outputs = self.pipeline_engine.validate_batch_iter(
model=self.model,
batch=(next(dataloader) for _ in range(self.limit_val_batches)),
nb_microbatches=self.limit_val_batches,
)
return outputs
def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
) -> None:
# TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607
dist.barrier()
torch.cuda.synchronize()
elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000
tokens_per_sec = (
self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000)
) # tokens_per_sec is calculated using sequence_length
model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec(
iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000,
sequence_length=self.sequence_length,
global_batch_size=self.global_batch_size,
)
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks"
lr = self.lr_scheduler.get_last_lr()[0]
log_entries = [
# LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"),
LogItem(
"consumed_tokens",
self.metadata.consumed_train_samples * self.config.tokens.sequence_length,
"human_format",
), # , "12d"),
LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"),
LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"),
LogItem(
"tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format"
), # , "1.6E"),
LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"),
LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"),
LogItem("lr", lr, "human_format"), # , ".3E"),
LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"),
LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"),
]
if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))
# Log not too often the memory
if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0:
total, used, free = shutil.disk_usage("/")
log_entries.extend(
[
LogItem(
"cuda_memory_allocated", torch.cuda.memory_allocated(), "human_format"
), # / 1024**2, ".2f"),
LogItem(
"cuda_max_memory_reserved", torch.cuda.max_memory_reserved(), "human_format"
), # / 1024**2, ".2f"),
LogItem("hd_total_memory_tb", total, "human_format"), # / (2**40), ".2f"),
LogItem("hd_used_memory_tb", used, "human_format"), # / (2**40), ".2f"),
LogItem("hd_free_memory_tb", free, "human_format"), # / (2**40), ".2f"),
]
)
# NOTE: only one rank writes to wandb
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
)
self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)
# Nanotron Benchmark mode: we log the throughput and exit
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1" and self.iteration_step == 3:
log_throughput(
self.config,
self.parallel_context,
model_tflops,
hardware_tflops,
tokens_per_sec,
)
log_rank("Throughput logging complete", logger=logger, level=logging.INFO)
if "SLURM_JOB_ID" in os.environ:
os.system("scancel " + os.environ["SLURM_JOB_ID"])
else:
exit(0)
def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
"""Initialize the model and load weights from checkpoint if needed."""
# TODO: add max_position_embeddings
self.model_config.vocab_size = _vocab_size_with_padding(
self.model_config.vocab_size,
pg_size=self.parallel_context.tp_pg.size(),
make_vocab_size_divisible_by=self.config.model.make_vocab_size_divisible_by,
)
if (
getattr(self.model_config, "max_position_embeddings", None) is not None
and self.model_config.max_position_embeddings != self.config.tokens.sequence_length
):
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
log_rank(
f"Finetuning a model with a sequence length {self.config.tokens.sequence_length} that is different from the checkpoint's max_position_embeddings {self.model_config.max_position_embeddings}.", # noqa
logger=logger,
level=logging.WARNING,
rank=0,
)
else:
assert (
self.config.tokens.sequence_length == self.model_config.max_position_embeddings
), "The tokenizer's sequence length does not match the model's maximum position embeddings."
log_rank("Config:\n" + pformat(self.config), logger=logger, level=logging.INFO, rank=0)
log_rank("Model Config:\n" + pformat(self.model_config), logger=logger, level=logging.INFO, rank=0)
model = self._init_model_instance()
model = self._load_model_checkpoint(model)
return model
def _init_model_instance(self) -> NanotronModel:
model_config_cls = self.model_config.__class__.__name__
assert (
model_config_cls in CONFIG_TO_MODEL_CLASS
), f"Unsupported model config {model_config_cls}. Only {CONFIG_TO_MODEL_CLASS.keys()} are supported"
model = self._init_model(
model_builder=lambda: CONFIG_TO_MODEL_CLASS[model_config_cls](
config=self.model_config,
parallel_context=self.parallel_context,
parallel_config=self.config.parallelism,
random_states=self.random_states,
),
)
return model
def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel:
unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model
# Load or initialize model weights
reloaded_from_checkpoint = False
if self.init_checkpoint_path is not None:
# Load from a pre existing checkpoint
if check_path_is_local(self.init_checkpoint_path):
# Reload from a training checkpoint
log_rank(
f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0
)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
reloaded_from_checkpoint = True
if not reloaded_from_checkpoint:
log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0)
if isinstance(self.config.model.init_method, ExistingCheckpointInit):
# Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...)
self.param_shard_metadata = load_weights(
model=unwrapped_model,
parallel_context=self.parallel_context,
root_folder=self.config.model.init_method.path,
)
elif isinstance(self.config.model.init_method, (RandomInit, SpectralMupInit)):
unwrapped_model.init_model_randomly(config=self.config)
# Synchronize parameters so that the model is consistent
# sync all params across dp
for _, param in sorted(model.named_parameters(), key=lambda x: x[0]):
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)
# sync tied params across tied groups
for (_, group_ranks), param in sorted(
get_tied_id_to_param(
parameters=model.parameters(),
root_module=unwrapped_model,
).items(),
key=lambda x: x[0],
):
group = self.parallel_context.world_ranks_to_pg[group_ranks]
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=group)
else:
raise ValueError(f"Unsupported {self.config.model.init_method}")
return model
def _init_model(
self,
model_builder: Callable[[], NanotronModel],
target_pp_ranks: Optional[List[int]] = None,
) -> NanotronModel:
config = self.config
parallel_context = self.parallel_context
parallel_config = config.parallelism
make_ddp = parallel_context.data_parallel_size > 1 and not (
config.optimizer.accumulate_grad_in_fp32 and config.optimizer.zero_stage > 0
)
# Build model and set pp ranks
model = build_model(
parallel_context=parallel_context,
dtype=config.model.dtype,
target_pp_ranks=target_pp_ranks,
model_builder=model_builder,
)
# Initialize rotary embeddings
for module in model.modules():
if not isinstance(module, RotaryEmbedding):
continue
module.init_rotary_embeddings()
# Mark some parameters as tied
self._mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
# count number of parameters
num_params = sum(p.numel() for p in model.parameters())
size_params = sum(p.numel() * p.element_size() for p in model.parameters())
total_params = torch.tensor(num_params, device="cuda")
total_size = torch.tensor(size_params, device="cuda")
dist.all_reduce(total_params, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM) # TP
dist.all_reduce(total_params, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM) # PP
dist.all_reduce(total_size, group=parallel_context.tp_pg, async_op=False, op=dist.ReduceOp.SUM)
dist.all_reduce(total_size, group=parallel_context.pp_pg, async_op=False, op=dist.ReduceOp.SUM)
# TODO @nouamanetazi: better memory logs
log_rank(
f"Total number of parameters: {human_format(total_params.item())} ({total_size.item() / 1024**2:.2f}MiB)",
logger=logger,
level=logging.INFO,
group=parallel_context.world_pg,
rank=0,
)
log_rank(
f"Local number of parameters: {human_format(num_params)} ({size_params / 1024**2:.2f}MiB)",
logger=logger,
level=logging.INFO,
group=parallel_context.dp_pg,
rank=0,
)
log_rank(
f"[After model building] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB."
f" Peak allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB"
f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB",
logger=logger,
level=logging.INFO,
group=parallel_context.dp_pg,
rank=0,
)
# Model make it DDP
if make_ddp is True:
# Check that the model has at least one grad. Necessary for DDP
check_model_has_grad(model=model, parallel_context=parallel_context)
# TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway)
model = DistributedDataParallel(
model,
process_group=parallel_context.dp_pg,
broadcast_buffers=False,
bucket_cap_mb=config.model.ddp_bucket_cap_mb,
)
# Sanity check the model, all parameters must be NanotronParameter (either tied or sharded)
sanity_check(root_module=model)
return model
def setup_log_writers(
self,
) -> Optional[LoggerWriter]:
"""Setup all log writers on the appropriate ranks
Args:
config (Config): The config object
logger_ranks (Iterable[int]): The ranks that should log
parallel_context (DistributedProcessGroups): The distributed process groups
"""
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
loggerwriter = LoggerWriter(global_step=self.config.tokens.train_steps)
else:
loggerwriter = None
return loggerwriter
def pre_save_checkpoint(self) -> Path:
if self.s3_mover is not None:
self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg)
if self.s3_mover.post_upload_callback_outputs is not None:
slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs
self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval")
def post_save_checkpoint(self):
# Upload to S3
if self.s3_mover is not None:
self.s3_mover.start_uploading()
def save_checkpoint(self) -> Path:
self.pre_save_checkpoint()
checkpoints_path = self.config.checkpoints.checkpoints_path
checkpoint_path = checkpoints_path / f"{self.iteration_step}"
if self.config.checkpoints.checkpoints_path_is_shared_file_system:
should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0
else:
should_mkdir = bool(int(os.environ.get("LOCAL_RANK", None)) == 0)
if should_mkdir:
checkpoint_path.mkdir(parents=True, exist_ok=True)
dist.barrier(self.parallel_context.world_pg)
log_rank(f"Saving checkpoint at {checkpoint_path}", logger=logger, level=logging.WARNING, rank=0)
# Update step/samples numbers before we save the config
self.config.general.step = self.metadata.last_train_step
self.config.general.consumed_train_samples = self.metadata.consumed_train_samples
save(
model=self.unwrapped_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
should_save_model=bool(
dist.get_rank(self.parallel_context.dp_pg) == 0
), # We only save the weights on DP==0
should_save_optimizer=True,
should_save_lr_scheduler=True,
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
parallel_context=self.parallel_context,
root_folder=checkpoint_path,
training_metadata=self.metadata,
config=self.config,
)
save_random_states(
random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path
)
with open(checkpoints_path / "latest.txt", mode="w") as fo:
fo.write(f"{self.iteration_step}")
if hasattr(self.model_config, "to_json_file"):
self.model_config.to_json_file(checkpoint_path / MODEL_CONFIG_FILE_NAME)
else:
with open(checkpoint_path / MODEL_CONFIG_FILE_NAME, mode="w") as fo:
fo.write(json.dumps(asdict(self.model_config)))
self.post_save_checkpoint()
return checkpoint_path
def _mark_tied_parameters(
self,
model: NanotronModel,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs] = None,
):
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
def mark_tied_parameters(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None
):
# Tie embeddings
embeddings_lm_head_tied_names = model.get_embeddings_lm_head_tied_names()
if len(embeddings_lm_head_tied_names) > 0:
shared_embeddings = [
(
target,
(
parallel_context.get_global_rank(
ep_rank=dist.get_rank(parallel_context.expert_pg),
pp_rank=get_pp_rank_of(target, module=model),
dp_rank=dist.get_rank(parallel_context.dp_pg),
tp_rank=dist.get_rank(parallel_context.tp_pg),
),
),
)
for target in embeddings_lm_head_tied_names
]
tie_parameters(
root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM
)
# Tie custom params
model.tie_custom_params()
# Sync all parameters that have the same name and that are not sharded across TP and EXP
assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point"
mark_unsharded_params_as_tied_across_tp(model, parallel_context, parallel_config)
mark_unsharded_params_as_tied_across_expert(model, parallel_context, parallel_config)
create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context)
def mark_unsharded_params_as_tied_across_tp(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs"
):
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if isinstance(param, NanotronParameter):
# We skip tying if param already tied or sharded along tp
if param.is_tied:
continue
if param.is_sharded:
sharded_info = param.get_sharded_info()
if sharded_info.is_tp_sharded(parallel_context=parallel_context):
continue
if isinstance(module, TensorParallelRowLinear) and "bias" == param_name:
# bias for TensorParallelRowLinear only exists on TP=0 so we don't need to tie it
continue
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
if parallel_config is None or parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# We add `reduce_op=None` in order to signal that the weight are synced by design without needing to reduce
# when TP=2 we have LN that is duplicated across TP, so by design it's tied
reduce_op = None
else:
reduce_op = dist.ReduceOp.SUM
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)
def mark_unsharded_params_as_tied_across_expert(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: "ParallelismArgs"
):
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
name = f"{module_name}.{param_name}"
if isinstance(param, NanotronParameter):
# We skip tying if param already tied or sharded along expert
if param.is_tied:
continue
if param.is_sharded:
sharded_info = param.get_sharded_info()
if sharded_info.is_expert_sharded(parallel_context):
continue
shared_weights = [
(
name,
# sync across expert group
tuple(sorted(dist.get_process_group_ranks(parallel_context.expert_pg))),
)
]
# Besides MoE block which sees shards tokens, the rest of the model sees the full tokens
# so we don't need to reduce the gradients across expert group
reduce_op = None
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)
import functools
import inspect
import os
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import ContextManager, List, Optional
import torch
from packaging import version
from torch import nn
from torch.utils.checkpoint import checkpoint
from nanotron import distributed as dist
class Singleton(type):
"""
Singleton metaclass.
Create objects using this class as the metaclass to enable singleton behaviour.
For instance:
```
class Logger(metaclass=Singleton):
...
```
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
in the `transformers` library.
"""
def __init__(self, context_managers: List[ContextManager]):
self.context_managers = context_managers
self.stack = ExitStack()
def __enter__(self):
for context_manager in self.context_managers:
self.stack.enter_context(context_manager)
def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({[context_manager.gen.__qualname__ for context_manager in self.context_managers]})"
@contextmanager
def main_rank_first(group: dist.ProcessGroup):
"""Context manager that executes the code in the context with the rank zero of the group going first."""
is_main = dist.get_rank(group) == 0
if is_main:
yield
dist.barrier(group)
if not is_main:
yield
@contextmanager
def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None):
"""Context manager that executes the code in the context with all the local rank zero of the group going first.
Useful to run only once per node first (e.g. to create local files, etc)
"""
is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0
if is_main:
yield
dist.barrier(group)
if not is_main:
yield
def checkpoint_method(attr_name: str):
"""Decorator to checkpoint a method of a class."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_self = args[0]
checkpoint_activated = getattr(_self, attr_name)
if checkpoint_activated:
all_args = list(args)
signature_params = inspect.signature(func).parameters
# Parameters are ordered in the function definition order: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
for i, (arg_name, arg_value) in enumerate(signature_params.items()):
if arg_value.kind in [inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL]:
raise NotImplementedError(
"Checkpointing of functions with *args or **kwargs is not supported."
)
if i < len(args):
continue
if arg_name not in kwargs:
assert (
arg_value.default is not inspect.Parameter.empty
), f"Missing argument {arg_name} from {kwargs} for {func.__name__}"
all_args.append(arg_value.default)
else:
all_args.append(kwargs[arg_name])
assert len(all_args) == len(signature_params), f"Missing arguments for {func.__name__}"
# TODO @nouamanetazi: we pass `self`(which is module) to checkpoint, so it's stored in `ctx.inputs` whereas some other methods create a custom fwd and pass only tensors without `self`. Need to investigate which is better
return checkpoint(func, *all_args)
else:
return func(*args, **kwargs)
return wrapper
return decorator
def get_parameter_and_parent_module(target: str, root_module: nn.Module):
module_path, _, param_name = target.rpartition(".")
mod: torch.nn.Module = root_module.get_submodule(module_path)
if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
param: torch.nn.Parameter = getattr(mod, param_name)
if not isinstance(param, torch.nn.Parameter):
raise AttributeError("`" + param_name + "` is not an " "nn.Parameter")
return param, mod, param_name
def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage:
if version.parse(torch.__version__) >= version.parse("2.0"):
return tensor.untyped_storage()
else:
return tensor.storage().untyped()
def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype):
# TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage.
device = untyped_storage.device
tensor = torch.empty([], dtype=dtype, device=device)
tensor.set_(source=untyped_storage)
return tensor
def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
while True:
port = random.randint(min_port, max_port)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
10K slice of OpenWebText - An open-source replication of the WebText dataset from OpenAI.
This is a small subset representing the first 10K records from the original dataset - created for testing.
The full 8M-record dataset is [here](https://huggingface.co/datasets/openwebtext).
```
$ python -c "from datasets import load_dataset; ds=load_dataset('stas/openwebtext-10k'); print(ds)"
DatasetDict({
train: Dataset({
features: ['text'],
num_rows: 10000
})
})
```
* Records: 10,000
* compressed size: ~15MB
* uncompressed size: 50MB
To convert to jsonlines:
```
from datasets import load_dataset
dataset_name = "stas/openwebtext-10k"
name = dataset_name.split('/')[-1]
ds = load_dataset(dataset_name, split='train')
ds.to_json(f"{name}.jsonl", orient="records", lines=True)
```
To see how this subset was created, here is the [instructions file](https://huggingface.co/datasets/stas/openwebtext-10k/blob/main/process.txt).
{"plain_text": {"description": "An open-source replication of the WebText dataset from OpenAI.\n\nThis is a small subset representing the first 10K records from the original dataset - created for testing.\n\nThe full 8M-record dataset is at https://huggingface.co/datasets/openwebtext\n", "citation": "@misc{Gokaslan2019OpenWeb,\n title={OpenWebText Corpus},\n author={Aaron Gokaslan*, Vanya Cohen*, Ellie Pavlick, Stefanie Tellex},\n howpublished{\\url{http://Skylion007.github.io/OpenWebTextCorpus}},\n year={2019}\n}\n", "homepage": "https://skylion007.github.io/OpenWebTextCorpus/", "license": "", "features": {"text": {"dtype": "string", "id": null, "_type": "Value"}}, "post_processed": null, "supervised_keys": null, "builder_name": "openwebtext10k", "config_name": "plain_text", "version": {"version_str": "1.0.0", "description": null, "major": 1, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 49670861, "num_examples": 10000, "dataset_name": "openwebtext10k"}}, "download_checksums": {"https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/openwebtext-10k.tar.xz": {"num_bytes": 14723792, "checksum": "1dd150ffa3361ab32fa9f129d1b5ce20ac48728be16be436558f844d1761c572"}}, "download_size": 14723792, "post_processing_size": null, "dataset_size": 49670861, "size_in_bytes": 64394653}}
\ No newline at end of file
# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Open WebText Corpus"""
import os
import re
from itertools import chain
import datasets
_CITATION = """\
@misc{Gokaslan2019OpenWeb,
title={OpenWebText Corpus},
author={Aaron Gokaslan*, Vanya Cohen*, Ellie Pavlick, Stefanie Tellex},
howpublished{\\url{http://Skylion007.github.io/OpenWebTextCorpus}},
year={2019}
}
"""
_DESCRIPTION = """\
An open-source replication of the WebText dataset from OpenAI.
This is a small subset representing the first 10K records from the original dataset - created for testing.
The full 8M-record dataset is at https://huggingface.co/datasets/openwebtext
"""
# _URL = "https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/openwebtext-10k.tar.xz"
_URL = "/home/nanotron/openwebtext-10k.tar.xz"
class Openwebtext10k(datasets.GeneratorBasedBuilder):
"""The Open WebText dataset."""
BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="plain_text",
description="Plain text",
version=datasets.Version("1.0.0"),
)
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features({"text": datasets.Value("string")}),
homepage="https://skylion007.github.io/OpenWebTextCorpus/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
dl_dir = dl_manager.download_and_extract(_URL)
owt_dir = os.path.join(dl_dir, "openwebtext-10k")
subset_xzs = [
os.path.join(owt_dir, file_name)
for file_name in sorted(os.listdir(owt_dir))
if file_name.endswith("xz") # filter out ...xz.lock
]
# ex_dirs = dl_manager.extract(subset_xzs, num_proc=round(os.cpu_count() * 0.75))
ex_dirs = dl_manager.extract(subset_xzs)
nested_txt_files = [
[
os.path.join(ex_dir, txt_file_name)
for txt_file_name in sorted(os.listdir(ex_dir))
if txt_file_name.endswith("txt")
]
for ex_dir in ex_dirs
]
txt_files = chain(*nested_txt_files)
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"txt_files": txt_files}),
]
def _generate_examples(self, txt_files):
"""Yields examples."""
for idx, filepath in enumerate(txt_files):
with open(filepath, encoding="utf-8") as f:
yield idx, {"text": re.sub("\n\n\n+", "\n\n", f.read()).strip()}
# this is a small derivative from 8M-big openwebtext dataset for testing
# how this build script and dataset_infos.json were generated
#
mkdir openwebtext-10k
cd openwebtext-10k
# data
wget https://zenodo.org/record/3834942/files/openwebtext.tar.xz
tar xf openwebtext.tar.xz
cd openwebtext
rename.pl 's|-|-00|; s|-00(\d\d\d)|-$1|; s|-00(\d\d)|-0$1|;' *xz
# now open the first 30 archives
mkdir subset
cp urlsf_subset00-0[0-2]*_data.xz subset
cd subset
find . -name "*xz" -exec tar xf {} \;
mkdir 10k
find . -name "*txt" | sort | head -10000 | xargs mv -t 10k
tar cfJ 10k.xz -C 10k .
mkdir openwebtext-10k
mv 10k.xz openwebtext-10k
tar cfJ openwebtext-10k.tar.xz openwebtext-10k
# the openwebtext subdir gets created on the fly
aws s3 cp openwebtext-10k.tar.xz s3://datasets.huggingface.co/nlp/datasets/openwebtext/
# script
wget https://raw.githubusercontent.com/huggingface/datasets/master/datasets/openwebtext/openwebtext.py
mv openwebtext.py openwebtext-10k.py
perl -pi -e 's|openwebtext|openwebtext-10k|g' openwebtext-10k.py
perl -pi -e 's|https://zenodo.org/record/3834942/files/|https://cdn-datasets.huggingface.co/nlp/datasets/openwebtext/|g' openwebtext-10k.py
perl -pi -e 's|Openwebtext|Openwebtext10k|g' openwebtext-10k.py
# manually check that the script is correct - edit the descriptions
# create a new dataset entry on the hub
https://huggingface.co/new-dataset
# once created clone it
git clone https://huggingface.co/datasets/stas/openwebtext-10k
cp openwebtext-10k.py process.txt openwebtext-10k
cd openwebtext-10k
git add openwebtext-10k.py process.txt
git commit -m "build script" openwebtext-10k.py process.txt
git push
# test and generate config file
cd ..
datasets-cli test ./openwebtext-10k --save_infos --all_configs
# add and push the generated config
cd openwebtext-10k
git add dataset_infos.json
git commit -m "add dataset_infos.json" dataset_infos.json
git push
# test that the dataset is working
python -c "from datasets import load_dataset; ds=load_dataset('stas/openwebtext-10k'); print(ds)"
import torch
from nanotron.fp8 import DTypes, FP8Parameter, FP8Tensor
from nanotron.fp8.meta import FP8Meta
def test_create_fp8_parameter():
# TODO(xrsrke): test FP8E5M2 format
# TODO(xrsrke): test take a cpu tensor
tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32)
fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3)
assert isinstance(fp8_parameter.data, FP8Tensor)
assert fp8_parameter.requires_grad is True
assert fp8_parameter.grad is None
assert isinstance(fp8_parameter.fp8_meta, FP8Meta)
assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta)
# TODO(xrsrke): add test for preventing torch autograd do the backward pass
# on a FP8Parameter
import pytest
import torch
from nanotron.fp8 import DTypes, FP8Linear, FP8Parameter, FP8Tensor
from torch import nn
from torch.optim import Adam
@pytest.mark.parametrize("is_bias", [True, False])
def test_fp8_linear_forward_pass(is_bias):
input = torch.randn(16, 16, device="cuda", dtype=torch.float32)
ref_input = input.detach().clone()
ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda", dtype=torch.float32)
fp8_linear = FP8Linear(16, 16, bias=is_bias, device="cuda:0")
fp8_linear.weight = FP8Parameter(ref_linear.weight.detach().clone(), DTypes.FP8E4M3)
if is_bias:
fp8_linear.bias.data = ref_linear.bias.detach().clone()
ref_output = ref_linear(ref_input)
output = fp8_linear(input)
assert isinstance(output, torch.Tensor)
assert output.dtype == torch.float32
assert torch.allclose(output, ref_output, rtol=0, atol=0.1)
# TODO(xrsrke): add cases where the input requires and don't require grad
@pytest.mark.parametrize("input_requires_grad", [True, False])
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_fp8_linear_backward_pass(input_requires_grad, device):
input = torch.randn(16, 16, device=device, dtype=torch.float32, requires_grad=input_requires_grad)
ref_input = input.detach().clone().requires_grad_(True)
ref_linear = nn.Linear(16, 16, device=device, dtype=torch.float32)
fp8_linear = FP8Linear(16, 16, device=device)
if device == "cpu":
fp8_linear.weight.data = ref_linear.weight.detach().clone()
else:
fp8_linear.weight.data = FP8Tensor(ref_linear.weight.detach().clone(), dtype=DTypes.FP8E4M3)
fp8_linear.bias.data = ref_linear.bias.detach().clone()
ref_linear(ref_input).sum().backward()
fp8_linear(input).sum().backward()
# TODO(xrsrke): investigate why input.grad is so high tolerance
# assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True
assert torch.allclose(fp8_linear.weight.grad, ref_linear.weight.grad, 0.1, 0.1)
assert torch.allclose(fp8_linear.bias.grad, ref_linear.bias.grad, 0, 0.1)
# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear
def test_fp8_linear_attrs():
fp8_linear = FP8Linear(16, 16, device="cuda:0")
assert next(fp8_linear.parameters()) is not None
assert all(p.requires_grad for p in fp8_linear.parameters()) is True
# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based
# on the requires_grad of the input, weight, or bias
def test_fp8_model_bwd():
HIDEEN_SIZE = 128
N_LAYERS = 5
N_EPOCHS = 3
input = torch.randn(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda", requires_grad=True)
model = nn.Sequential(
*[nn.Sequential(FP8Linear(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda"), nn.ReLU()) for _ in range(N_LAYERS)]
)
optim = Adam(model.parameters(), lr=1e-3)
for _ in range(N_EPOCHS):
optim.zero_grad()
model(input).sum().backward()
optim.step()
assert all(p.grad is not None for p in model.parameters())
from copy import deepcopy
import numpy as np
import pytest
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex
from nanotron.fp8 import DTypes, FP8Tensor
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import convert_tensor_from_fp8
@pytest.mark.parametrize("size", [4, 8, 16, 64])
def test_quantize_and_dequantize_tensor_in_fp8(size):
tensor = torch.randn((size, size), dtype=torch.float32, device="cuda")
ref_tensor = deepcopy(tensor)
fp8_tensor = FP8Tensor(tensor, dtype=DTypes.FP8E4M3)
assert isinstance(fp8_tensor, FP8Tensor)
assert isinstance(fp8_tensor.fp8_meta, FP8Meta)
assert fp8_tensor.device == ref_tensor.device
assert fp8_tensor.dtype == torch.uint8
assert fp8_tensor.shape == ref_tensor.shape
assert fp8_tensor.numel() == ref_tensor.numel()
assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy())
# TODO(xrsrke): remove the fixed 1 factor
# it couples with the current implementation of FP8Meta
# because we initialize scale with 1
assert fp8_tensor.fp8_meta.amax == ref_tensor.abs().max()
assert isinstance(fp8_tensor.fp8_meta.inverse_scale, torch.Tensor)
assert fp8_tensor.fp8_meta.scale != 0.1 and fp8_tensor.fp8_meta.scale != 1
assert isinstance(fp8_tensor.fp8_meta.te_dtype, tex.DType)
tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32)
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == ref_tensor.dtype
assert torch.allclose(tensor, ref_tensor, rtol=1e-1, atol=1e-1)
def test_fp8_tensor_attrs():
SIZE = 64
tensor = torch.randn((SIZE, SIZE), dtype=torch.float32, device="cuda:0")
ref_tensor = tensor.detach().clone()
fp8_tensor = FP8Tensor(tensor, DTypes.FP8E4M3)
assert isinstance(fp8_tensor, FP8Tensor)
assert isinstance(fp8_tensor.fp8_meta, FP8Meta)
assert fp8_tensor.device == ref_tensor.device
assert fp8_tensor.dtype == torch.uint8
assert fp8_tensor.shape == ref_tensor.shape
assert fp8_tensor.numel() == ref_tensor.numel()
assert fp8_tensor.device == ref_tensor.device
# TODO(xrsrke): test it has all the methods of torch.Tensor
# TODO(xrsrke): test it has all the attributes of its input tensor
import shutil
import uuid
from functools import lru_cache
from pathlib import Path
class TestContext:
def __init__(self):
self._random_string = str(uuid.uuid1())
self._root_dir = Path(__file__).parent.parent / ".test_cache"
self._root_dir.mkdir(parents=True, exist_ok=True)
@lru_cache(maxsize=1)
def get_auto_remove_tmp_dir(self):
path = self._root_dir / self._random_string
path.mkdir(parents=True, exist_ok=True)
return path
def __del__(self):
path = self.get_auto_remove_tmp_dir()
shutil.rmtree(path)
import hashlib
import importlib
import json
import os
import sys
from argparse import Namespace
from collections import OrderedDict
from pathlib import Path
package = importlib.import_module("nanotron")
package_path = Path(package.__file__).parent.parent.parent
sys.path.append(str(package_path))
import nanotron.distributed as dist
import torch
from nanotron.data.nanoset import Nanoset
from nanotron.parallel import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.sanity_checks import assert_tensor_synced_across_pg
from tools.preprocess_data import main
def create_dataset_paths(tmp_dir: str, quantity: int):
json_dataset_path = [os.path.join(tmp_dir, f"pytest_{i}.json") for i in range(quantity)]
datatrove_tokenized_dataset_paths = [os.path.join(tmp_dir, f"tokenized_documents_{i}") for i in range(quantity)]
return json_dataset_path, datatrove_tokenized_dataset_paths
def create_dummy_json_dataset(path_to_json: str, dummy_text: str, n_samples: int = 50000):
with open(path_to_json, "a") as json_file:
for sample in range(n_samples):
sample_dict = {"text": f"[{sample}] Hello! Im sample {sample}! And this is my dummy text: {dummy_text}"}
json_file.write(json.dumps(sample_dict))
json_file.write("\n")
def preprocess_dummy_dataset(json_dataset_path: str, datatrove_tokenized_dataset_path: str, tokenizer: str):
# Create args for preprocessing
args = Namespace(
readers="jsonl",
dataset=json_dataset_path,
column="text",
glob_pattern=None,
output_folder=datatrove_tokenized_dataset_path,
tokenizer_name_or_path=tokenizer,
eos_token=None,
n_tasks=1,
logging_dir=None,
)
# tools/preprocess_data.py main
main(args)
def assert_batch_dataloader(
batch: dict, parallel_context: ParallelContext, micro_batch_size: int, sequence_length: int
):
"""
batch (dict): Batch produced from the Dataloader, with keys input_ids, input_mask, label_ids, label_mask
"""
for element in batch:
tensor = batch[element]
# Assert that inputs are only present in input_pp_rank and outputs in output_pp_rank
input_pp_rank, output_pp_rank = 0, int(parallel_context.pp_pg.size() - 1)
if dist.get_rank(parallel_context.pp_pg) == input_pp_rank and element.startswith("input_"):
assert isinstance(tensor, torch.Tensor)
elif dist.get_rank(parallel_context.pp_pg) == output_pp_rank and element.startswith("label_"):
assert isinstance(tensor, torch.Tensor)
else:
assert isinstance(tensor, TensorPointer)
data_class = (
0 # 0 if tensor is from the ids, 1 if TensorPointer and 2 if mask. Used in the data parallel group check
)
# Check shape of mask and ids tensors
if isinstance(tensor, torch.Tensor):
assert tensor.shape == (micro_batch_size, sequence_length)
# TensorPointer case: Check that all TensorPointers from the same tp_pg point to the same group_rank. Create torch.tensor with group_rank
if isinstance(tensor, TensorPointer):
tensor = torch.tensor(tensor.group_rank)
data_class = 1
# Attention Masks case: dtype is torch.bool --> Transform to int64
if tensor.dtype == torch.bool:
tensor = tensor.long()
data_class = 2
# Assert that we have the SAME element in all the processes belonging to the same tensor parallel group
assert_tensor_synced_across_pg(
tensor=tensor.flatten().cuda(),
pg=parallel_context.tp_pg,
msg=lambda err: f"{element} is not synchronized across TP {err}",
)
# Assert that we have the SAME class of data in all processes belonging to the same data parallel group
assert_tensor_synced_across_pg(
tensor=torch.tensor(data_class, device="cuda"),
pg=parallel_context.dp_pg,
msg=lambda err: f"{element} is not synchronized across DP {err}",
)
def compute_hash(identifier: OrderedDict, n_digit: int = 8) -> int:
"""
Creates a sha256 hash from the elements of a OrderedDict
"""
unique_description = json.dumps(identifier, indent=4)
# Create n_digit description hash
unique_description_hash = int(hashlib.sha256(unique_description.encode("utf-8")).hexdigest(), 16) % 10**n_digit
return unique_description_hash
def assert_nanoset_sync_across_all_ranks(nanoset: Nanoset, parallel_context: ParallelContext):
"""
Checks that the same Nanoset is created in all processes
"""
# Extract a sample from the Nanoset
IDX_SAMPLE = 23
nanoset_identifiers = OrderedDict()
nanoset_identifiers["dataset_folders"] = nanoset.dataset_folders
nanoset_identifiers["dataset_weights"] = nanoset.dataset_weights.tolist()
nanoset_identifiers["sequence_length"] = nanoset.sequence_length
nanoset_identifiers["train_split_num_samples"] = nanoset.train_split_num_samples
nanoset_identifiers["random_seed"] = nanoset.random_seed
nanoset_identifiers["length"] = len(nanoset)
nanoset_identifiers["input_ids"] = nanoset[IDX_SAMPLE]["input_ids"].tolist()
nanoset_identifiers["dataset_index"] = nanoset.dataset_index.tolist()
nanoset_identifiers["dataset_sample_index"] = nanoset.dataset_sample_index.tolist()
nanoset_identifiers["token_size"] = nanoset.token_size
unique_description_hash = compute_hash(nanoset_identifiers)
assert_tensor_synced_across_pg(
tensor=torch.tensor(unique_description_hash, device="cuda"),
pg=parallel_context.world_pg,
msg=lambda err: f"Nanoset is not synchronized across all processes {err}",
)
def compute_batch_hash(batch: dict) -> int:
"""
Checks that the Nanoset/BlendedNanoset is in the same state after recovering from a crash
batch (dict): Batch produced from the Dataloader, with keys input_ids, input_mask, label_ids, label_mask
"""
batch_identifiers = OrderedDict()
for element in batch:
tensor = batch[element]
# TensorPointer
if isinstance(tensor, TensorPointer):
identifier = tensor.group_rank
# Attention Masks case: dtype is torch.bool --> Transform to int64
elif tensor.dtype == torch.bool:
identifier = tensor.long().tolist()
# Input IDs tensor
else:
identifier = tensor.tolist()
batch_identifiers[element] = identifier
unique_description_hash = compute_hash(batch_identifiers)
return unique_description_hash
import torch
from nanotron import distributed as dist
from nanotron.distributed import ProcessGroup, get_global_rank
def assert_tensor_equal_over_group(tensor: torch.Tensor, group: ProcessGroup, assert_: bool = True) -> bool:
"""We assume that tensors are already of correct size."""
reference_rank = 0
if dist.get_rank(group) == reference_rank:
reference_tensor = tensor
else:
reference_tensor = torch.empty_like(tensor)
dist.broadcast(
reference_tensor,
src=get_global_rank(group=group, group_rank=reference_rank),
group=group,
)
if assert_:
torch.testing.assert_close(tensor, reference_tensor, atol=0, rtol=0)
else:
result = torch.allclose(tensor, reference_tensor, atol=0.0, rtol=0.0)
results = [0] * group.size()
dist.all_gather_object(results, result, group)
return all(results)
from math import ceil
from typing import Union
import torch
from nanotron import distributed as dist
from nanotron.models import init_on_device_and_dtype
from nanotron.optim.base import BaseOptimizer
from nanotron.optim.named_optimizer import NamedOptimizer
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tied_parameters import tie_parameters
from nanotron.parallel.utils import initial_sync
from torch import nn
from torch.nn.parallel import DistributedDataParallel
class DummyModel(nn.Module):
def __init__(
self,
p2p: P2P,
):
super().__init__()
self.p2p = p2p
self.mlp = nn.Sequential(
*(
nn.ModuleDict(
{
"linear": PipelineBlock(
p2p=p2p,
module_builder=nn.Linear,
module_kwargs={"in_features": 10, "out_features": 10},
module_input_keys={"input"},
module_output_keys={"output"},
),
"activation": PipelineBlock(
p2p=p2p,
module_builder=nn.Sigmoid if pp_rank < p2p.pg.size() - 1 else nn.Identity,
module_kwargs={},
module_input_keys={"input"},
module_output_keys={"output"},
),
}
)
for pp_rank in range(p2p.pg.size())
)
)
self.loss = PipelineBlock(
p2p=p2p,
module_builder=lambda: lambda x: x.sum(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(self, x: Union[torch.Tensor, TensorPointer]):
for non_linear in self.mlp:
x = non_linear.linear(input=x)["output"]
x = non_linear.activation(input=x)["output"]
x = self.loss(x=x)["output"]
return x
def init_dummy_model(parallel_context: ParallelContext, dtype: torch.dtype = torch.float) -> DummyModel:
p2p = P2P(pg=parallel_context.pp_pg, device=torch.device("cuda"))
model = DummyModel(p2p=p2p)
# Build model using contiguous segments
pipeline_blocks = [module for name, module in model.named_modules() if isinstance(module, PipelineBlock)]
with init_on_device_and_dtype(device=torch.device("cuda"), dtype=dtype):
contiguous_size = ceil(len(pipeline_blocks) / parallel_context.pp_pg.size())
for i, block in enumerate(pipeline_blocks):
rank = i // contiguous_size
block.build_and_set_rank(rank)
# Sync all parameters that have the same name and that are not sharded across TP.
for name, param in model.named_parameters():
if isinstance(param, NanotronParameter) and param.is_sharded:
continue
shared_weights = [
(
name,
# sync across TP group
tuple(sorted(dist.get_process_group_ranks(parallel_context.tp_pg))),
)
]
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM
)
initial_sync(model=model, parallel_context=parallel_context)
if len(list(model.named_parameters())) > 0:
model = DistributedDataParallel(model, process_group=parallel_context.dp_pg)
else:
# No parameters, so no need to use DDP to sync parameters gradients
model = model
return model
def init_dummy_optimizer(model: nn.Module, parallel_context: ParallelContext) -> BaseOptimizer:
optimizer = NamedOptimizer(
named_params_or_groups=model.named_parameters(), optimizer_builder=lambda params: torch.optim.AdamW(params)
)
# Synchronize across dp: basic assumption, already done as nothing in optimizer initialization is stochastic
return optimizer
def dummy_infinite_data_loader(pp_pg: dist.ProcessGroup, dtype=torch.float, input_pp_rank=0):
micro_batch_size = 3
# We assume the first linear is always built on the first rank.
current_pp_rank = dist.get_rank(pp_pg)
while True:
yield {
"x": torch.randn(micro_batch_size, 10, dtype=dtype, device="cuda")
if current_pp_rank == input_pp_rank
else TensorPointer(group_rank=input_pp_rank)
}
import contextlib
import signal
from typing import Optional
from nanotron import distributed as dist
@contextlib.contextmanager
def assert_fail_with(exception_class, error_msg: Optional[str] = None):
try:
yield
except exception_class as e:
if error_msg is None:
return
if error_msg == str(e):
return
else:
raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
@contextlib.contextmanager
def assert_fail_except_rank_with(
exception_class, rank_exception: int, pg: dist.ProcessGroup, error_msg: Optional[str] = None
):
try:
yield
except exception_class as e:
if rank_exception == dist.get_rank(pg):
raise AssertionError(f"Expected rank {rank_exception} to not raise {exception_class}.")
else:
if error_msg is None:
return
if error_msg == str(e):
return
else:
raise AssertionError(f'Expected message to be "{error_msg}", but got "{str(e)}" instead.')
except Exception as e:
raise AssertionError(f"Expected {exception_class} to be raised, but got: {type(e)} instead:\n{e}")
if dist.get_rank(pg) != rank_exception:
raise AssertionError(f"Expected {exception_class} to be raised, but no exception was raised.")
@contextlib.contextmanager
def timeout_after(ms=500):
"""Timeout context manager."""
def signal_handler(signum, frame):
raise TimeoutError(f"Timed out after {ms} ms.")
signal.signal(signal.SIGALRM, signal_handler)
signal.setitimer(signal.ITIMER_REAL, ms / 1000)
try:
yield
finally:
signal.alarm(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment