Commit 0371621a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1989 canceled with stages
"""
Tensor Parallelism API requires PyTorch 2.3.0+
"""
from allamo.logging import logger
from allamo.configuration import AllamoConfiguration
from allamo.torch_utils import (
TORCH_DTYPE_MAP,
)
from allamo.training_context import TrainingContext
import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
PrepareModuleInput,
SequenceParallel,
)
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.utils.checkpoint import checkpoint
def build_world_mesh(train_ctx: TrainingContext, device_type: str = "cuda"):
dims = (train_ctx.pp, train_ctx.dp, train_ctx.tp)
dim_names = ("pp", "dp", "tp")
device_mesh = init_device_mesh(device_type, dims, mesh_dim_names=dim_names)
logger.info(f"{len(dims)}-D device mesh built: {dim_names} = {dims}")
return device_mesh
def parallelize_model_with_fsdp2(model, world_mesh, config, with_activation_checkpointing):
if world_mesh['tp'].size() > 1:
apply_tensor_parallelism(model, world_mesh)
if with_activation_checkpointing:
apply_activation_checkpointing(model)
apply_fsdp(model, world_mesh, config)
if config.compile:
logger.info("Compiling model")
try:
model = torch.compile(model, mode=config.compile_mode)
logger.info("Model compiled and ready to use")
except Exception as err:
logger.warning(f"Unable to compile the model: {err}")
return model
def apply_tensor_parallelism(model: nn.Module, world_mesh: DeviceMesh):
logger.warning(
"Tensor parallelism is in an early experimental stage. "
"Strided sharding is required for 2D/3D DCP, but it is only available in nightly builds "
"newer than 20240809 and in PyTorch version 2.5 or later."
)
parallelize_module(
model,
world_mesh["tp"],
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"lm_head": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
},
)
for layer in model.layers:
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None, None),
desired_input_layouts=(Replicate(), None, None),
),
"attention.q_proj": ColwiseParallel(),
"attention.k_proj": ColwiseParallel(),
"attention.v_proj": ColwiseParallel(),
"attention.c_proj": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.gate_proj": ColwiseParallel(),
"feed_forward.down_proj": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.up_proj": ColwiseParallel(),
}
layer.attention.num_heads //= world_mesh["tp"].size()
layer.attention.num_kv_heads //= world_mesh["tp"].size()
parallelize_module(
module=layer,
device_mesh=world_mesh["tp"],
parallelize_plan=layer_plan,
)
logger.info(f"Model parallelized with Tensor Parallelism (size: {world_mesh['tp'].size()})")
def apply_activation_checkpointing(model: nn.Module):
for layer_id in range(len(model.layers)):
model.layers[layer_id] = checkpoint_wrapper(
model.layers[layer_id],
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
use_reentrant=False,
preserve_rng_state=False,
)
def apply_fsdp(model: nn.Module, world_mesh: DeviceMesh, config: AllamoConfiguration):
fsdp_config = {"mesh": world_mesh["dp"]}
if config.dtype != 'float32':
fsdp_config["mp_policy"] = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[config.dtype],
reduce_dtype=torch.float32
)
pp_enabled = world_mesh['pp'].size() > 1
for layer_id, layer in enumerate(model.layers):
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# layer since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
logger.info(f"Model parallelized with FSDP2: {model}\n")
import torch
import functools
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from allamo.logging import logger
from allamo.configuration import AllamoConfiguration
from allamo.model.model import SelfAttentionBlock
from allamo.torch_utils import (
TORCH_DTYPE_MAP,
)
FSDP_SHARDING_STRATEGY_MAP = {
'FULL_SHARD': ShardingStrategy.FULL_SHARD,
'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
'_HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
'NO_SHARD': ShardingStrategy.NO_SHARD
}
def enable_activation_checkpointing(model):
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance(submodule, SelfAttentionBlock)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
logger.info(f"Activation checkpointing applied to the model")
def parallelize_model_with_fsdp1(model, config: AllamoConfiguration, with_activation_checkpointing: bool = False):
logger.info("Configuring model with FSDP1")
ptdtype = TORCH_DTYPE_MAP[config.dtype]
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
SelfAttentionBlock,
},
)
sharding_strategy = FSDP_SHARDING_STRATEGY_MAP[config.fsdp_sharding_strategy]
fsdp_config = dict(
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device(),
mixed_precision=MixedPrecision(
param_dtype=ptdtype,
reduce_dtype=ptdtype,
buffer_dtype=ptdtype,
),
limit_all_gathers=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # will use slightly more memory vs. no prefetch
use_orig_params=True, # required to use torch.compile()
)
model = FSDP(model, **fsdp_config)
logger.info(f"Model configured with FSDP1 and {sharding_strategy=}")
if with_activation_checkpointing:
enable_activation_checkpointing(model)
logger.info(f"Model after parallelization {model=}\n")
if config.compile:
logger.info("Compiling model")
try:
model = torch.compile(model, mode=config.compile_mode)
logger.info("Model compiled and ready to use")
except Exception as err:
logger.warning(f"Unable to compile the model: {err}")
return model
import math
import os
from typing import Optional
import torch
import torch.distributed as dist
from allamo.logging import logger
from allamo.configuration import AllamoConfiguration
from allamo.training_context import TrainingContext
TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"bfloat16-true": torch.bfloat16,
}
def override_numa_affinity(local_rank: int, verbose: Optional[bool] = None) -> None:
if torch.cuda.is_available():
try:
import pynvml as nvml
except ImportError:
logger.warning("To set CPU affinity on CUDA GPUs the `pynvml` package must be available. (`pip install pynvml`)")
return
# The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py
nvml.nvmlInit()
num_elements = math.ceil(os.cpu_count() / 64)
handle = nvml.nvmlDeviceGetHandleByIndex(local_rank)
affinity_string = ""
for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements):
# assume nvml returns list of 64 bit ints
affinity_string = f"{j:064b}{affinity_string}"
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is the 0th element
affinity_to_set = set([i for i, e in enumerate(affinity_list) if e != 0])
current_affinity = set(os.sched_getaffinity(0))
affinity_to_set = affinity_to_set.intersection(current_affinity)
if affinity_to_set:
os.sched_setaffinity(0, affinity_to_set)
if verbose:
cpu_cores = os.sched_getaffinity(0)
logger.info(f"Assigning {len(cpu_cores)} cpu cores to process {local_rank}: {cpu_cores}")
else:
logger.info("No affinity available to set")
def configure_torch(config: AllamoConfiguration, rank: int = 0):
torch.manual_seed(config.seed + rank)
if 'cuda' in config.device:
torch.cuda.manual_seed(config.seed + rank)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Use for setting the internal precision of float32 matrix multiplications
# torch.set_float32_matmul_precision("highest")
def init_torch(train_ctx: TrainingContext, config: AllamoConfiguration, distributed=True):
if distributed:
dist.init_process_group(backend=config.backend)
if 'cuda' in config.device:
config.device = f'cuda:{train_ctx.local_rank}'
torch.cuda.set_device(config.device)
if train_ctx.master_process:
logger.info(
f"RANK: {train_ctx.rank}, LOCAL_RANK: {train_ctx.local_rank}, "
f"WORLD_SIZE: {train_ctx.world_size}, LOCAL_WORLD_SIZE: {train_ctx.local_world_size}"
)
os.makedirs(config.out_dir, exist_ok=True)
configure_torch(config, train_ctx.rank)
# override_numa_affinity(train_ctx.local_rank)
import dataclasses
import hashlib
import os
from allamo.model.model import AllamoTransformerConfig
def rename_file_to_prev_version(file_path):
if os.path.exists(file_path):
os.rename(file_path, file_path + '.prev')
def calculate_md5(file_path, chunk_size=1024*1024):
md5 = hashlib.md5()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def remove_unwanted_prefix_from_model_state_dict(state_dict, unwanted_prefix = '_orig_mod.'):
unwanted_prefix_len = len(unwanted_prefix)
for k, _ in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[unwanted_prefix_len:]] = state_dict.pop(k)
def remove_unwanted_prefix_from_optimizer_state_dict(optimizer_state_dict, unwanted_prefix = '_orig_mod.'):
if "param_groups" in optimizer_state_dict:
unwanted_prefix_len = len(unwanted_prefix)
for param_group in optimizer_state_dict["param_groups"]:
param_group['params'] = [p[unwanted_prefix_len:] if p.startswith(unwanted_prefix) else p for p in param_group['params']]
def format_seconds_as_time(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours)}:{int(minutes):02}:{int(seconds):02}"
def estimate_mfu(model_num_params, config, fwdbwd_per_iter, dt):
# estimate model flops utilization (MFU) in units of GPU bfloat16 peak FLOPS
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
N = model_num_params
L, H, Q, T = config.n_layer, config.n_head, config.head_size, config.block_size
flops_per_token = 6 * N + 12 * L * H * Q * T
flops_per_fwdbwd = flops_per_token * T
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
# express our flops throughput as ratio of GPU bfloat16 peak flops
flops_achieved = flops_per_iter * (1.0/dt) # per second
return flops_achieved / config.mfu_flops_peak
def get_model_checkpoint_path(ckpt_file_name, ckpt_dir):
return os.path.join(ckpt_dir, f'model_{ckpt_file_name}.pt')
def get_config_checkpoint_path(ckpt_file_name, ckpt_dir):
return os.path.join(ckpt_dir, f'config_{ckpt_file_name}.json')
def get_optimizer_checkpoint_path(ckpt_file_name, ckpt_dir):
return os.path.join(ckpt_dir, f'optimizer_{ckpt_file_name}.pt')
def model_checkpoint_files_exist(ckpt_file_name, ckpt_dir):
return os.path.exists(get_config_checkpoint_path(ckpt_file_name, ckpt_dir)) \
and os.path.exists(get_model_checkpoint_path(ckpt_file_name, ckpt_dir))
def get_model_config_field_names():
return [f.name for f in dataclasses.fields(AllamoTransformerConfig)]
def create_model_config(config):
model_args = {k: getattr(config, k) for k in get_model_config_field_names() if hasattr(config, k)}
return AllamoTransformerConfig(**model_args)
import gc
import os
import time
import math
import datetime
import subprocess
import wandb
import torch
import torch.distributed as dist
from allamo.checkpoint.checkpoint_manager import CheckpointManager
from allamo.configuration import AllamoConfiguration
from allamo.dataset.data_loader import AllamoDataLoader
from allamo.logging import configure_logger, logger
from allamo.model.attentions import attention_version
from allamo.torch_utils import init_torch
from allamo.train_utils import (
format_seconds_as_time,
estimate_mfu,
get_model_checkpoint_path,
get_config_checkpoint_path,
create_model_config,
)
from allamo.training_context import TrainingContext
class BaseTrainer:
def __init__(self, config: AllamoConfiguration):
self.train_ctx = TrainingContext(
tp = config.tensor_parallel_degree,
)
if self.train_ctx.master_process:
configure_logger(config, True)
self.config = config
self.init_torch(config)
logger.info(f"Torch initialized for run {self.train_ctx.run_uuid}")
self.data_loader = AllamoDataLoader(config, self.train_ctx.rank, self.train_ctx.world_size)
self.init_training()
def distributed(self):
raise NotImplementedError("Not implemented")
def init_torch(self, config: AllamoConfiguration):
self.device_type = 'cuda' if 'cuda' in config.device else 'cpu'
init_torch(self.train_ctx, config, distributed=self.distributed())
def init_training(self):
attention_version.configure(self.config)
self.checkpoint_manager = CheckpointManager(self.config, self.train_ctx, self.data_loader)
self.checkpoint_manager.init_checkpoint()
self.data_loader.load_datasets()
self.model_config = create_model_config(self.config)
def init_gradient_accumulation_scheduler(self):
if self.config.grad_accum_schedule:
self.config.grad_accum_max = self.config.gradient_accumulation_steps
self.config.gradient_accumulation_steps = self.config.grad_accum_initial
logger.info(
f"Gradient accumulation scheduler enabled. "
f"Current gradient accumulation steps: {self.config.gradient_accumulation_steps}"
)
self.gradient_accumulation_steps = self.config.gradient_accumulation_steps
def log_init_learning_rate(self):
if self.config.decay_lr:
logger.info(f"Cosing decay learning rate enabled. Currect learning rate: {self.get_lr()}")
else:
logger.info(f"Using constant learning rate: {self.config.learning_rate}")
def init_wandb(self):
if self.config.wandb_log and self.train_ctx.master_process:
wandb_run_name = self.config.wandb_run_name + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
wandb.init(project=self.config.wandb_project, name=wandb_run_name, config=self.config)
def trigger_gc(self):
gc.collect()
torch.cuda.empty_cache()
def should_evaluate(self):
return self.config.eval_interval > 0 and self.train_ctx.iter_num % self.config.eval_interval == 0
def should_save_last_checkpoint(self):
return self.config.checkpoint_interval > 0 and self.train_ctx.iter_num > self.start_iter and self.train_ctx.iter_num % self.config.checkpoint_interval == 0
def should_log_metrics(self):
return self.config.log_interval > 0 and self.train_ctx.iter_num % self.config.log_interval == 0 and self.train_ctx.master_process
def clip_grad_norm(self):
return torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip).item()
def has_next_iter_to_perform(self):
if self.config.num_train_epochs is not None and self.data_loader.epoch >= self.config.num_train_epochs:
return False
return self.train_ctx.iter_num <= self.config.max_iters
def calculate_eta(self):
current_time = datetime.datetime.now()
elapsed_time = current_time - self.start_timestamp
elapsed_iters = self.train_ctx.iter_num - self.start_iter
if elapsed_iters < 1:
return 'N/A'
avg_time_per_iter = elapsed_time.total_seconds() / elapsed_iters
eta_seconds = math.ceil(avg_time_per_iter * (self.config.max_iters - self.train_ctx.iter_num))
return format_seconds_as_time(eta_seconds)
def get_grad_accum(self):
if self.config.grad_accum_schedule and self.gradient_accumulation_steps < self.config.grad_accum_max and self.train_ctx.iter_num % (self.config.grad_accum_max_iter/100) == 0:
return min(self.gradient_accumulation_steps + 1, self.config.grad_accum_max)
else:
return self.gradient_accumulation_steps
def get_lr(self):
""" learning rate decay scheduler (cosine with warmup) """
if self.train_ctx.iter_num < self.config.warmup_iters:
return self.config.learning_rate * self.train_ctx.iter_num / self.config.warmup_iters
if self.config.decay_lr:
if self.train_ctx.iter_num >= self.config.lr_decay_iters:
return self.config.min_lr
if self.config.lr_decay_reset_iters is not None:
decay_ratio = (self.train_ctx.iter_num % self.config.lr_decay_reset_iters) / self.config.lr_decay_reset_iters
else:
decay_ratio = (self.train_ctx.iter_num - self.config.warmup_iters) / (self.config.lr_decay_iters - self.config.warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return self.config.min_lr + coeff * (self.config.learning_rate - self.config.min_lr)
else:
return self.config.learning_rate
def run_checkpoint_hook_program(self, hook_program, current_epoch, ckpt_file_name):
env_variables = {
"ALLAMO_EPOCH_HOOK_RUN_UUID": self.train_ctx.run_uuid,
"ALLAMO_EPOCH_HOOK_TRAINING_UUID": self.train_ctx.training_uuid,
"ALLAMO_EPOCH_HOOK_EPOCH": str(current_epoch),
"ALLAMO_EPOCH_HOOK_ITERATION": str(self.train_ctx.iter_num),
"ALLAMO_EPOCH_HOOK_MODEL_CKPT_PATH": str(os.path.abspath(get_model_checkpoint_path(ckpt_file_name, self.config.out_dir))),
"ALLAMO_EPOCH_HOOK_CONFIG_CKPT_PATH": str(os.path.abspath(get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)))
}
try:
process = subprocess.Popen(hook_program, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=env_variables)
return process.pid
except Exception as err:
return f"n/a - Error: {err}"
def dist_all_reduce(self, x: torch.Tensor, op: dist.ReduceOp):
if self.distributed():
dist.all_reduce(x, op=op)
return x
# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(self):
losses_out = {}
accuraces = {}
self.model.eval()
for split in self.data_loader.splits:
validation_metrics = torch.zeros(3).to(self.config.device)
for _ in range(self.config.eval_iters):
batch = self.data_loader.get_batch(split, True)
logits, loss, _ = self.model(**batch)
if batch["target_weights"] is not None:
loss = loss / torch.sum(batch["target_weights"] > 0).item()
validation_metrics[0] += loss.item()
validation_metrics[1] += (logits.max(2).indices == batch["target_ids"]).sum().item() / torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
validation_metrics[2] += 1
validation_metrics = self.dist_all_reduce(validation_metrics, op=dist.ReduceOp.SUM)
losses_out[split] = validation_metrics[0] / (self.config.eval_iters * self.train_ctx.world_size)
accuraces[split] = validation_metrics[1] / validation_metrics[2]
self.model.train()
if 'val' not in losses_out:
losses_out['val'] = losses_out['train']
accuraces['val'] = accuraces['train']
return losses_out, accuraces
def evaluate(self):
eval_time = time.time()
losses, accuraces = self.estimate_loss()
eval_time = time.time() - eval_time
train_loss = losses['train'].item()
val_loss = losses['val'].item()
if self.train_ctx.iter_num > self.start_iter:
if train_loss < self.train_ctx.best_train_loss:
self.train_ctx.best_train_loss = train_loss
if val_loss < self.train_ctx.best_val_loss:
self.train_ctx.best_val_loss = val_loss
if self.config.save_best_checkpoint:
self.save_checkpoint('ckpt')
if self.train_ctx.master_process:
train_ppl = torch.exp(losses['train']).item()
val_ppl = torch.exp(losses['val']).item()
logger.info(
f"iter {self.train_ctx.iter_num:,}: train loss={train_loss:.4f} ppl={train_ppl:.4f} "
f"acc={accuraces['train']:.4f} (best loss={self.train_ctx.best_train_loss:.4f}), "
f"val loss={val_loss:.4f} ppl={val_ppl:.4f} acc={accuraces['val']:.4f} "
f"(best loss={self.train_ctx.best_val_loss:.4f}), tokens {self.train_ctx.processed_tokens:,}"
)
if self.config.wandb_log:
wandb.log({
"iter": self.train_ctx.iter_num,
"eval/time": eval_time*1000,
"eval/samples_per_second": (self.config.eval_iters * len(self.data_loader.splits)) / eval_time,
"eval/train_loss": train_loss,
"eval/val_loss": val_loss,
"eval/train_ppl": train_ppl,
"eval/val_ppl": val_ppl,
"eval/train_acc": accuraces['train'].item(),
"eval/val_acc": accuraces['val'].item(),
"eval/diff_loss": (val_loss-train_loss),
"eval/diff_acc": (accuraces['train']-accuraces['val']).item(),
"eval/diff_ppl": (val_ppl-train_ppl),
"eval/best_train_loss": self.train_ctx.best_train_loss,
"eval/best_val_loss": self.train_ctx.best_val_loss
})
self.trigger_gc()
def train(self):
logger.info(f"Starting training (run id: {self.train_ctx.run_uuid}, world size: {self.train_ctx.world_size}) with configuration:\n{self.config}")
batch = self.data_loader.get_batch('train') # fetch the very first batch
self.start_iter = self.train_ctx.iter_num
self.start_timestamp = datetime.datetime.now()
current_epoch = self.data_loader.epoch
current_num_loaded_files = self.data_loader.get_num_loaded_files()
iter_metrics = torch.zeros(5).to(self.config.device)
self.trigger_gc()
while self.has_next_iter_to_perform():
if current_epoch < self.data_loader.epoch:
ckpt_file_name = f'epoch_{current_epoch}'
self.save_checkpoint(ckpt_file_name, model_only=True, epoch_ckpt=True)
if self.config.epoch_completion_hook_program and self.train_ctx.master_process:
pid = self.run_checkpoint_hook_program(self.config.epoch_completion_hook_program, current_epoch, ckpt_file_name)
logger.info(f"Epoch completion hook program started with pid {pid}")
current_epoch = self.data_loader.epoch
current_num_loaded_files = self.data_loader.get_num_loaded_files()
elif self.config.save_checkpoint_on_dataset_reload and current_num_loaded_files != self.data_loader.get_num_loaded_files():
ckpt_file_name = f'ds_reload_{current_epoch}-{current_num_loaded_files}'
self.save_checkpoint(ckpt_file_name, model_only=True, epoch_ckpt=False)
current_num_loaded_files = self.data_loader.get_num_loaded_files()
elif self.config.should_override_config(self.train_ctx.iter_num):
self.config.override_config_properties()
timer = time.time()
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# determine and set batch_size and gradient_accumulation_steps for this iteration
micro_batch_size = self.data_loader.update_batch_size(self.train_ctx.iter_num)
total_batch_size = self.config.block_size * micro_batch_size * self.gradient_accumulation_steps * self.train_ctx.world_size
self.gradient_accumulation_steps = self.get_grad_accum()
# evaluate the loss on train/val sets and write best checkpoint
if self.should_evaluate():
self.evaluate()
if self.should_save_last_checkpoint():
ckpt_file_name = 'last_eval_ckpt'
self.save_checkpoint(ckpt_file_name)
if self.config.regular_checkpoint_hook_program and self.train_ctx.master_process:
pid = self.run_checkpoint_hook_program(self.config.regular_checkpoint_hook_program, current_epoch, ckpt_file_name)
logger.info(f"Regular checkpoint hook program started with pid {pid}")
accuracy = 0
iter_metrics.zero_()
batch_mfu_excluded_time = 0
fwdbwd_time = time.time()
# forward backward update, with optional gradient accumulation to simulate larger batch size
for micro_step in range(self.gradient_accumulation_steps):
loss, unmasked_labels, accuracy = self.forward(batch, (micro_step == self.gradient_accumulation_steps - 1))
mfu_excluded_time = time.time()
iter_metrics[0] += loss.item()
iter_metrics[1] += unmasked_labels
iter_metrics[2] += accuracy
iter_metrics[3] += 1
# immediately async prefetch next batch while model is doing the forward pass on the GPU
batch = self.data_loader.get_batch('train')
batch_mfu_excluded_time += time.time() - mfu_excluded_time
# backward pass, with gradient scaling
self.scaler.scale(loss).backward()
# clip the gradient
if self.config.grad_clip != 0.0:
self.scaler.unscale_(self.optimizer)
iter_metrics[4] += self.clip_grad_norm()
mfu_excluded_time = time.time()
# sync loss and acc over all processes
iter_metrics = self.dist_all_reduce(iter_metrics, op=dist.ReduceOp.SUM)
# adjust learning rate
if self.config.adaptive_learning_rate:
lr = lr * math.sqrt(iter_metrics[1].item() / total_batch_size)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
if self.train_ctx.master_process:
self.train_ctx.processed_tokens += int(iter_metrics[1])
batch_mfu_excluded_time += time.time() - mfu_excluded_time
# step the optimizer and scaler
self.scaler.step(self.optimizer)
self.scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
self.optimizer.zero_grad(set_to_none=True)
fwdbwd_time = time.time() - fwdbwd_time - batch_mfu_excluded_time
if self.should_log_metrics():
iter_time = time.time() - timer
# get loss as float. note: this is a CPU-GPU sync point
lossf = iter_metrics[0].item() / self.train_ctx.world_size
ppl = torch.exp(torch.tensor(lossf)).item()
accuracy = iter_metrics[2].item() / iter_metrics[3].item()
grad_norm = iter_metrics[4].item() / self.train_ctx.world_size
if self.config.mfu_flops_peak > 0 and self.train_ctx.iter_num > self.start_iter:
mfu = estimate_mfu(self.model_num_params, self.config, micro_batch_size * self.gradient_accumulation_steps, fwdbwd_time)
mfu_str = f'{mfu*100:.2f}%'
else:
mfu = -1.0
mfu_str = 'n/a'
mtu = fwdbwd_time/iter_time # model time utilization
iter_time_ms = iter_time * 1000
logger.info(
f"iter {self.train_ctx.iter_num:,}: loss {lossf:.4f}, ppl {ppl:.4f}, acc {accuracy:.4f}, "
f"iter time {iter_time_ms:.2f}ms, tokens {self.train_ctx.processed_tokens:,}, lr {lr:.8f}, "
f"mfu {mfu_str}, mtu {mtu*100:.2f}%, epoch {self.data_loader.epoch}, "
f"ETA: {self.calculate_eta()}"
)
if self.config.wandb_log:
metrics = {
"iter": self.train_ctx.iter_num,
"train/loss": lossf,
"train/acc": accuracy,
"train/ppl": ppl,
"train/grad_norm": grad_norm,
"train/lr": lr,
"train/mtu": mtu,
"train/tokens_per_sec": (total_batch_size/iter_time),
"train/tokens_per_gpu_per_sec": (total_batch_size/self.train_ctx.world_size/iter_time),
"train/tokens": self.train_ctx.processed_tokens,
"train/epoch": self.data_loader.epoch,
"train/total_batch_size": total_batch_size,
"train/iter_time": iter_time_ms,
}
if mfu > 0:
metrics['train/mfu'] = mfu
if self.config.dataset_seq_train:
metrics['train/ds_offset'] = self.data_loader.dataset_offset
wandb.log(metrics)
self.train_ctx.iter_num += 1
training_time = format_seconds_as_time((datetime.datetime.now() - self.start_timestamp).total_seconds())
logger.info(f"Training finished in {training_time}")
ckpt_file_name = 'final_ckpt'
self.save_checkpoint(ckpt_file_name, model_only=True, epoch_ckpt=True)
if self.config.epoch_completion_hook_program and self.train_ctx.master_process:
pid = self.run_checkpoint_hook_program(self.config.epoch_completion_hook_program, current_epoch, ckpt_file_name)
logger.info(f"Epoch completion hook program started with pid {pid}")
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from copy import deepcopy
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.train_utils import model_checkpoint_files_exist
from allamo.trainer.fsdp_trainer import FSDPTrainer
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def get_log_prob(logits, target_ids, ignore_index):
"""
Args:
logits: unnormalized logits [B, T, V]
target_ids: masked labels [B, T]
ignore_index: masked label id
Returns:
aggregated log probabilities [B, ]
"""
labels = target_ids.clone()
loss_mask = (labels != ignore_index)
labels[labels == ignore_index] = 0 # will be ignored for the loss calc
log_probs = F.log_softmax(logits, dim=-1)
per_token_logps = torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1)
return (per_token_logps * loss_mask).sum(-1)
class DPOTrainer(FSDPTrainer):
def __init__(self, config: AllamoConfiguration):
super().__init__(config)
def init_training(self):
super().init_training()
if model_checkpoint_files_exist(self.config.reference_checkpoint_name, self.checkpoint_dir):
ref_model_conf = deepcopy(self.model.config)
ref_model = AllamoTransformer(ref_model_conf)
self.load_model_checkpoint(ref_model, os.path.join(self.checkpoint_dir, f'model_{self.config.reference_checkpoint_name}.pt'), self.config)
logger.info("Configuring reference model with FSDP")
ref_model = FSDP(ref_model, **self.fsdp_config)
logger.info(f"Reference model configured with FSDP and sharding strategy {self.sharding_strategy}")
# compile the model - requires PyTorch 2.0
if self.config.compile:
logger.info("Compiling reference model")
try:
ref_model = torch.compile(ref_model, mode=self.config.compile_mode)
logger.info("Reference model compiled and ready to use")
except Exception as err:
logger.warning(f"Unable to compile the reference model: {err}")
self.ref_model = ref_model
self.ref_model.eval()
else:
self.ref_model = None
logger.warning("Reference model checkpoint not provided. Reference log probabilities must be supplied via DataLoader")
def forward(self, batch, last_micro_step):
policy_chosen_logits, _, _ = self.model(input_ids=batch["chosen_input_ids"], target_ids=batch["chosen_target_ids"])
policy_rejected_logits, _, _ = self.model(input_ids=batch["rejected_input_ids"], target_ids=batch["rejected_target_ids"])
policy_chosen_logps = get_log_prob(policy_chosen_logits, batch["chosen_target_ids"], self.config.ignore_index)
policy_rejected_logps = get_log_prob(policy_rejected_logits, batch["rejected_target_ids"], self.config.ignore_index)
if "reference_chosen_logps" in batch and batch["reference_chosen_logps"] is not None:
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
assert self.ref_model is not None
with torch.no_grad():
reference_chosen_logits, _, _ = self.ref_model(input_ids=batch["chosen_input_ids"], target_ids=batch["chosen_target_ids"])
reference_rejected_logits, _, _ = self.ref_model(input_ids=batch["rejected_input_ids"], target_ids=batch["rejected_target_ids"])
reference_chosen_logps = get_log_prob(reference_chosen_logits, batch["chosen_target_ids"], self.config.ignore_index)
reference_rejected_logps = get_log_prob(reference_rejected_logits, batch["rejected_target_ids"], self.config.ignore_index)
# calculate DPO loss
chosen_rewards = self.config.dpo_chosen_beta * (policy_chosen_logps - reference_chosen_logps)
rejected_rewards = self.config.dpo_rejected_beta * (policy_rejected_logps - reference_rejected_logps)
reward_penalty = self.config.dpo_penalty_lambda * torch.maximum(torch.zeros_like(policy_chosen_logps), reference_chosen_logps - policy_chosen_logps)
dpo_loss = -F.logsigmoid(chosen_rewards - rejected_rewards - reward_penalty).mean()
if self.gradient_accumulation_steps > 1:
dpo_loss = dpo_loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
chosen_unmasked_labels = torch.sum(batch["chosen_target_ids"].view(-1) != self.config.ignore_index).item()
rejected_unmasked_labels = torch.sum(batch["rejected_target_ids"].view(-1) != self.config.ignore_index).item()
unmasked_labels = chosen_unmasked_labels + rejected_unmasked_labels
accuracy = (policy_chosen_logits.max(2).indices == batch["chosen_target_ids"]).sum().item() / chosen_unmasked_labels
if last_micro_step and self.config.log_interval > 0 and self.train_ctx.iter_num % self.config.log_interval == 0:
chosen_rewards = chosen_rewards.detach()
rejected_rewards = rejected_rewards.detach()
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
reward_margins = (chosen_rewards - rejected_rewards).mean()
chosen_rewards = chosen_rewards.mean()
rejected_rewards = rejected_rewards.mean()
reward_penalty = reward_penalty.mean()
policy_chosen_logps = policy_chosen_logps.detach()
policy_rejected_logps = policy_rejected_logps.detach()
policy_accuracies = (policy_chosen_logps > policy_rejected_logps).float().mean()
policy_chosen_logps = policy_chosen_logps.mean()
policy_rejected_logps = policy_rejected_logps.mean()
metrics = torch.tensor([
1,
reward_accuracies.item(),
reward_margins.item(),
chosen_rewards.item(),
rejected_rewards.item(),
reward_penalty.item(),
policy_accuracies.item(),
policy_chosen_logps.item(),
policy_rejected_logps.item()
]).to(self.config.device)
metrics = self.dist_all_reduce(metrics, op=dist.ReduceOp.SUM)
if self.train_ctx.master_process:
cnt = metrics[0].item()
reward_accuracies = metrics[1].item() / cnt
reward_margins = metrics[2].item() / cnt
chosen_rewards = metrics[3].item() / cnt
rejected_rewards = metrics[4].item() / cnt
reward_penalty = metrics[5].item() / cnt
policy_accuracies = metrics[6].item() / cnt
policy_chosen_logps = metrics[7].item() / cnt
policy_rejected_logps = metrics[8].item() / cnt
if self.config.wandb_log:
wandb.log({
"iter": self.train_ctx.iter_num,
"dpo/rewards/accuracies": reward_accuracies,
"dpo/rewards/margins": reward_margins,
"dpo/rewards/chosen": chosen_rewards,
"dpo/rewards/rejected": rejected_rewards,
"dpo/rewards/penalty": reward_penalty,
"dpo/logps/chosen": policy_chosen_logps,
"dpo/logps/rejected": policy_rejected_logps,
"dpo/logps/accuracies": policy_accuracies
})
else:
logger.info(
f"iter {self.train_ctx.iter_num:,}: "
f"reward_acc={reward_accuracies:.4f} reward_marg={reward_margins:.4f} "
f"reward_chosen={chosen_rewards:.4f} reward_rejected={rejected_rewards:.4f} "
f"reward_penalty={reward_penalty:.4f}"
)
return dpo_loss, unmasked_labels, accuracy
import os
import shutil
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig, # general model non-sharded, non-flattened params
)
from allamo.trainer.base import BaseTrainer
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.parallelisms.fsdp_utils import parallelize_model_with_fsdp1
from allamo.parallelisms.fsdp2_utils import build_world_mesh, parallelize_model_with_fsdp2
from allamo.train_utils import (
get_model_checkpoint_path,
get_config_checkpoint_path,
get_optimizer_checkpoint_path,
)
class FSDPTrainer(BaseTrainer):
def __init__(self, config: AllamoConfiguration):
super().__init__(config)
def distributed(self):
return True
def init_torch(self, config: AllamoConfiguration):
super().init_torch(config)
if config.dtype == 'bfloat16-true':
raise Exception('Full bfloat16 training is not supported with FSDP')
self.fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if config.gradient_checkpointing:
self.fsdp_activation_checkpointing = True
config.gradient_checkpointing = False # control gradient checkpointing with FSDP
logger.info(
"Deactivated gradient checkpointing at the model configuration level. "
"Activated gradient checkpointing at the FSDP level."
)
else:
self.fsdp_activation_checkpointing = False
# DCP activates FSDP2
if self.config.distributed_checkpoint:
assert self.config.dtype != 'float16', "GradScaler is not functioning properly with FSDP2"
self.world_mesh = build_world_mesh(self.train_ctx, self.device_type)
else:
self.world_mesh = None
def init_training(self):
super().init_training()
self.model_config.gradient_checkpointing = False # AC is handled by FSDP
with torch.device('meta'):
model = AllamoTransformer(self.model_config)
self.model_num_params = model.model_num_params
if self.checkpoint_manager.checkpoint_name is None:
if self.world_mesh is None:
self.model = parallelize_model_with_fsdp1(model, self.config, self.fsdp_activation_checkpointing)
else:
self.model = parallelize_model_with_fsdp2(model, self.world_mesh, self.config, self.fsdp_activation_checkpointing)
self.model.to_empty(device=self.device_type)
self.model.init_model_weights()
logger.info("Initialized a new model from scratch")
self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
logger.info("Initializing optimizer from scratch")
else:
if self.config.distributed_checkpoint:
self.model = parallelize_model_with_fsdp2(model, self.world_mesh, self.config, self.fsdp_activation_checkpointing)
logger.info("model.to_empty")
self.model.to_empty(device=self.device_type)
logger.info("model.init_model_weights")
self.model.init_model_weights()
logger.info("checkpoint_manager.load_distributed_model_checkpoint")
self.checkpoint_manager.load_distributed_model_checkpoint(self.model)
logger.info("model.configure_optimizers")
self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
logger.info("checkpoint_manager.load_distributed_optimizer_checkpoint")
self.checkpoint_manager.load_distributed_optimizer_checkpoint(self.model, self.optimizer)
logger.info("ready")
else:
model.to_empty(device=self.device_type)
model.init_model_weights()
self.checkpoint_manager.load_regular_model_checkpoint(model)
self.model = parallelize_model_with_fsdp1(model, self.config, self.fsdp_activation_checkpointing)
self.optimizer = self.model.configure_optimizers(self.config, self.device_type)
self.load_optimizer_checkpoint(self.model, self.optimizer)
# initialize a GradScaler only for FSDP's built-in mixed precision with fp16
self.scaler = torch.amp.GradScaler(self.device_type, enabled=(self.config.dtype == 'float16'))
self.init_gradient_accumulation_scheduler()
self.log_init_learning_rate()
def load_optimizer_checkpoint(self, model, optimizer):
ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
if os.path.exists(ckpt_path):
# requires each rank to have the full dict in CPU memory to reduce communication
full_osd = torch.load(ckpt_path, map_location='cpu')
sharded_osd = FSDP.optim_state_dict_to_load(model, optimizer, full_osd)
optimizer.load_state_dict(sharded_osd)
logger.info(f"Shared optimizer state loaded from checkpoint {ckpt_path}")
else:
if self.train_ctx.master_process:
logger.warning("Optimizer checkpoint file not found. Initializing optimizer from scratch")
# helps saving checkpoint to a file
def save_checkpoint(self, ckpt_file_name, model_only=False, epoch_ckpt=False):
if self.config.distributed_checkpoint:
config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, None, self.model_config)
if self.train_ctx.master_process and not self.config.ignore_last_checkpoint_backup:
logger.warning("Backing up a previous checkpoint is not supported for distributed checkpoints")
model_ckpt_dir_path = self.checkpoint_manager.save_distributed_model_checkpoint(self.model, ckpt_file_name)
if model_only == False and self.checkpoint_manager.should_save_optimizer():
self.checkpoint_manager.save_distributed_optimizer_checkpoint(self.model, self.optimizer, ckpt_file_name)
if self.config.optimizer_checkpoint_interval is not None:
shutil.copytree(model_ckpt_dir_path, model_ckpt_dir_path + '-optim')
shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
else:
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, self.fullstate_save_policy):
full_msd = self.model.state_dict()
if self.train_ctx.master_process:
model_ckpt_file_path = get_model_checkpoint_path(ckpt_file_name, self.config.out_dir)
md5sum = self.checkpoint_manager.save_regular_model_checkpoint(full_msd, model_ckpt_file_path, epoch_ckpt)
del full_msd
config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, md5sum, self.model_config)
if model_only == False and self.checkpoint_manager.should_save_optimizer():
# pull all sharded optimizer states to rank0 cpu.
full_osd = FSDP.full_optim_state_dict(self.model, self.optimizer)
if self.train_ctx.master_process:
optim_ckpt_file_path = get_optimizer_checkpoint_path(ckpt_file_name, self.config.out_dir)
self.checkpoint_manager.save_regular_optimizer_checkpoint(full_osd, optim_ckpt_file_path)
del full_osd
if self.config.optimizer_checkpoint_interval is not None:
shutil.copy(model_ckpt_file_path, model_ckpt_file_path + '.optim')
shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
def dist_all_reduce(self, x: torch.Tensor, op: dist.ReduceOp):
if self.world_mesh is None:
dist.all_reduce(x, op=op)
return x
else:
return funcol.all_reduce(x, reduceOp=op.name, group=self.world_mesh["dp"])
def clip_grad_norm(self):
if self.world_mesh is None:
return self.model.clip_grad_norm_(self.config.grad_clip).item()
else:
return super().clip_grad_norm()
def forward(self, batch, last_micro_step):
logits, loss, _ = self.model(**batch)
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
if batch["target_weights"] is not None:
if self.config.weighted_loss_method == 'openchat':
target_weights = batch["target_weights"].sum()
# sum loss weights over all processes
target_weights = self.dist_all_reduce(target_weights, op=dist.ReduceOp.SUM)
loss = (self.train_ctx.world_size / target_weights) * loss
else:
loss = loss / torch.sum(batch["target_weights"] > 0).item()
unmasked_labels = torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
accuracy = (logits.max(2).indices == batch["target_ids"]).sum().item() / unmasked_labels
return loss, unmasked_labels, accuracy
def close(self):
dist.barrier()
dist.destroy_process_group()
import os
import shutil
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from allamo.trainer.base import BaseTrainer
from allamo.logging import logger
from allamo.model.model import AllamoTransformer
from allamo.configuration import AllamoConfiguration
from allamo.torch_utils import TORCH_DTYPE_MAP
from allamo.train_utils import (
get_model_checkpoint_path,
get_config_checkpoint_path,
get_optimizer_checkpoint_path,
)
class SimpleTrainer(BaseTrainer):
def __init__(self, config: AllamoConfiguration):
super().__init__(config)
if config.distributed_checkpoint:
config.distributed_checkpoint = False
logger.warn("PyTorch Distributed Checkpoint (DCP) is only available for FSDP training! Fallback to regular checkpoint")
def distributed(self):
return self.train_ctx.world_size > 1
def init_torch(self, config: AllamoConfiguration):
super().init_torch(config)
self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=TORCH_DTYPE_MAP[config.dtype])
if config.dtype == 'bfloat16-true':
# torch.set_float32_matmul_precision("high")
torch.set_default_dtype(torch.bfloat16)
def init_training(self):
super().init_training()
model = AllamoTransformer(self.model_config)
print("model: ", model)
self.model_num_params = model.model_num_params
if self.checkpoint_manager.is_checkpoint_available():
self.checkpoint_manager.load_regular_model_checkpoint(model)
else:
logger.info("New model initialized from scratch")
model.to(self.config.device)
if self.config.compile:
logger.info("Compiling model")
try:
model = torch.compile(model, mode=self.config.compile_mode)
logger.info("Model compiled and ready to use")
except Exception as err:
logger.warn(f"Unable to compile the model: {err}")
self.raw_model = model # neeeded in DDP training
self.model = model
# wrap model into DDP container
if self.distributed():
self.model = DDP(self.model, device_ids=[self.train_ctx.local_rank])
# initialize a GradScaler. If enabled=False scaler is a no-op
self.scaler = torch.amp.GradScaler(self.device_type, enabled=(self.config.dtype == 'float16' or self.config.dtype == 'bfloat16'))
# optimizer
self.optimizer = self.raw_model.configure_optimizers(self.config, self.device_type)
if self.checkpoint_manager.is_checkpoint_available():
self.load_optimizer_checkpoint(self.optimizer)
self.init_gradient_accumulation_scheduler()
self.log_init_learning_rate()
def load_optimizer_checkpoint(self, optimizer):
ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
if os.path.exists(ckpt_path):
state_dict = torch.load(ckpt_path, map_location=self.config.device)
optimizer.load_state_dict(state_dict)
logger.info(f"Optimizer state loaded from checkpoint {ckpt_path}")
else:
logger.warning("Optimizer checkpoint file not found. Initializing optimizer from scratch")
# helps saving checkpoint to a file
def save_checkpoint(self, ckpt_file_name, model_only=False, epoch_ckpt=False):
if not self.train_ctx.master_process:
return
model_ckpt_file_path = get_model_checkpoint_path(ckpt_file_name, self.config.out_dir)
md5sum = self.checkpoint_manager.save_regular_model_checkpoint(self.raw_model.state_dict(), model_ckpt_file_path, epoch_ckpt)
config_ckpt_file_path = get_config_checkpoint_path(ckpt_file_name, self.config.out_dir)
self.checkpoint_manager.save_config_checkpoint(config_ckpt_file_path, md5sum, self.model_config)
if model_only == False and self.checkpoint_manager.should_save_optimizer():
optim_ckpt_file_path = get_optimizer_checkpoint_path(ckpt_file_name, self.config.out_dir)
self.checkpoint_manager.save_regular_optimizer_checkpoint(self.optimizer.state_dict(), optim_ckpt_file_path)
if self.config.optimizer_checkpoint_interval is not None:
shutil.copy(model_ckpt_file_path, model_ckpt_file_path + '.optim')
shutil.copy(config_ckpt_file_path, config_ckpt_file_path + '.optim')
logger.info(f"checkpoint files saved in {self.config.out_dir}")
def should_evaluate(self):
return super().should_evaluate() and self.train_ctx.master_process
def forward(self, batch, last_micro_step):
if self.distributed():
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
self.model.require_backward_grad_sync = last_micro_step
with self.ctx:
logits, loss, _ = self.model(**batch)
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps # scale the loss to account for micro steps
if batch["target_weights"] is not None:
if self.config.weighted_loss_method == 'openchat':
target_weights = batch["target_weights"].sum()
# sum loss weights over all processes
target_weights = self.dist_all_reduce(target_weights, op=dist.ReduceOp.SUM)
loss = (self.train_ctx.world_size / target_weights) * loss
else:
loss = loss / torch.sum(batch["target_weights"] > 0).item()
unmasked_labels = torch.sum(batch["target_ids"].view(-1) != self.config.ignore_index).item()
accuracy = (logits.max(2).indices == batch["target_ids"]).sum().item() / unmasked_labels
return loss, unmasked_labels, accuracy
def close(self):
if self.distributed():
dist.barrier()
dist.destroy_process_group()
import os
import uuid
from dataclasses import dataclass
@dataclass
class TrainingContext:
dp: int = -1
tp: int = 1
pp: int = 1
def __post_init__(self):
self.world_size = int(os.environ.get('WORLD_SIZE', 1))
self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
self.rank = int(os.environ.get('RANK', 0))
self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
self.run_uuid = str(uuid.uuid4())
self.training_uuid = self.run_uuid
self.iter_num = 0
self.best_train_loss = 1e2
self.best_val_loss = 1e2
self.processed_tokens = 0
self._validate()
def _validate(self):
if self.pp < 1:
self.pp = 1
if self.tp < 1:
self.tp = 1
if self.dp < 1:
self.dp = self.world_size // (self.tp * self.pp)
assert self.dp > 0
assert self.tp > 0
assert self.pp == 1, f"pp({self.pp}) > 1 is not supported"
assert self.dp * self.tp * self.pp == self.world_size, f"dp({self.dp}) * tp({self.tp}) * pp({self.pp}) != world_size({self.world_size})"
@property
def master_process(self):
return self.rank == 0
This diff is collapsed.
import os
import requests
import tiktoken
import numpy as np
import torch
# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)
with open(input_file_path, 'r', encoding='utf-8') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# encode with tiktoken gpt2 bpe
# enc = tiktoken.get_encoding("gpt2")
enc = tiktoken.get_encoding("cl100k_base")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
# torch.save(train_ids, os.path.join(os.path.dirname(__file__), 'train.pt'))
# torch.save(val_ids, os.path.join(os.path.dirname(__file__), 'val.pt'))
# train.bin has 301,966 tokens
# val.bin has 36,059 tokens
# tiny shakespeare
Tiny shakespeare, of the good old char-rnn fame :)
After running `prepare.py`:
- train.bin has 301,966 tokens
- val.bin has 36,059 tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment