Commit 0371621a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1989 canceled with stages
"""
Use this file to import Huggingface MistralForCausalLM weights to ALLaMo format.
"""
import argparse
import dataclasses
import json
import os
import torch
from transformers import MistralForCausalLM
from allamo.logging import configure_logger, logger
from allamo.model.model import AllamoTransformerConfig, AllamoTransformer
def import_model(hf_model_path, output_model_path):
logger.info(f"Importing Huggingface MistralForCausalLM weights")
hf_model = MistralForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
logger.info(f"Huggingface MistralForCausalLM model loaded")
assert hf_model.config.hidden_act == "silu"
config = AllamoTransformerConfig()
config.block_size = hf_model.config.sliding_window # hf_model.config.max_position_embeddings
config.vocab_size = hf_model.config.vocab_size
config.n_layer = hf_model.config.num_hidden_layers
config.n_head = hf_model.config.num_attention_heads
config.n_embd = hf_model.config.hidden_size
config.intermediate_size = hf_model.config.intermediate_size
config.head_size = config.n_embd // config.n_head
config.num_kv_heads = hf_model.config.num_key_value_heads
config.sliding_window = hf_model.config.sliding_window
config.dropout = 0.0
config.bias = False
config.norm_eps = hf_model.config.rms_norm_eps
config.rope_freq_base = int(hf_model.config.rope_theta)
logger.info(f"initializing vanilla ALLaMo model")
model = AllamoTransformer(config)
logger.info(f"preparing weights")
state_dicts_map = {}
sd_hf_model = hf_model.state_dict()
model_sd = model.state_dict()
for layer_i in range(config.n_layer):
state_dicts_map[f"layers.{layer_i}.attention.q_proj.weight"] = f"model.layers.{layer_i}.self_attn.q_proj.weight"
state_dicts_map[f"layers.{layer_i}.attention.k_proj.weight"] = f"model.layers.{layer_i}.self_attn.k_proj.weight"
state_dicts_map[f"layers.{layer_i}.attention.v_proj.weight"] = f"model.layers.{layer_i}.self_attn.v_proj.weight"
state_dicts_map[f"layers.{layer_i}.attention.c_proj.weight"] = f"model.layers.{layer_i}.self_attn.o_proj.weight"
state_dicts_map[f"layers.{layer_i}.attention.rotary_emb.inv_freq"] = f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"
state_dicts_map[f"layers.{layer_i}.feed_forward.gate_proj.weight"] = f"model.layers.{layer_i}.mlp.gate_proj.weight"
state_dicts_map[f"layers.{layer_i}.feed_forward.down_proj.weight"] = f"model.layers.{layer_i}.mlp.down_proj.weight"
state_dicts_map[f"layers.{layer_i}.feed_forward.up_proj.weight"] = f"model.layers.{layer_i}.mlp.up_proj.weight"
state_dicts_map[f"layers.{layer_i}.attention_norm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight"
state_dicts_map[f"layers.{layer_i}.ffn_norm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight"
state_dicts_map["tok_embeddings.weight"] = "model.embed_tokens.weight"
state_dicts_map["norm.weight"] = "model.norm.weight"
state_dicts_map["lm_head.weight"] = "lm_head.weight"
logger.info(f"checking params coverage")
for k, v in model_sd.items():
if k not in state_dicts_map:
logger.info(f"{k} param won't be updated in the ALLaMo model!")
for k, v in sd_hf_model.items():
if k not in state_dicts_map.values():
logger.info(f"{k} param won't be copied to the ALLaMo model!")
logger.info(f"copying params to the ALLaMo model")
param_count = 0
for k, v in state_dicts_map.items():
if not k.endswith('rotary_emb.inv_freq'):
assert sd_hf_model[v].shape == model_sd[k].shape
with torch.no_grad():
model_sd[k].copy_(sd_hf_model[v])
param_count += model_sd[k].numel()
logger.info(f"{param_count} params copied to the ALLaMo model")
for k, _ in model_sd.items():
if not torch.all(torch.eq(model_sd[k], sd_hf_model[state_dicts_map[k]])):
logger.info(f"{k} param in the ALLaMo model is not the same as {state_dicts_map[k]} param in the source model!")
logger.info(f"params verified")
ckpt_file_name = 'import_ckpt'
config_checkpoint = {
'model_args': dataclasses.asdict(config)
}
ckpt_file_path = os.path.join(output_model_path, f'config_{ckpt_file_name}.json')
logger.info(f"saving config checkpoint to {ckpt_file_path}")
with open(ckpt_file_path, "w", encoding="utf-8") as f:
json.dump(config_checkpoint, f, indent=4, ensure_ascii=False)
ckpt_file_path = os.path.join(output_model_path, f'model_{ckpt_file_name}.pt')
logger.info(f"saving model checkpoint to {ckpt_file_path}")
torch.save(model_sd, ckpt_file_path)
logger.info(f"checkpoint files saved in {output_model_path}")
if __name__ == '__main__':
configure_logger()
parser = argparse.ArgumentParser(description='Import Huggingface MistralForCausalLM weights to ALLaMo format')
parser.add_argument(
"--huggingface_model",
help="Huggingface model path",
)
parser.add_argument(
"--output_dir",
help="Path to output directory",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
import_model(
hf_model_path=args.huggingface_model,
output_model_path=args.output_dir,
)
logger.info("import completed")
"""
Use this file to import original LLaMA weights to ALLaMo format.
"""
import argparse
import dataclasses
import json
import os
import torch
import shutil
from allamo.logging import configure_logger, logger
from allamo.model.model import AllamoTransformerConfig, AllamoTransformer
DEFAULT_BLOCK_SIZE = 4096
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
# permute for sliced rotary
def permute(w, dim, n_heads):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
def import_model(input_base_path, output_model_path, max_num_layers, max_block_size):
logger.info(f"start importing llama weights")
params = read_json(os.path.join(input_base_path, "params.json"))
config = AllamoTransformerConfig()
config.block_size = min(max_block_size, DEFAULT_BLOCK_SIZE) if max_block_size else DEFAULT_BLOCK_SIZE
config.vocab_size = 32000
config.n_layer = min(max_num_layers, params["n_layers"]) if max_num_layers else params["n_layers"]
config.n_head = params["n_heads"]
config.n_embd = params["dim"]
config.head_size = config.n_embd // config.n_head
config.dropout = 0.0
config.bias = False
config.multiple_of = params["multiple_of"]
config.norm_eps = params["norm_eps"]
# Switch to half tensors
torch.set_default_tensor_type(torch.cuda.HalfTensor)
logger.info(f"initializing vanilla model")
model = AllamoTransformer(config)
logger.info(f"loading llama weights")
# Sharded models are not supported!
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
logger.info(f"copying llama weights to the model")
theta = 10000.0
inv_freq = 1.0 / (theta ** (torch.arange(0, config.head_size, 2).float() / config.head_size))
# Switch back to full tensors
torch.set_default_tensor_type(torch.FloatTensor)
param_count = 0
model_sd = model.state_dict()
for layer_i in range(config.n_layer):
logger.info(f"copying weights in layer {layer_i}")
state_dict = {
f"layers.{layer_i}.attention.q_proj.weight": permute(loaded[f"layers.{layer_i}.attention.wq.weight"], config.n_embd, config.n_head),
f"layers.{layer_i}.attention.k_proj.weight": permute(loaded[f"layers.{layer_i}.attention.wk.weight"], config.n_embd, config.n_head),
f"layers.{layer_i}.attention.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"layers.{layer_i}.attention.c_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"layers.{layer_i}.feed_forward.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"layers.{layer_i}.feed_forward.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"layers.{layer_i}.feed_forward.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"layers.{layer_i}.attention_norm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"layers.{layer_i}.ffn_norm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
state_dict[f"layers.{layer_i}.attention.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
assert v.shape == model_sd[k].shape
with torch.no_grad():
model_sd[k].copy_(v)
param_count += v.numel()
state_dict = {
"tok_embeddings.weight": loaded["tok_embeddings.weight"],
"norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
for k, v in state_dict.items():
assert v.shape == model_sd[k].shape
with torch.no_grad():
model_sd[k].copy_(v)
param_count += v.numel()
logger.info(f"{param_count} params imported to the model")
ckpt_file_name = 'import_ckpt'
config_checkpoint = {
'model_args': dataclasses.asdict(config)
}
ckpt_file_path = os.path.join(output_model_path, f'config_{ckpt_file_name}.json')
logger.info(f"saving config checkpoint to {ckpt_file_path}")
with open(ckpt_file_path, "w", encoding="utf-8") as f:
json.dump(config_checkpoint, f, indent=4, ensure_ascii=False)
ckpt_file_path = os.path.join(output_model_path, f'model_{ckpt_file_name}.pt')
logger.info(f"saving model checkpoint to {ckpt_file_path}")
torch.save(model_sd, ckpt_file_path)
logger.info(f"checkpoint files saved in {output_model_path}")
def import_tokenizer(input_tokenizer_path, output_model_path, max_block_size):
logger.info(f"start importing tokenizer")
model_max_length = min(max_block_size, DEFAULT_BLOCK_SIZE) if max_block_size else DEFAULT_BLOCK_SIZE
write_json({}, os.path.join(output_model_path, "special_tokens_map.json"))
write_json(
{
"bos_token": "<s>",
"eos_token": "</s>",
"model_max_length": model_max_length,
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
},
os.path.join(output_model_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(output_model_path, "tokenizer.model"))
logger.info(f"tokenizer files saved in {output_model_path}")
if __name__ == '__main__':
configure_logger()
parser = argparse.ArgumentParser(description='Import LLaMA weights to ALLaMo model')
parser.add_argument('--input_data_dir', type=str, help='Path to a directory with LLaMA model files')
parser.add_argument('--input_tokenizer_path', type=str, help='Path to LLaMA tokenizer.model file')
parser.add_argument('--output_data_dir', type=str, required=True, help='Path to output directory')
parser.add_argument('--max_num_layers', type=int, help='Crop layers to make the model smaller')
parser.add_argument('--max_block_size', type=int, help='Crop block size to make the model smaller')
args = parser.parse_args()
os.makedirs(args.output_data_dir, exist_ok=True)
if args.input_tokenizer_path:
import_tokenizer(args.input_tokenizer_path, args.output_data_dir, args.max_block_size)
if args.input_data_dir:
import_model(args.input_data_dir, args.output_data_dir, args.max_num_layers, args.max_block_size)
logger.info("import completed")
import argparse
import datetime
import glob
import numpy as np
import os.path
import pandas as pd
from allamo.logging import configure_logger, logger
import torch
EOS_TOKEN = "</s>"
def init_tokenizer(output_dir, tiktoken_tokenizer_name, hf_tokenizer_path):
vocab_size = None
if hf_tokenizer_path:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)
if vocab_size is None:
vocab_size = len(tokenizer)
logger.info(f"HuggingFace {hf_tokenizer_path} tokenizer loaded with the vocab size {vocab_size}")
elif tiktoken_tokenizer_name:
import tiktoken
tokenizer = tiktoken.get_encoding(tiktoken_tokenizer_name)
if vocab_size is None:
vocab_size = tokenizer.max_token_value + 1 # values start from 0
logger.info(f"Tiktoken {tiktoken_tokenizer_name} tokenizer loaded with the vocab size {vocab_size}")
else:
raise Exception('Tokenizer is not provided. Please specify either a Tiktoken tokenizer or a HuggingFace tokenizer')
return tokenizer
def load_list_of_txt_files(index_file_path, input_data_dir, data_split):
if index_file_path:
# Load the csv file into a pandas dataframe
index_df = pd.read_csv(index_file_path)
# Replace any NaN values in the "File" column with an empty string
index_df['File'] = index_df['File'].fillna('')
# Filter the dataframe to only include rows where the "File" column ends with ".txt"
txt_files_df = index_df[(index_df['File'].str.endswith('.txt'))]
if 'Split' not in txt_files_df.columns:
txt_files_df['Split'] = data_split if data_split else 'train'
elif input_data_dir:
txt_files = glob.glob(os.path.join(input_data_dir, "*.txt"))
txt_files_df = pd.DataFrame({'File': txt_files})
txt_files_df['Split'] = data_split if data_split else 'train'
else:
raise Exception('Either an index file or an input data dir must be provided')
logger.info(f"{len(txt_files_df)} txt files found to process")
return txt_files_df
def encode_file(input_file, output_file, tokenizer):
enc_data = tokenizer.encode(input_file.read())
enc_data = np.array(enc_data, dtype=np.uint16)
enc_data.tofile(output_file)
# torch.save(enc_data, output_file)
tokens = len(enc_data)
enc_data = tokenizer.encode(EOS_TOKEN)
enc_data = np.array(enc_data, dtype=np.uint16)
enc_data.tofile(output_file)
# torch.save(enc_data, output_file)
tokens += len(enc_data)
return tokens
def create_datasets(txt_files_df, tokenizer, input_data_dir, output_data_dir):
train_tokens = 0
val_tokens = 0
files_cnt = 0
train_file_path = os.path.join(output_data_dir, 'train.bin')
val_file_path = os.path.join(output_data_dir, 'val.bin')
# train_file_path = os.path.join(output_data_dir, 'train.pt')
# val_file_path = os.path.join(output_data_dir, 'val.pt')
with open(train_file_path, 'wb+') as train_file, open(val_file_path, 'wb+') as val_file:
# Process each of the txt files
for _, row in txt_files_df.iterrows():
filename = os.path.join(input_data_dir, row['File']) if input_data_dir else row['File']
if not os.path.isfile(filename):
logger.info(f"File {filename} does not exist.")
continue
with open(filename, 'r', encoding="utf-8") as txt_file:
logger.info(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - Start processing {filename}")
if row['Split'] == 'test':
tokens = encode_file(txt_file, val_file, tokenizer)
val_tokens += tokens
else:
tokens = encode_file(txt_file, train_file, tokenizer)
train_tokens += tokens
logger.info(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {filename} added ({tokens} tokens) to the {row['Split']} dataset")
files_cnt += 1
total_tokens = train_tokens + val_tokens
logger.info(f"Datasets created in {output_data_dir} from {files_cnt} files. Tokens: {total_tokens:,} (Train: {train_tokens:,} Val: {val_tokens:,})")
if __name__ == '__main__':
configure_logger()
parser = argparse.ArgumentParser(description='Prepare your datasets')
parser.add_argument('--index_file', type=str, help='Path to an index file')
parser.add_argument('--input_data_dir', type=str, help='Path to a directory with txt files')
parser.add_argument('--data_split', type=str, default='train', choices=['train', 'test'], help='Data split')
parser.add_argument('--output_data_dir', type=str, required=True, help='Path to a directory for output dataset files')
parser.add_argument('--tiktoken_tokenizer_name', type=str, help='Tiktoken tokenizer name')
parser.add_argument('--hf_tokenizer_path', type=str, help='HuggingFace tokenizer path')
args = parser.parse_args()
tokenizer = init_tokenizer(args.output_data_dir, args.tiktoken_tokenizer_name, args.hf_tokenizer_path)
txt_files_df = load_list_of_txt_files(args.index_file, args.input_data_dir, args.data_split)
create_datasets(txt_files_df, tokenizer, args.input_data_dir, args.output_data_dir)
python prepare_datasets.py --index_file data/train_index.txt --input_data_dir data --data_split train --output_data_dir data --tiktoken_tokenizer_name "cl100k_base"
# python prepare_datasets.py --index_file data/test_index.txt --input_data_dir data --data_split test --output_data_dir data --tiktoken_tokenizer_name "cl100k_base"
"""
Use this file to create a dataset for DPO (Direct Preference Optimization) training
The script performs the following steps:
1. Reads the input JSONL file with dialogues (single or multi-turn)
2. Applies the OpenChatML or Llama2 chat template to each dialogue
3. Tokenizes the formatted dialogues
6. Saves the processed data in a binary format
7. Generates and saves summary statistics for the dataset
Example record with signe-turn dialogue:
```json
{"messages": [{"role": "user", "content": "1+2=?"}], "chosen": {"role": "assistant", "content": "3"}, "rejected": {"role": "assistant", "content": "4"}}
```
Example record with multi-turn dialogue:
```json
{"messages": [{"role": "user", "content": "1+2=?"}, {"role": "assistant", "content": "3"}, {"role": "user", "content": "2+2=?"}], "chosen": {"role": "assistant", "content": "4"}, "rejected": {"role": "assistant", "content": "5"}
```
"""
import argparse
import concurrent.futures
import joblib
import json
import numpy as np
import os
import pyarrow as pa
import pyarrow.parquet as pq
import random
import time
from collections import Counter
from itertools import chain
from tqdm import tqdm
from transformers import AutoTokenizer
from allamo.logging import configure_logger, logger
def tokenize_openchatml_conversation(messages, tokenizer, ignore_index):
result = {'input_ids': [], 'target_ids': []}
last_idx = len(messages) - 1
for idx, entry in enumerate(messages):
if entry["role"] == 'assistant':
pre_content = '<|im_start|>assistant\n'
pre_input_ids = tokenizer.encode(pre_content, add_special_tokens=False)
pre_input_ids_len = len(pre_input_ids)
content = entry['content'] + '<|im_end|>\n'
if idx == last_idx:
content += "</s>"
full_input_ids = tokenizer.encode(pre_content + content, add_special_tokens=False)
if full_input_ids[:pre_input_ids_len] == pre_input_ids:
result['input_ids'].extend(full_input_ids)
result['target_ids'].extend(list(
ignore_index if i < pre_input_ids_len else full_input_ids[i] for i in range(len(full_input_ids))
))
else:
logger.warning("Tokenization inconsistency detected. Performing separate tokenization")
content_input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(pre_input_ids)
result['input_ids'].extend(content_input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(pre_input_ids_len)))
result['target_ids'].extend(content_input_ids)
else:
content = "<s><|im_start|>" if idx == 0 else "<|im_start|>"
content += entry["role"] + '\n' + entry["content"] + '<|im_end|>\n'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(len(input_ids))))
assert len(result['input_ids']) == len(result['target_ids'])
return result
def tokenize_llama2_conversation(messages, tokenizer, ignore_index):
result = {'input_ids': [], 'target_ids': []}
if messages[0]['role'] == 'system':
sys_message = f"<<SYS>>\n{messages[0]['content']}\n<</SYS>>\n\n"
messages = messages[1:]
else:
sys_message = ''
for idx, entry in enumerate(messages):
if entry['role'] == 'user':
content = '<s>[INST] '+sys_message if idx <= 1 else '[INST] '
content += entry['content'] + ' [/INST]'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(len(input_ids))))
elif entry['role'] == 'assistant':
content = ' ' + entry['content'] + '</s>'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(input_ids)
assert len(result['input_ids']) == len(result['target_ids'])
return result
def tokenize_conversation(data, tokenizer, ignore_index, chat_format):
if chat_format == 'OpenChatML':
return tokenize_openchatml_conversation(data, tokenizer, ignore_index)
elif chat_format == 'llama2':
return tokenize_llama2_conversation(data, tokenizer, ignore_index)
else:
raise Exception(f"Unsupported chat format: {chat_format}")
def convert_to_numpy_array(pylist, target_length, pad_token, data_type):
padded = np.full(target_length, pad_token, dtype=data_type)
padded[:len(pylist)] = pylist
return padded
def pad_and_align(sample, input_ids_key, target_ids_key, block_size, max_sample_size, pad_token_id, ignore_index, data_dtype):
padding = max_sample_size - len(sample[input_ids_key])
if pad_token_id >= 0:
assert padding >= 0
if padding > 0:
if padding > 1:
sample[input_ids_key] = convert_to_numpy_array(sample[input_ids_key], block_size, pad_token_id, data_dtype)
else:
sample[input_ids_key] = np.array(sample[input_ids_key], dtype=data_dtype)
sample[target_ids_key] = convert_to_numpy_array(sample[target_ids_key][1:], block_size, ignore_index, data_dtype)
else:
assert len(sample[input_ids_key]) == max_sample_size
assert len(sample[target_ids_key]) == max_sample_size
sample[input_ids_key] = np.array(sample[input_ids_key][:-1], dtype=data_dtype)
sample[target_ids_key] = np.array(sample[target_ids_key][1:], dtype=data_dtype)
else:
expected_len = len(sample[input_ids_key]) - 1 if padding > 0 else block_size
sample[input_ids_key] = np.array(sample[input_ids_key][:expected_len], dtype=data_dtype)
sample[target_ids_key] = np.array(sample[target_ids_key][1:expected_len+1], dtype=data_dtype)
def process_chunk(args):
chunk_file, tokenizer_path, chat_format, block_size, ignore_index, pad_token_id, min_unmasked_tokens = args
max_sample_size = block_size + 1
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
data_dtype = np.int16 if len(tokenizer) < 32767 else np.int32
truncated = 0
rejected = 0
data = []
pa_table = pq.read_table(chunk_file)
for i in range(len(pa_table['rows'])):
cols = pa_table['rows'][i].as_py().split(';', 1)
row = json.loads(cols[1])
if 'messages' not in row or 'chosen' not in row or 'rejected' not in row:
rejected += 1
else:
chosen_sample = tokenize_conversation(row['messages']+[row['chosen']], tokenizer, ignore_index, chat_format)
chosen_input_ids_len = len(chosen_sample['input_ids'])
if chosen_input_ids_len > max_sample_size:
chosen_sample['input_ids'] = chosen_sample['input_ids'][:max_sample_size]
chosen_sample['target_ids'] = chosen_sample['target_ids'][:max_sample_size]
truncated += 1
rejected_sample = tokenize_conversation(row['messages']+[row['rejected']], tokenizer, ignore_index, chat_format)
rejected_input_ids_len = len(rejected_sample['input_ids'])
if rejected_input_ids_len > max_sample_size:
rejected_sample['input_ids'] = rejected_sample['input_ids'][:max_sample_size]
rejected_sample['target_ids'] = rejected_sample['target_ids'][:max_sample_size]
truncated += 1
data.append({
'chosen_input_ids': chosen_sample['input_ids'],
'chosen_target_ids': chosen_sample['target_ids'],
'rejected_input_ids': rejected_sample['input_ids'],
'rejected_target_ids': rejected_sample['target_ids'],
'source_file': cols[0]
})
del pa_table
created = len(data)
result = []
for sample in data:
pad_and_align(sample, "chosen_input_ids", "chosen_target_ids", block_size, max_sample_size, pad_token_id, ignore_index, data_dtype)
pad_and_align(sample, "rejected_input_ids", "rejected_target_ids", block_size, max_sample_size, pad_token_id, ignore_index, data_dtype)
assert isinstance(sample["chosen_input_ids"], np.ndarray)
assert isinstance(sample["chosen_target_ids"], np.ndarray)
assert isinstance(sample["rejected_input_ids"], np.ndarray)
assert isinstance(sample["rejected_target_ids"], np.ndarray)
if np.sum(sample['chosen_target_ids'] != ignore_index) >= min_unmasked_tokens and np.sum(sample['rejected_target_ids'] != ignore_index) >= min_unmasked_tokens:
result.append(sample)
else:
rejected += 1
with open(chunk_file, 'wb') as f:
joblib.dump(result, f)
return {'created': created, 'truncated': truncated, 'rejected': rejected}
def save_chunk_for_rank(rows, rank, output_dir, chunk_files):
chunk_file = os.path.join(output_dir, f"chunk_{rank:05}.tmp")
pa_array = pa.array(rows)
pa_table = pa.table([pa_array], names=['rows'])
pq.write_table(pa_table, chunk_file)
chunk_files.append(chunk_file)
def format_seconds_as_time(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours)}:{int(minutes):02}:{int(seconds):02}"
if __name__ == "__main__":
configure_logger()
parser = argparse.ArgumentParser(description='Tokenize dialogues for DPO training')
parser.add_argument("-c", "--config_path", help="Config file with a list of input files")
parser.add_argument("-f", "--input_file", help="Input file")
parser.add_argument("-i", "--input_dir", help="Directory with input jsonl files")
parser.add_argument("-o", "--output_dir", help="Output dir")
parser.add_argument("-n", "--num_output_files", type=int, default=1, help="Number of final output files")
parser.add_argument("-t", "--tokenizer_path", required=True, help="Tokenizer path")
parser.add_argument("-p", "--max_workers", type=int, default=20, help="The max number of processes")
parser.add_argument("-b", "--block_size", type=int, default=4096, help="Block/context size")
parser.add_argument('--chat_format', type=str, choices=['OpenChatML', 'llama2'], default='OpenChatML', help='Chat format')
parser.add_argument("--min_unmasked_tokens", type=int, default=1, help="Minimum number of unmasked target tokens required for a sample to be included in training")
parser.add_argument("--ignore_index", type=int, default=-100, help="Specifies a target value that is ignored in loss computation. Default is -100")
parser.add_argument("--pad_token_id", type=int, default=0, help="Specifies the padding token id. Default is 0")
parser.add_argument("--chunk_size", type=int, default=100000, help="Chunk size")
parser.add_argument('--save_samples', type=int, default=-1, help='Save this number of samples if positive')
parser.add_argument('--verbose', action='store_true', help='Be verbose')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
logger.info(f"Loaded tokenizer with vocab size {len(tokenizer)}")
logger.info(f"Active chat template type: {args.chat_format}")
timer = time.time()
max_sample_size = args.block_size + 1
configs = []
if args.config_path:
with open(args.config_path, "r", encoding="utf-8") as f:
configs = json.load(f)
if args.input_file:
configs.append({'path': args.input_file})
if args.input_dir:
for root, dirs, files in os.walk(args.input_dir):
for f in files:
if f.endswith('.jsonl'):
configs.append({'path': os.path.join(root, f)})
logger.info(f"Initialized with {len(configs)} input files")
logger.info("Loading data")
def load_data_file(config):
filename_prefix = os.path.basename(config['path']) + ";"
with open(config['path'], 'r') as f:
return list(filename_prefix + line for line in f if line)
chunks = joblib.Parallel(n_jobs=args.max_workers)(joblib.delayed(load_data_file)(config) for config in configs)
all_rows = list(chain.from_iterable(chunks))
del chunks
del configs
instruction_count = len(all_rows)
logger.info(f"Loaded {instruction_count:,} rows")
logger.info("Shuffling data")
random.shuffle(all_rows)
logger.info("Shuffling completed")
# adjust num of workers if needed
if len(all_rows) < 10*args.max_workers:
args.max_workers = max(1, len(all_rows) // 10)
logger.info(f"Chunking {len(all_rows):,} rows into {args.max_workers} files")
chunk_files = []
for rank in tqdm(range(args.max_workers), total=args.max_workers, desc="Chunking", disable=(not args.verbose)):
save_chunk_for_rank(all_rows[rank::args.max_workers], rank, args.output_dir, chunk_files)
del all_rows
logger.info(f"Saved {len(chunk_files)} chunks in {args.output_dir}")
logger.info(f"Tokenizing {len(chunk_files)} files")
processed_chunk_stats = []
max_workers = min(len(chunk_files), args.max_workers)
chunk_batches = list((chunk_file, args.tokenizer_path, args.chat_format, args.block_size, args.ignore_index, args.pad_token_id, args.min_unmasked_tokens) for chunk_file in chunk_files)
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
for result in tqdm(executor.map(process_chunk, chunk_batches), total=len(chunk_batches), desc="Tokenizing", disable=(not args.verbose)):
processed_chunk_stats.append(result)
del executor
stats = {'created': 0, 'truncated': 0, 'rejected': 0}
for s in processed_chunk_stats:
for k, v in s.items():
stats[k] += v
del processed_chunk_stats
logger.info(f"Tokenization finished in {len(chunk_files)} chunks. Stats: {stats}")
logger.info(f"Merging {len(chunk_files)} chunks")
chunks = joblib.Parallel(n_jobs=args.max_workers)(joblib.delayed(joblib.load)(f) for f in chunk_files)
all_samples = list(chain.from_iterable(chunks))
sample_count = len(all_samples)
logger.info(f"{sample_count:,} samples loaded")
assert isinstance(all_samples[0]["chosen_input_ids"], np.ndarray)
assert isinstance(all_samples[0]["chosen_target_ids"], np.ndarray)
assert isinstance(all_samples[0]["rejected_input_ids"], np.ndarray)
assert isinstance(all_samples[0]["rejected_target_ids"], np.ndarray)
assert sample_count > 0
if args.save_samples > 0:
logger.info(f"Saving samples")
samples_file = os.path.join(args.output_dir, "samples.jsonl")
with open(samples_file, 'w') as f:
for sample in all_samples[:args.save_samples]:
chosen_input_ids = sample["chosen_input_ids"].tolist()
rejected_input_ids = sample["rejected_input_ids"].tolist()
new_sample = {
"chosen_input": tokenizer.decode(chosen_input_ids),
"chosen_input_ids": chosen_input_ids,
"chosen_target_ids": sample["chosen_target_ids"].tolist(),
"rejected_input": tokenizer.decode(rejected_input_ids),
"rejected_input_ids": rejected_input_ids,
"rejected_target_ids": sample["rejected_target_ids"].tolist(),
}
f.write(json.dumps(new_sample, ensure_ascii=False))
f.write('\n')
logger.info(f"Samples saved in {samples_file}")
if args.num_output_files > 1:
for i in tqdm(range(args.num_output_files), desc="Saving", disable=(not args.verbose)):
bucket = all_samples[i::args.num_output_files]
output_file = os.path.join(args.output_dir, f"samples_part_{i:05}.alm")
with open(output_file, 'wb') as f:
joblib.dump(bucket, f)
logger.info(f"Saved {len(bucket)} samples into {output_file}")
else:
output_file = os.path.join(args.output_dir, "all_samples.alm")
with open(output_file, 'wb') as f:
joblib.dump(all_samples, f)
logger.info(f"All ({sample_count}) samples saved in {output_file}")
# cleanup
for chunk_file in chunk_files:
os.remove(chunk_file)
logger.info(f"Calculating stats")
chosen_sample_lenghts = [np.sum(sample['chosen_input_ids'] != args.pad_token_id).item() for sample in all_samples]
rejected_sample_lenghts = [np.sum(sample['rejected_input_ids'] != args.pad_token_id).item() for sample in all_samples]
chosen_shortest_sample_tokens = min(chosen_sample_lenghts)
chosen_longest_sample_tokens = max(chosen_sample_lenghts)
chosen_total_tokens_count = sum(chosen_sample_lenghts)
rejected_shortest_sample_tokens = min(rejected_sample_lenghts)
rejected_longest_sample_tokens = max(rejected_sample_lenghts)
rejected_total_tokens_count = sum(rejected_sample_lenghts)
stats = {
'samples_count': sample_count,
'chosen_shortest_sample_tokens': chosen_shortest_sample_tokens,
'chosen_longest_sample_tokens': chosen_longest_sample_tokens,
'chosen_total_tokens_count': chosen_total_tokens_count,
'rejected_shortest_sample_tokens': rejected_shortest_sample_tokens,
'rejected_longest_sample_tokens': rejected_longest_sample_tokens,
'rejected_total_tokens_count': rejected_total_tokens_count,
'chosen_avg_sample_size': (chosen_total_tokens_count // sample_count),
'rejected_avg_sample_size': (rejected_total_tokens_count // sample_count)
}
stats_str = json.dumps(stats, indent=4, ensure_ascii=False)
logger.info(f"Stats:\n{stats_str}")
stats_file = os.path.join(args.output_dir, "dataset_stats.json")
with open(stats_file, 'w') as fin:
json.dump(stats, fin)
logger.info(f"Stats saved in {stats_file}")
chosen_sample_lenght_histogram = dict(Counter(chosen_sample_lenghts))
histogram_file = os.path.join(args.output_dir, "dataset_chosen_histogram.csv")
with open(histogram_file, 'w') as fin:
fin.write("token_count; sample_count\n")
for length in range(0, max_sample_size+1):
fin.write(f"{length}; {chosen_sample_lenght_histogram.get(length, 0)}\n")
logger.info(f"Chosen samples histogram saved in {histogram_file}")
rejected_sample_lenght_histogram = dict(Counter(rejected_sample_lenghts))
histogram_file = os.path.join(args.output_dir, "dataset_rejected_histogram.csv")
with open(histogram_file, 'w') as fin:
fin.write("token_count; sample_count\n")
for length in range(0, max_sample_size+1):
fin.write(f"{length}; {rejected_sample_lenght_histogram.get(length, 0)}\n")
logger.info(f"Rejected samples istogram saved in {histogram_file}")
logger.info(f"Dataset with {sample_count:,} samples ({(chosen_total_tokens_count+rejected_total_tokens_count):,} tokens) has been created in {format_seconds_as_time(time.time()-timer)}")
"""
Use this file to create a dataset for Supervised Fine-Tuning (SFT) training
The script performs the following steps:
1. Reads the input JSONL file with dialogues (single or multi-turn)
2. Applies the OpenChatML or Llama2 chat template to each dialogue
3. Tokenizes the formatted dialogues
4. Generates token weights
5. Optionally packs dialogues to maximize GPU utilization
6. Saves the processed data in a binary format
7. Generates and saves summary statistics for the dataset
Example record with signe-turn dialogue:
```json
{"messages": [{"role": "user", "content": "1+2=?"}, {"role": "assistant", "content": "3"}]}
```
Example record with multi-turn dialogue:
```json
{"messages": [{"role": "user", "content": "1+2=?"}, {"role": "assistant", "content": "3"}, {"role": "user", "content": "2+2=?"}, {"role": "assistant", "content": "4"}]}
```
"""
import argparse
import concurrent.futures
import joblib
import json
import numpy as np
import os
import pyarrow as pa
import pyarrow.parquet as pq
import random
import time
from collections import Counter
from itertools import chain
from tqdm import tqdm
from transformers import AutoTokenizer
from allamo.logging import configure_logger, logger
MIN_WEIGHT = 0.001
def tokenize_openchatml_conversation(data, tokenizer, ignore_index):
conversation = data["messages"]
weight = data["weight"]
result = {'input_ids': [], 'target_ids': []}
if weight > MIN_WEIGHT:
result['target_weights'] = []
last_idx = len(conversation) - 1
for idx, entry in enumerate(conversation):
if entry["role"] == 'assistant':
pre_content = '<|im_start|>assistant\n'
pre_input_ids = tokenizer.encode(pre_content, add_special_tokens=False)
pre_input_ids_len = len(pre_input_ids)
content = entry['content'] + '<|im_end|>\n'
if idx == last_idx:
content += "</s>"
full_input_ids = tokenizer.encode(pre_content + content, add_special_tokens=False)
if full_input_ids[:pre_input_ids_len] == pre_input_ids:
result['input_ids'].extend(full_input_ids)
result['target_ids'].extend(list(
ignore_index if i < pre_input_ids_len else full_input_ids[i] for i in range(len(full_input_ids))
))
if weight > 0:
result['target_weights'].extend(list(
0.0 if i < pre_input_ids_len else weight for i in range(len(full_input_ids))
))
else:
logger.warning("Tokenization inconsistency detected. Performing separate tokenization")
content_input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(pre_input_ids)
result['input_ids'].extend(content_input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(pre_input_ids_len)))
result['target_ids'].extend(content_input_ids)
if weight > 0:
result['target_weights'].extend(list(0.0 for _ in range(pre_input_ids_len)))
result['target_weights'].extend(list(weight for _ in range(len(content_input_ids))))
else:
content = "<s><|im_start|>" if idx == 0 else "<|im_start|>"
content += entry["role"] + '\n' + entry["content"] + '<|im_end|>\n'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(len(input_ids))))
if weight > 0:
result['target_weights'].extend(list(0.0 for _ in range(len(input_ids))))
assert len(result['input_ids']) == len(result['target_ids'])
if weight > 0:
assert len(result['input_ids']) == len(result['target_weights'])
return result
def tokenize_llama2_conversation(data, tokenizer, ignore_index):
conversation = data["messages"]
weight = data["weight"]
result = {'input_ids': [], 'target_ids': []}
if weight > MIN_WEIGHT:
result['target_weights'] = []
if conversation[0]['role'] == 'system':
sys_message = f"<<SYS>>\n{conversation[0]['content']}\n<</SYS>>\n\n"
conversation = conversation[1:]
else:
sys_message = ''
for idx, entry in enumerate(conversation):
if entry['role'] == 'user':
content = '<s>[INST] '+sys_message if idx <= 1 else '[INST] '
content += entry['content'] + ' [/INST]'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(list(ignore_index for _ in range(len(input_ids))))
if weight > 0:
result['target_weights'].extend(list(0.0 for _ in range(len(input_ids))))
elif entry['role'] == 'assistant':
content = ' ' + entry['content'] + '</s>'
input_ids = tokenizer.encode(content, add_special_tokens=False)
result['input_ids'].extend(input_ids)
result['target_ids'].extend(input_ids)
if weight > 0:
result['target_weights'].extend(
list(weight for _ in range(len(input_ids)))
)
assert len(result['input_ids']) == len(result['target_ids'])
if weight > 0:
assert len(result['input_ids']) == len(result['target_weights'])
return result
def tokenize_conversation(data, tokenizer, ignore_index, chat_format):
if chat_format == 'OpenChatML':
return tokenize_openchatml_conversation(data, tokenizer, ignore_index)
elif chat_format == 'llama2':
return tokenize_llama2_conversation(data, tokenizer, ignore_index)
else:
raise Exception(f"Unsupported chat format: {chat_format}")
def convert_to_numpy_array(pylist, target_length, pad_token, data_type):
padded = np.full(target_length, pad_token, dtype=data_type)
padded[:len(pylist)] = pylist
return padded
def process_chunk(args):
chunk_file, pack, tokenizer_path, chat_format, block_size, ignore_index, pad_token_id, min_unmasked_tokens = args
max_sample_size = block_size + 1
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
data_dtype = np.int16 if len(tokenizer) < 32767 else np.int32
truncated = 0
rejected = 0
data = []
pa_table = pq.read_table(chunk_file)
for i in range(len(pa_table['rows'])):
cols = pa_table['rows'][i].as_py().split(';', 1)
messages = json.loads(cols[1])
if 'messages' not in messages:
rejected += 1
else:
weight = messages['weight'] if "weight" in messages and messages['weight'] > 0 else float(cols[0])
sample = tokenize_conversation({
"messages": messages['messages'],
"weight": weight
}, tokenizer, ignore_index, chat_format)
input_ids_len = len(sample['input_ids'])
if input_ids_len > max_sample_size:
sample['input_ids'] = sample['input_ids'][:max_sample_size]
sample['target_ids'] = sample['target_ids'][:max_sample_size]
if 'target_weights' in sample:
sample['target_weights'] = sample['target_weights'][:max_sample_size]
truncated += 1
data.append(sample)
del pa_table
created = len(data)
packed = 0
if pack:
packed_data = []
while data:
instructions_buffer = data.pop()
instructions_buffer["seq_lens"] = [len(instructions_buffer["input_ids"])]
while len(data) > 0 and len(instructions_buffer["input_ids"]) + len(data[-1]["input_ids"]) <= max_sample_size:
instruction = data.pop()
instructions_buffer["input_ids"].extend(instruction["input_ids"])
instructions_buffer["target_ids"].extend(instruction["target_ids"])
if "target_weights" in instructions_buffer:
instructions_buffer["target_weights"].extend(instruction["target_weights"])
instructions_buffer["seq_lens"].append(len(instruction["input_ids"]))
packed_data.append(instructions_buffer)
packed = len(packed_data)
data = packed_data
del packed_data
result = []
for sample in data:
padding = max_sample_size - len(sample['input_ids'])
if pad_token_id >= 0:
assert padding >= 0
if padding > 0:
if padding > 1:
sample["input_ids"] = convert_to_numpy_array(sample["input_ids"], block_size, pad_token_id, data_dtype)
else:
sample["input_ids"] = np.array(sample["input_ids"], dtype=data_dtype)
sample["target_ids"] = convert_to_numpy_array(sample["target_ids"][1:], block_size, ignore_index, data_dtype)
if "target_weights" in sample:
sample["target_weights"] = convert_to_numpy_array(sample["target_weights"][1:], block_size, 0, np.float16)
else:
assert len(sample["input_ids"]) == max_sample_size
assert len(sample["target_ids"]) == max_sample_size
if "target_weights" in sample:
assert len(sample["target_weights"]) == max_sample_size
if "seq_lens" in sample:
assert sum(sample["seq_lens"]) == max_sample_size
sample["input_ids"] = np.array(sample["input_ids"][:-1], dtype=data_dtype)
sample["target_ids"] = np.array(sample["target_ids"][1:], dtype=data_dtype)
if "target_weights" in sample:
sample["target_weights"] = np.array(sample["target_weights"][1:], dtype=np.float16)
if "seq_lens" in sample:
sample["seq_lens"][-1] -= 1
else:
expected_len = len(sample['input_ids']) - 1 if padding > 0 else block_size
sample["input_ids"] = np.array(sample["input_ids"][:expected_len], dtype=data_dtype)
sample["target_ids"] = np.array(sample["target_ids"][1:expected_len+1], dtype=data_dtype)
if "target_weights" in sample:
sample["target_weights"] = np.array(sample["target_weights"][1:expected_len+1], dtype=np.float16)
assert isinstance(sample["input_ids"], np.ndarray)
assert isinstance(sample["target_ids"], np.ndarray)
if np.sum(sample['target_ids'] != ignore_index) >= min_unmasked_tokens:
result.append(sample)
else:
rejected += 1
with open(chunk_file, 'wb') as f:
joblib.dump(result, f)
return {'created': created, 'truncated': truncated, 'rejected': rejected, 'packed': packed}
def save_chunk_for_rank(rows, rank, output_dir, chunk_files):
chunk_file = os.path.join(output_dir, f"chunk_{rank:05}.tmp")
pa_array = pa.array(rows)
pa_table = pa.table([pa_array], names=['rows'])
pq.write_table(pa_table, chunk_file)
chunk_files.append(chunk_file)
def format_seconds_as_time(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{int(hours)}:{int(minutes):02}:{int(seconds):02}"
def create_sample_for(input_ids, target_weights, seq_lens, data_dtype):
sample = {'input_ids': np.array(input_ids, dtype=data_dtype)}
if target_weights and isinstance(target_weights, list):
sample['target_weights'] = np.array(target_weights, dtype=np.float16)
if seq_lens and isinstance(seq_lens, list):
sample['seq_lens'] = seq_lens
return sample
if __name__ == "__main__":
configure_logger()
parser = argparse.ArgumentParser(description='Tokenize dialogues with weights')
parser.add_argument("-c", "--config_path", help="Config file with a list of input files")
parser.add_argument("-f", "--input_file", help="Input file")
parser.add_argument("-i", "--input_dir", help="Directory with input jsonl files")
parser.add_argument("-o", "--output_dir", help="Output dir")
parser.add_argument("-n", "--num_output_files", type=int, default=1, help="Number of final output files")
parser.add_argument("-t", "--tokenizer_path", required=True, help="Tokenizer path")
parser.add_argument("-w", "--default_weight", type=float, default=-1, help="Default weight for input files")
parser.add_argument("-p", "--max_workers", type=int, default=20, help="The max number of processes")
parser.add_argument("-b", "--block_size", type=int, default=4096, help="Block/context size")
parser.add_argument('--chat_format', type=str, choices=['OpenChatML', 'llama2'], default='OpenChatML', help='Chat format')
parser.add_argument("--min_unmasked_tokens", type=int, default=1, help="Minimum number of unmasked target tokens required for a sample to be included in training")
parser.add_argument("--ignore_index", type=int, default=-100, help="Specifies a target value that is ignored in loss computation. Default is -100")
parser.add_argument("--pad_token_id", type=int, default=0, help="Specifies the padding token id. Default is 0")
parser.add_argument("--chunk_size", type=int, default=100000, help="Chunk size")
parser.add_argument('--save_samples', action='store_true', help='Save some samples')
parser.add_argument('--pack', action='store_true', help='Pack')
parser.add_argument('--verbose', action='store_true', help='Be verbose')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
logger.info(f"Loaded tokenizer with vocab size {len(tokenizer)}")
logger.info(f"Active chat template type: {args.chat_format}")
if not args.pack:
logger.warning("Padding not applied as packing is disabled")
timer = time.time()
max_sample_size = args.block_size + 1
configs = []
if args.config_path:
with open(args.config_path, "r", encoding="utf-8") as f:
configs = json.load(f)
if args.input_file:
configs.append({'path': args.input_file})
if args.input_dir:
for root, dirs, files in os.walk(args.input_dir):
for f in files:
if f.endswith('.jsonl'):
configs.append({'path': os.path.join(root, f)})
logger.info(f"Initialized with {len(configs)} input files")
logger.info("Loading data")
def load_data_file(config):
weight = config['weight'] if 'weight' in config else args.default_weight
weight_doc_prefix = f"{weight};"
with open(config['path'], 'r') as f:
return list(weight_doc_prefix + line for line in f if line)
chunks = joblib.Parallel(n_jobs=args.max_workers)(joblib.delayed(load_data_file)(config) for config in configs)
all_rows = list(chain.from_iterable(chunks))
del chunks
del configs
instruction_count = len(all_rows)
logger.info(f"Loaded {instruction_count:,} rows")
logger.info("Shuffling data")
random.shuffle(all_rows)
logger.info("Shuffling completed")
# adjust num of workers if needed
if len(all_rows) < 10*args.max_workers:
args.max_workers = max(1, len(all_rows) // 10)
logger.info(f"Chunking {len(all_rows):,} rows into {args.max_workers} files")
chunk_files = []
for rank in tqdm(range(args.max_workers), total=args.max_workers, desc="Chunking", disable=(not args.verbose)):
save_chunk_for_rank(all_rows[rank::args.max_workers], rank, args.output_dir, chunk_files)
del all_rows
logger.info(f"Saved {len(chunk_files)} chunks in {args.output_dir}")
logger.info(f"Tokenizing {len(chunk_files)} files")
processed_chunk_stats = []
max_workers = min(len(chunk_files), args.max_workers)
chunk_batches = list((chunk_file, args.pack, args.tokenizer_path, args.chat_format, args.block_size, args.ignore_index, args.pad_token_id, args.min_unmasked_tokens) for chunk_file in chunk_files)
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
for result in tqdm(executor.map(process_chunk, chunk_batches), total=len(chunk_batches), desc="Tokenizing", disable=(not args.verbose)):
processed_chunk_stats.append(result)
del executor
stats = {'created': 0, 'truncated': 0, 'rejected': 0, 'packed': 0}
for s in processed_chunk_stats:
for k, v in s.items():
stats[k] += v
del processed_chunk_stats
logger.info(f"Tokenization finished in {len(chunk_files)} chunks. Stats: {stats}")
logger.info(f"Merging {len(chunk_files)} chunks")
chunks = joblib.Parallel(n_jobs=args.max_workers)(joblib.delayed(joblib.load)(f) for f in chunk_files)
all_samples = list(chain.from_iterable(chunks))
sample_count = len(all_samples)
logger.info(f"{sample_count:,} samples loaded")
assert isinstance(all_samples[0]["input_ids"], np.ndarray)
assert isinstance(all_samples[0]["target_ids"], np.ndarray)
assert sample_count > 0
if args.save_samples:
logger.info(f"Saving samples")
samples_file = os.path.join(args.output_dir, "samples.jsonl")
with open(samples_file, 'w') as f:
for sample in all_samples[:100]:
input_ids = sample["input_ids"].tolist()
new_sample = {
"input": tokenizer.decode(input_ids),
"input_ids": input_ids,
"target_ids": sample["target_ids"].tolist(),
}
if 'target_weights' in sample:
new_sample["target_weights"] = sample["target_weights"].tolist()
if 'seq_lens' in sample:
new_sample["seq_lens"] = sample["seq_lens"]
f.write(json.dumps(new_sample, ensure_ascii=False))
f.write('\n')
logger.info(f"Samples saved in {samples_file}")
if args.num_output_files > 1:
for i in tqdm(range(args.num_output_files), desc="Saving", disable=(not args.verbose)):
bucket = all_samples[i::args.num_output_files]
output_file = os.path.join(args.output_dir, f"samples_part_{i:05}.alm")
with open(output_file, 'wb') as f:
joblib.dump(bucket, f)
logger.info(f"Saved {len(bucket)} samples into {output_file}")
else:
output_file = os.path.join(args.output_dir, "all_samples.alm")
with open(output_file, 'wb') as f:
joblib.dump(all_samples, f)
logger.info(f"All ({sample_count}) samples saved in {output_file}")
# cleanup
for chunk_file in chunk_files:
os.remove(chunk_file)
logger.info(f"Calculating stats")
if args.pack:
sample_lenghts = [sum(sample['seq_lens']) for sample in all_samples]
else:
sample_lenghts = [np.sum(sample['input_ids'] != args.pad_token_id).item() for sample in all_samples]
shortest_sample_tokens = min(sample_lenghts)
longest_sample_tokens = max(sample_lenghts)
total_tokens_count = sum(sample_lenghts)
stats = {
'instruction_count': instruction_count,
'samples_count': sample_count,
'shortest_sample_tokens': shortest_sample_tokens,
'longest_sample_tokens': longest_sample_tokens,
'total_tokens_count': total_tokens_count,
'avg_instruction_size': (total_tokens_count // instruction_count),
'avg_sample_size': (total_tokens_count // sample_count),
'packing_ratio': (instruction_count / sample_count),
'packing_level': (total_tokens_count / (sample_count * args.block_size) * 100),
}
stats_str = json.dumps(stats, indent=4, ensure_ascii=False)
logger.info(f"Stats:\n{stats_str}")
stats_file = os.path.join(args.output_dir, "dataset_stats.json")
with open(stats_file, 'w') as fin:
json.dump(stats, fin)
logger.info(f"Stats saved in {stats_file}")
sample_lenght_histogram = dict(Counter(sample_lenghts))
histogram_file = os.path.join(args.output_dir, "dataset_histogram.csv")
with open(histogram_file, 'w') as fin:
fin.write("token_count; sample_count\n")
for length in range(0, max_sample_size+1):
fin.write(f"{length}; {sample_lenght_histogram.get(length, 0)}\n")
logger.info(f"Histogram saved in {histogram_file}")
logger.info(f"Dataset with {sample_count:,} samples ({total_tokens_count:,} tokens) has been created in {format_seconds_as_time(time.time()-timer)}")
"""
Use this file for prunning model layers
"""
import argparse
import json
import os
import torch
from allamo.logging import configure_logger, logger
from allamo.train_utils import (
get_model_checkpoint_path,
get_config_checkpoint_path,
)
def prepare_layer_keys_mapping(n_layers, num_layers_to_remove):
num_layers = n_layers - 1 # last layer is excluded since we will always keep it
step = num_layers // (num_layers - num_layers_to_remove)
indices_to_keep = [int(i * step) for i in range(num_layers - num_layers_to_remove)]
mapping_pairs = []
for layer_i in range(num_layers):
if layer_i in indices_to_keep:
mapping_pairs.append((len(mapping_pairs), layer_i))
assert len(mapping_pairs) == len(indices_to_keep)
mapping_pairs.append((len(mapping_pairs), n_layers - 1))
return mapping_pairs
def prune_model(input_dir_path, input_checkpoint_name_base, output_dir_path, output_checkpoint_name_base, num_layers_to_remove, bfloat16):
os.makedirs(output_dir_path, exist_ok=True)
logger.info(f"loading checkpoint from {input_dir_path}...")
with open(get_config_checkpoint_path(input_checkpoint_name_base, input_dir_path), "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
model_checkpoint = torch.load(get_model_checkpoint_path(input_checkpoint_name_base, input_dir_path), map_location='cpu')
unwanted_prefix = '_orig_mod.'
for k,v in list(model_checkpoint.items()):
if k.startswith(unwanted_prefix):
model_checkpoint[k[len(unwanted_prefix):]] = model_checkpoint.pop(k)
state_dict = {
"tok_embeddings.weight": model_checkpoint["tok_embeddings.weight"],
"norm.weight": model_checkpoint["norm.weight"],
"lm_head.weight": model_checkpoint["lm_head.weight"],
}
layer_mapping_pairs = prepare_layer_keys_mapping(config_checkpoint['model_args']['n_layer'], num_layers_to_remove)
# you can customize the mapping here, e.g., replace some layers with others
for dest_layer_idx, src_layer_idx in layer_mapping_pairs:
logger.info(f"copying weights from layer {src_layer_idx} to layer {dest_layer_idx}")
state_dict[f"layers.{dest_layer_idx}.attention.q_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.attention.q_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.attention.k_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.attention.k_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.attention.v_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.attention.v_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.attention.c_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.attention.c_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.feed_forward.gate_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.feed_forward.gate_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.feed_forward.down_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.feed_forward.down_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.feed_forward.up_proj.weight"] = model_checkpoint[f"layers.{src_layer_idx}.feed_forward.up_proj.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.attention_norm.weight"] = model_checkpoint[f"layers.{src_layer_idx}.attention_norm.weight"].clone()
state_dict[f"layers.{dest_layer_idx}.ffn_norm.weight"] = model_checkpoint[f"layers.{src_layer_idx}.ffn_norm.weight"].clone()
if bfloat16:
logger.info("converting weights to bfloat16")
param_count = 0
param_bytes = 0
for k, v in state_dict.items():
if bfloat16:
v = v.to(torch.bfloat16)
state_dict[k] = v
param_count += v.numel()
param_bytes += v.numel() * v.element_size()
config_checkpoint['model_args']['n_layer'] = len(layer_mapping_pairs)
param_count /= 1e6
param_bytes /= 1024**2
logger.info(f"New model layers: {config_checkpoint['model_args']['n_layer']}. Model parameters: {param_count:.2f}M. Est. Size: {param_bytes:.3f}MB")
ckpt_file_path = get_config_checkpoint_path(output_checkpoint_name_base, output_dir_path)
logger.info(f"saving config checkpoint to {ckpt_file_path}")
with open(ckpt_file_path, "w", encoding="utf-8") as f:
json.dump(config_checkpoint, f, indent=4, ensure_ascii=False)
ckpt_file_path = get_model_checkpoint_path(output_checkpoint_name_base, output_dir_path)
logger.info(f"saving model checkpoint to {ckpt_file_path}")
torch.save(state_dict, ckpt_file_path)
logger.info(f"checkpoint files saved in {output_dir_path}")
if __name__ == "__main__":
configure_logger()
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of ALLaMo weights, which contains a checkpoint file",
)
parser.add_argument(
"--input_checkpoint_name_base",
default='ckpt',
help="Source checkpoint file name base",
)
parser.add_argument(
"--output_dir",
help="Location to write up-scaled model",
)
parser.add_argument(
"--output_checkpoint_name_base",
default='pruned_ckpt',
help="Output checkpoint file name base",
)
parser.add_argument(
"--num_layers_to_remove", type=int,
help="Number of layers to remove from the model",
)
parser.add_argument(
"--bfloat16", type=bool,
help="Convert weights to bfloaf16",
)
args = parser.parse_args()
prune_model(
input_dir_path=args.input_dir,
input_checkpoint_name_base=args.input_checkpoint_name_base,
output_dir_path=args.output_dir,
output_checkpoint_name_base=args.output_checkpoint_name_base,
num_layers_to_remove=args.num_layers_to_remove,
bfloat16=args.bfloat16
)
from setuptools import setup
setup(name='allamo',
version='5.0.0',
author='Krzysztof (Chris) Ociepa',
packages=['allamo'],
description='Simple, hackable and fast implementation for training/finetuning medium-sized LLaMA-based models',
license='MIT',
install_requires=[
'torch',
'numpy',
'joblib',
'wandb'
],
)
from allamo.configuration import AllamoConfiguration
from allamo.trainer.simple_trainer import SimpleTrainer
if __name__ == '__main__':
config = AllamoConfiguration()
trainer = SimpleTrainer(config)
trainer.init_wandb()
trainer.train()
trainer.close()
# Refer to allamo/configuration.py, if you need to sft:"training_type": "sft", "init_from": "resume", ...
python train.py --config="./train_configs/train_1B.json"
{
"training_type": "pre",
"init_from": "scratch",
"data_dir": "../data/",
"out_dir": "../data/out-allamo-1B/",
"checkpoint_interval": 1000,
"save_best_checkpoint": false,
"eval_interval": 1000,
"eval_iters": 200,
"log_interval": 1,
"vocab_size": 50307,
"custom_tokenizer_path": "../data/allamo_1B_dataset/tokenizer.json",
"wandb_log": false,
"wandb_project": "allamo",
"wandb_run_name": "allamo-1B",
"dataset": "allamo_1B_dataset",
"batch_size": 1,
"block_size": 2048,
"gradient_accumulation_steps": 264,
"dataset_seq_train": true,
"grad_accum_schedule": false,
"n_layer": 20,
"n_head": 16,
"head_size": 128,
"n_embd": 2048,
"dropout": 0,
"weight_decay": 0.1,
"multiple_of": 256,
"norm_eps": 0.000001,
"learning_rate": 0.0003,
"max_iters": 38000,
"decay_lr": true,
"lr_decay_iters": 38000,
"lr_decay_reset_iters": 3800,
"min_lr": 0.0002,
"warmup_iters": 3800,
"device": "cuda:0",
"dtype": "float16",
"compile": true
}
torchrun --standalone --nnodes=1 --nproc-per-node=8 train.py --config="./train_configs/train_1B.json"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment