Commit c2faed49 authored by chenzk's avatar chenzk
Browse files

v1.0.3

parent 207c6325
......@@ -118,7 +118,7 @@ class CheckpointManager:
def load_regular_model_checkpoint(self, model):
model_ckpt_file_path = get_model_checkpoint_path(self.checkpoint_name, self.checkpoint_dir)
state_dict = torch.load(model_ckpt_file_path, map_location='cpu')
state_dict = torch.load(model_ckpt_file_path, map_location='cpu', weights_only=True)
remove_unwanted_prefix_from_model_state_dict(state_dict)
model.load_state_dict(state_dict)
if self.config.log_checkpoint_md5_on_load and self.train_ctx.master_process:
......
......@@ -3,7 +3,8 @@ import json
import logging
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List
logger = logging.getLogger("AllamoConfiguration")
......@@ -97,13 +98,21 @@ class AllamoConfiguration:
fsdp_sharding_strategy: str = 'FULL_SHARD'
epoch_completion_hook_program: str = None
regular_checkpoint_hook_program: str = None
training_type: str = 'pre'
attention_implementation: str = 'sdpa'
tensor_parallel_degree: int = 1
# freezing params
freeze_embeddings: bool = False
freeze_lm_head: bool = False
freeze_layers: bool = False
keep_layers_trainable: List[int] = field(default_factory=list)
# DPO params
dpo_chosen_beta: float = 0.5
dpo_rejected_beta: float = 0.1
dpo_penalty_lambda: float = 50.0
reference_checkpoint_name: str = 'ref_ckpt'
training_type: str = 'pre'
attention_implementation: str = 'sdpa'
tensor_parallel_degree: int = 1
# inference params
prompt: str = "\n"
......@@ -203,13 +212,20 @@ class AllamoConfiguration:
parser.add_argument('--fsdp_sharding_strategy', type=str, choices=['FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2', 'SHARD_GRAD_OP', 'NO_SHARD'], help='FSDP sharding strategy')
parser.add_argument('--epoch_completion_hook_program', type=str, help='Path to the program/script to be executed after the epoch ends and the checkpoint is saved')
parser.add_argument('--regular_checkpoint_hook_program', type=str, help='Path to the program/script to be executed after the regualar checkpoint is saved')
parser.add_argument('--training_type', type=str, choices=['pre', 'sft', 'dpo'], help='Specifies the type of training: pre (pre-training), sft (supervised fine-tuning), or dpo (direct preference optimization)')
parser.add_argument('--attention_implementation', type=str, choices=['sdpa', 'flash_attention_2', 'eager'], help='Specifies attention implementation')
parser.add_argument('--tensor_parallel_degree', type=int, help='Specifies the degree of tensor parallelism. Activates TP when it is greater than 1')
parser.add_argument('--freeze_embeddings', action='store_true', help='Freeze embeddings')
parser.add_argument('--freeze_lm_head', action='store_true', help='Freeze lm_head')
parser.add_argument('--freeze_layers', action='store_true', help='Freeze all layers')
parser.add_argument('--keep_layers_trainable', type=int, nargs='*', default=[], help='List of layer indices to keep trainable (e.g., --keep_layers_trainable 0 31)')
parser.add_argument('--dpo_chosen_beta', type=float, help='Temperature parameter for the chosen part of the DPO loss, typically something in the range of 0.1 to 0.5')
parser.add_argument('--dpo_rejected_beta', type=float, help='Temperature parameter for the rejected part of the DPO loss, typically something in the range of 0.1 to 0.5')
parser.add_argument('--dpo_penalty_lambda', type=float, help='Temperature parameter for penalty-positive in the DPO loss, typically in the range of 1 to 100')
parser.add_argument('--reference_checkpoint_name', type=str, help='Checkpoint name for the reference model')
parser.add_argument('--training_type', type=str, choices=['pre', 'sft', 'dpo'], help='Specifies the type of training: pre (pre-training), sft (supervised fine-tuning), or dpo (direct preference optimization)')
parser.add_argument('--attention_implementation', type=str, choices=['sdpa', 'flash_attention_2', 'eager'], help='Specifies attention implementation')
parser.add_argument('--tensor_parallel_degree', type=int, help='Specifies the degree of tensor parallelism. Activates TP when it is greater than 1')
parser.add_argument('--prompt', type=str, help='Prompt for generating text. Can also specify a file, use as: "FILE:prompt.txt"')
parser.add_argument('--num_samples', type=int, help='Number of samples to generate')
parser.add_argument('--max_new_tokens', type=int, help='Number of tokens to generate in each sample')
......
......@@ -64,7 +64,7 @@ class AllamoDataset:
self.processed_files.append(load_dataset_file)
new_data = None
if load_dataset_file.endswith('.bin'):
# assert self.training_type == 'pre', 'NumPy format is supported only for pre-training'
assert self.training_type == 'pre', 'NumPy format is supported only for pre-training'
step_size = self.world_size * self.sample_size
new_data = torch.from_numpy(np.fromfile(load_dataset_file, dtype=np.uint16).astype(np.int16))
if step_size > len(new_data):
......@@ -77,7 +77,7 @@ class AllamoDataset:
new_data = self.limit_samples_to_rank(new_data)
elif load_dataset_file.endswith('.pt'):
assert self.training_type != 'dpo', 'DPO training only supports the ALM format'
new_data = torch.load(load_dataset_file, map_location='cpu')
new_data = torch.load(load_dataset_file, map_location='cpu', weights_only=True)
if isinstance(new_data, torch.Tensor):
step_size = self.world_size * self.sample_size
if step_size > len(new_data):
......
......@@ -71,3 +71,4 @@ def init_torch(train_ctx: TrainingContext, config: AllamoConfiguration, distribu
configure_torch(config, train_ctx.rank)
# override_numa_affinity(train_ctx.local_rank)
\ No newline at end of file
......@@ -11,6 +11,7 @@ import torch.distributed as dist
from allamo.checkpoint.checkpoint_manager import CheckpointManager
from allamo.configuration import AllamoConfiguration
from allamo.model.model import AllamoTransformer
from allamo.dataset.data_loader import AllamoDataLoader
from allamo.logging import configure_logger, logger
from allamo.model.attentions import attention_version
......@@ -54,7 +55,23 @@ class BaseTrainer:
self.checkpoint_manager.init_checkpoint()
self.data_loader.load_datasets()
self.model_config = create_model_config(self.config)
def freeze_model_params(self, model: AllamoTransformer):
if self.config.freeze_embeddings:
model.freeze_params(model.tok_embeddings)
logger.info("Embeddings frozen")
if self.config.freeze_lm_head:
model.freeze_params(model.norm)
model.freeze_params(model.lm_head)
logger.info("LM head frozen")
if self.config.freeze_layers:
for layer_id in range(self.model_config.n_layer):
if layer_id not in self.config.keep_layers_trainable:
model.freeze_params(model.layers[layer_id])
logger.info(f"Layer {layer_id} frozen")
else:
logger.info(f"Layer {layer_id} kept trainable")
def init_gradient_accumulation_scheduler(self):
if self.config.grad_accum_schedule:
self.config.grad_accum_max = self.config.gradient_accumulation_steps
......
......@@ -59,7 +59,9 @@ class FSDPTrainer(BaseTrainer):
with torch.device('meta'):
model = AllamoTransformer(self.model_config)
self.model_num_params = model.model_num_params
self.freeze_model_params(model) # Optionally freezes model parameters depending on the configuration
if self.checkpoint_manager.checkpoint_name is None:
if self.world_mesh is None:
self.model = parallelize_model_with_fsdp1(model, self.config, self.fsdp_activation_checkpointing)
......@@ -106,7 +108,7 @@ class FSDPTrainer(BaseTrainer):
ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
if os.path.exists(ckpt_path):
# requires each rank to have the full dict in CPU memory to reduce communication
full_osd = torch.load(ckpt_path, map_location='cpu')
full_osd = torch.load(ckpt_path, map_location='cpu', weights_only=True)
sharded_osd = FSDP.optim_state_dict_to_load(model, optimizer, full_osd)
optimizer.load_state_dict(sharded_osd)
logger.info(f"Shared optimizer state loaded from checkpoint {ckpt_path}")
......
......@@ -41,6 +41,9 @@ class SimpleTrainer(BaseTrainer):
model = AllamoTransformer(self.model_config)
print("model: ", model)
self.model_num_params = model.model_num_params
self.freeze_model_params(model) # Optionally freezes model parameters depending on the configuration
if self.checkpoint_manager.is_checkpoint_available():
self.checkpoint_manager.load_regular_model_checkpoint(model)
else:
......@@ -75,7 +78,7 @@ class SimpleTrainer(BaseTrainer):
def load_optimizer_checkpoint(self, optimizer):
ckpt_path = get_optimizer_checkpoint_path(self.checkpoint_manager.checkpoint_name, self.checkpoint_manager.checkpoint_dir)
if os.path.exists(ckpt_path):
state_dict = torch.load(ckpt_path, map_location=self.config.device)
state_dict = torch.load(ckpt_path, map_location=self.config.device, weights_only=True)
optimizer.load_state_dict(state_dict)
logger.info(f"Optimizer state loaded from checkpoint {ckpt_path}")
else:
......
......@@ -68,7 +68,7 @@ class AllamoSampler:
else:
raise Exception('Tokenizer is not provided. Please specify either a Tiktoken tokenizer or a HuggingFace tokenizer')
# ensure that the tokenizer and model vocabulary sizes are equal
# assert len(tokenizer) == self.model.config.vocab_size
assert len(tokenizer) == self.model.config.vocab_size
self.tokenizer = tokenizer
def tokenize_prompt(self, text: str):
......
......@@ -36,7 +36,7 @@ def adjust_model(input_dir_path, input_checkpoint_name_base, output_dir_path, ou
logger.info(f"loading checkpoint from {input_dir_path}...")
with open(get_config_checkpoint_path(input_checkpoint_name_base, input_dir_path), "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu')
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu', weights_only=True)
unwanted_prefix = '_orig_mod.'
for k,v in list(model_checkpoint.items()):
......
......@@ -8,7 +8,7 @@ import torch
from allamo.logging import configure_logger, logger
def convert_ckpt(config_ckpt):
config_checkpoint = torch.load(config_ckpt, map_location='cpu')
config_checkpoint = torch.load(config_ckpt, map_location='cpu', weights_only=True)
json_checkpoint = {}
if 'model_args' in config_checkpoint:
json_checkpoint['model_args'] = dataclasses.asdict(config_checkpoint['model_args'])
......
......@@ -37,7 +37,7 @@ def depth_up_scale_model(input_dir_path, input_checkpoint_name_base, output_dir_
logger.info(f"loading checkpoint from {input_dir_path}...")
with open(get_config_checkpoint_path(input_checkpoint_name_base, input_dir_path), "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu')
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu', weights_only=True)
unwanted_prefix = '_orig_mod.'
for k,v in list(model_checkpoint.items()):
......
......@@ -35,7 +35,7 @@ def write_model(checkpoint_dir_path, checkpoint_name_base, hf_model_path, hf_mod
logger.info(f"loading checkpoint from {checkpoint_dir_path}...")
with open(get_config_checkpoint_path(checkpoint_name_base, checkpoint_dir_path), "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
model_checkpoint = torch.load(get_model_checkpoint_path(checkpoint_name_base, checkpoint_dir_path), map_location='cpu')
model_checkpoint = torch.load(get_model_checkpoint_path(checkpoint_name_base, checkpoint_dir_path), map_location='cpu', weights_only=True)
allamo_transformer_config = AllamoTransformerConfig(**config_checkpoint['model_args'])
n_layers = allamo_transformer_config.n_layer
......
......@@ -48,7 +48,7 @@ def import_model(input_base_path, output_model_path, max_num_layers, max_block_s
logger.info(f"loading llama weights")
# Sharded models are not supported!
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu", weights_only=True)
logger.info(f"copying llama weights to the model")
theta = 10000.0
......
......@@ -5,7 +5,6 @@ import numpy as np
import os.path
import pandas as pd
from allamo.logging import configure_logger, logger
import torch
EOS_TOKEN = "</s>"
......@@ -54,13 +53,11 @@ def encode_file(input_file, output_file, tokenizer):
enc_data = tokenizer.encode(input_file.read())
enc_data = np.array(enc_data, dtype=np.uint16)
enc_data.tofile(output_file)
# torch.save(enc_data, output_file)
tokens = len(enc_data)
enc_data = tokenizer.encode(EOS_TOKEN)
enc_data = np.array(enc_data, dtype=np.uint16)
enc_data.tofile(output_file)
# torch.save(enc_data, output_file)
tokens += len(enc_data)
return tokens
......@@ -72,8 +69,6 @@ def create_datasets(txt_files_df, tokenizer, input_data_dir, output_data_dir):
files_cnt = 0
train_file_path = os.path.join(output_data_dir, 'train.bin')
val_file_path = os.path.join(output_data_dir, 'val.bin')
# train_file_path = os.path.join(output_data_dir, 'train.pt')
# val_file_path = os.path.join(output_data_dir, 'val.pt')
with open(train_file_path, 'wb+') as train_file, open(val_file_path, 'wb+') as val_file:
# Process each of the txt files
for _, row in txt_files_df.iterrows():
......
......@@ -32,7 +32,7 @@ def prune_model(input_dir_path, input_checkpoint_name_base, output_dir_path, out
logger.info(f"loading checkpoint from {input_dir_path}...")
with open(get_config_checkpoint_path(input_checkpoint_name_base, input_dir_path), "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu')
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu', weights_only=True)
unwanted_prefix = '_orig_mod.'
for k,v in list(model_checkpoint.items()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment