Commit 4f4ba442 authored by mashun1's avatar mashun1
Browse files

omnisql

parents
Pipeline #2643 canceled with stages
import argparse
import json
import re
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
def parse_response(response):
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
# Extract the last SQL query in the response text and remove extra whitespace characters
last_sql = sql_blocks[-1].strip()
return last_sql
else:
# print("No SQL blocks found.")
return ""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path", type = str, default = "/fs/fast/u2021000902/previous_nvme/xxx")
parser.add_argument("--input_file", type = str, help = "the input file path (prompts)")
parser.add_argument("--output_file", type = str, help = "the output file path (results)")
parser.add_argument("--tensor_parallel_size", type = int, help = "the number of used GPUs", default = 4)
parser.add_argument("--n", type = int, help = "the number of generated responses", default = 4)
parser.add_argument("--temperature", type = float, help = "temperature of llm's sampling", default = 1.0)
opt = parser.parse_args()
print(opt)
input_dataset = json.load(open(opt.input_file))
tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_model_name_or_path, trust_remote_code=True)
if "Qwen2.5-" in opt.pretrained_model_name_or_path:
stop_token_ids = [151645] # 151645 is the token id of <|im_end|> (end of turn token in Qwen2.5)
elif "deepseek-coder-" in opt.pretrained_model_name_or_path:
stop_token_ids = [32021]
elif "DeepSeek-Coder-V2" in opt.pretrained_model_name_or_path:
stop_token_ids = [100001]
elif "OpenCoder-" in opt.pretrained_model_name_or_path:
stop_token_ids = [96539]
elif "Meta-Llama-" in opt.pretrained_model_name_or_path:
stop_token_ids = [128009, 128001]
elif "granite-" in opt.pretrained_model_name_or_path:
stop_token_ids = [0] # <|end_of_text|> is the end token of granite-3.1 and granite-code
elif "starcoder2-" in opt.pretrained_model_name_or_path:
stop_token_ids = [0] # <|end_of_text|> is the end token of starcoder2
elif "Codestral-" in opt.pretrained_model_name_or_path:
stop_token_ids = [2]
elif "Mixtral-" in opt.pretrained_model_name_or_path:
stop_token_ids = [2]
elif "OmniSQL-" in opt.pretrained_model_name_or_path:
stop_token_ids = [151645] # OmniSQL uses the same tokenizer as Qwen2.5
else:
print("Use Qwen2.5's stop tokens by default.")
stop_token_ids = [151645]
print("stop_token_ids:", stop_token_ids)
max_model_len = 8192 # used to allocate KV cache memory in advance
max_input_len = 6144
max_output_len = 2048 # (max_input_len + max_output_len) must <= max_model_len
print("max_model_len:", max_model_len)
print("temperature:", opt.temperature)
sampling_params = SamplingParams(
temperature = opt.temperature,
max_tokens = max_output_len,
n = opt.n,
stop_token_ids = stop_token_ids
)
llm = LLM(
model = opt.pretrained_model_name_or_path,
dtype = "bfloat16",
tensor_parallel_size = opt.tensor_parallel_size,
max_model_len = max_model_len,
gpu_memory_utilization = 0.92,
swap_space = 42,
enforce_eager = True,
disable_custom_all_reduce = True,
trust_remote_code = True
)
chat_prompts = [tokenizer.apply_chat_template(
[{"role": "user", "content": data["input_seq"]}],
add_generation_prompt = True, tokenize = False
) for data in input_dataset]
outputs = llm.generate(chat_prompts, sampling_params)
results = []
for data, output in zip(input_dataset, outputs):
responses = [o.text for o in output.outputs]
sqls = [parse_response(response) for response in responses]
data["responses"] = responses
data["pred_sqls"] = sqls
results.append(data)
with open(opt.output_file, "w", encoding = "utf-8") as f:
f.write(json.dumps(results, indent = 2, ensure_ascii = False))
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
if __name__ == "__main__":
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Coder-32B-Instruct",
torch_dtype = torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-32B-Instruct")
print(model.dtype)
peft_model_id = "your_lora_weight_path"
print(peft_model_id)
model = PeftModel.from_pretrained(model, peft_model_id)
print("before merging:")
for name, param in model.named_parameters():
print(name)
model = model.merge_and_unload(progressbar = True)
print(model.dtype)
print("after merging:")
for name, param in model.named_parameters():
print(name)
model.save_pretrained(
peft_model_id + "-full-model",
max_shard_size = "4GB"
)
tokenizer.save_pretrained(
peft_model_id + "-full-model"
)
'''
This script originates from the GitHub repository:
https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
'''
import torch
import torch.nn.functional as F
import transformers
from typing import Optional
import sys
def get_max_seqlen_in_batch(attention_mask):
max_num = torch.max(attention_mask)
# attention_mask: B x N
counts = []
for i in range(1, max_num + 1):
counts.append(
torch.sum(attention_mask == i, axis=-1)
) # shape: B, count length of data point maksed with i
result = torch.stack(counts, axis=1)
result = result.flatten()
return result[result.nonzero()].squeeze(-1).to(dtype=torch.int32)
def get_unpad_data(attention_mask):
seqlens_in_batch = get_max_seqlen_in_batch(
attention_mask
) # attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask
def load_balancing_loss_func(
gate_logits: torch.Tensor,
num_experts: torch.Tensor = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
attention_mask (`torch.Tensor`, None):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat(
[layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0
)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
# ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask
new_attention_mask = (attention_mask != 0).int().to(attention_mask.device)
batch_size, sequence_length = new_attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (
batch_size * sequence_length
)
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
new_attention_mask[None, :, :, None, None]
.expand(
(num_hidden_layers, batch_size, sequence_length, top_k, num_experts)
)
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(
expert_mask.float() * expert_attention_mask, dim=0
) / torch.sum(expert_attention_mask, dim=0)
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
new_attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(
routing_weights * router_per_expert_attention_mask, dim=0
) / torch.sum(router_per_expert_attention_mask, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
def monkey_patch_for_model_with_name(model_type: str, modelling_type: str):
"""For example for llama: model_package = llama, modelling_module=modeling_llama
Args:
model_package (_type_): _description_
modelling_module (_type_): _description_
"""
module = getattr(getattr(transformers, model_type), modelling_type)
if hasattr(module, "_get_unpad_data"):
module._get_unpad_data = get_unpad_data
print(
f"cannot packing llama because _get_unpad_data was not found in transformers.{model_type}.{modelling_type}.py or transformers.modeling_flash_attention_utils._get_unpad_data"
)
sys.exit(1)
def monkey_patch_packing_for_model(pretrained_model):
# Monkey-patch flash attention if this transformers already merged: https://github.com/huggingface/transformers/commit/e314395277d784a34ee99526f48155d4d62cff3d
# this will work for all models using flash attention: Llama, Mistral, Qwen2, Phi3, ...
model_config = transformers.AutoConfig.from_pretrained(pretrained_model)
config_type = type(model_config).__name__.lower()
if hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
else: # if this is the old version of transformer
model_type, modelling_type = "", ""
if config_type == "mistralconfig":
print("monkey_patch_packing for Mistral ")
transformers.models.mistral.modeling_mistral._get_unpad_data = (
get_unpad_data
)
elif config_type == "llamaconfig":
print("monkey_patch_packing for Llama ")
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif config_type == "mixtralconfig":
print("monkey_patch_packing for Mixtral")
transformers.models.mixtral.modeling_mixtral._get_unpad_data = (
get_unpad_data
)
elif config_type == "qwen2config":
print("monkey_patch_packing for Qwen2")
# transformers.models.qwen2.modeling_qwen2
model_type, modelling_type = "qwen2", "modeling_qwen2"
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif config_type == "phi3config":
# transformers.models.phi3.modeling_phi3
print("monkey_patch_packing for Qwen2")
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
else:
raise Exception(
f"{config_type} is not supported, currently we only support: Mistral, Mixtral, Llama, Qwen2 for monkey-patch-packing"
)
monkey_patch_for_model_with_name(model_type, modelling_type)
if config_type == "mixtralconfig":
# if it is mixtral, we need to monkey-patch the load_balancing_loss_func
transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = (
load_balancing_loss_func
)
\ No newline at end of file
import nltk
nltk.download('punkt')
\ No newline at end of file
import json
import sqlite3
import os
from tqdm import tqdm
import re
import argparse
import random
from collections import OrderedDict
from pyserini.search.lucene import LuceneSearcher
from nltk.tokenize import word_tokenize
from nltk import ngrams
import ijson
SQL_RESERVED_WORDS = {'IDENTIFIED', 'FOREIGN', 'CONSTRAINT', 'USER', 'POSITION', 'DESCRIBE', 'CHECK', 'RECURSIVE', 'REAL', 'CONTINUE', 'GLOBAL', 'RLIKE', 'INSENSITIVE', 'BOOLEAN', 'CHAR', 'ROLE', 'CASE', 'SCHEMA', 'CLOB', 'RESIGNAL', 'ROW', 'DEC', 'TOP', 'EXCEPT', 'SENSITIVE', 'OUT', 'RENAME', 'READS', 'BLOB', 'INT', 'EXTERNAL', 'LOCALTIMESTAMP', 'DECLARE', 'DO', 'AS', 'OVER', 'CONDITION', 'SELECT', 'SAVEPOINT', 'WITHIN', 'ELSEIF', 'UNLOCK', 'DATABASE', 'TRIGGER', 'ACCESS', 'FALSE', 'BREAK', 'ITERATE', 'SMALLINT', 'ASC', 'YEAR', 'DELETE', 'ROLLBACK', 'ON', 'ESCAPE', 'CREATE', 'MONTH', 'SPECIFIC', 'SESSION', 'SQLSTATE', 'HOLD', 'SET', 'EXPLAIN', 'RETURN', 'ROWNUM', 'BINARY', 'SYSDATE', 'SQLWARNING', 'EXTEND', 'CAST', 'FOR', 'TERMINATED', 'VIEW', 'TRAILING', 'HOUR', 'VARYING', 'RESTRICT', 'RIGHT', 'DISTINCT', 'JOIN', 'UNKNOWN', 'VALUES', 'TABLE', 'OR', 'DOUBLE', 'DROP', 'COMMIT', 'PRECISION', 'LANGUAGE', 'START', 'INTERSECT', 'IGNORE', 'NULL', 'CURRENT_DATE', 'LOCK', 'INTO', 'NEW', 'DESC', 'STATIC', 'MODIFIES', 'GRANT', 'VALUE', 'LIMIT', 'MODULE', 'DATE', 'LOCALTIME', 'PERCENT', 'REPEAT', 'FULL', 'USAGE', 'ORDER', 'WHEN', 'PRIMARY', 'BETWEEN', 'CURSOR', 'DECIMAL', 'HAVING', 'IF', 'FILTER', 'INDEX', 'ILIKE', 'VARCHAR', 'EXEC', 'USING', 'ROWS', 'PLACING', 'WHILE', 'EXECUTE', 'EACH', 'LEFT', 'FLOAT', 'COLLATE', 'CURRENT_TIME', 'OPEN', 'RANGE', 'CROSS', 'FUNCTION', 'TIME', 'BOTH', 'NOT', 'CONVERT', 'NCHAR', 'KEY', 'DEFAULT', 'LIKE', 'ANALYZE', 'EXISTS', 'IN', 'BIT', 'INOUT', 'SUM', 'NUMERIC', 'AFTER', 'LEAVE', 'INSERT', 'TO', 'COUNT', 'THEN', 'BEFORE', 'OUTER', 'COLUMN', 'ONLY', 'END', 'PROCEDURE', 'OFFSET', 'ADD', 'INNER', 'RELEASE', 'FROM', 'DAY', 'NO', 'CALL', 'BY', 'LOCAL', 'ZONE', 'TRUE', 'EXIT', 'LEADING', 'INTEGER', 'MERGE', 'OLD', 'AVG', 'MIN', 'SQL', 'LOOP', 'SIGNAL', 'REFERENCES', 'MINUTE', 'UNIQUE', 'GENERATED', 'ALL', 'MATCH', 'CASCADE', 'UNION', 'COMMENT', 'FETCH', 'UNDO', 'UPDATE', 'WHERE', 'ELSE', 'PARTITION', 'BIGINT', 'CHARACTER', 'CURRENT_TIMESTAMP', 'ALTER', 'INTERVAL', 'REVOKE', 'CONNECT', 'WITH', 'TIMESTAMP', 'GROUP', 'BEGIN', 'CURRENT', 'REGEXP', 'NATURAL', 'SOME', 'SQLEXCEPTION', 'MAX', 'SUBSTRING', 'OF', 'AND', 'REPLACE', 'IS'}
SPECIAL_CHARS_PATTERN = re.compile(r'[^a-zA-Z0-9_]')
def load_json_file(file):
dataset = []
with open(file, 'r', encoding='utf-8') as f:
objects = ijson.items(f, 'item')
for obj in tqdm(objects):
dataset.append(obj)
return dataset
def remove_sql_comments(sql):
# Remove single-line comments
sql = re.sub(r'--.*', '', sql)
# Remove multi-line comments
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
return sql.strip()
def obtain_db_ddls(db_file_dir):
conn = sqlite3.connect(db_file_dir)
cursor = conn.cursor()
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
create_statements = []
for table in tables:
_, create_statement = table
create_statements.append(create_statement)
cursor.close()
conn.close()
# table_schemas = [remove_sql_comments(stat) for stat in create_statements]
return create_statements
def needs_backticks(identifier):
if identifier.upper() in SQL_RESERVED_WORDS:
return True
if SPECIAL_CHARS_PATTERN.search(identifier):
return True
return False
def format_identifier(identifier):
if needs_backticks(identifier):
return f'`{identifier}`'
return identifier
def sample_table_values(db_file_dir, table_names, limit_num):
db_values_dict = dict()
conn = sqlite3.connect(db_file_dir)
cursor = conn.cursor()
for table_name in table_names:
cursor.execute(f"PRAGMA table_info(`{table_name}`);")
columns = cursor.fetchall()
column_names = [column[1] for column in columns]
for column_name in column_names:
# cursor.execute(f"SELECT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL LIMIT {limit_num};")
query = f"""
SELECT `{column_name}`
FROM (
SELECT DISTINCT `{column_name}`
FROM `{table_name}`
WHERE `{column_name}` IS NOT NULL and `{column_name}` != ''
) AS unique_values
LIMIT {limit_num};
"""
cursor.execute(query)
values = cursor.fetchall()
values = [value[0] for value in values]
# truncate too long strings
for idx in range(len(values)):
if isinstance(values[idx], str):
values[idx] = values[idx][:40]
if len(values) > 0:
db_values_dict[f"{table_name}.{column_name}".lower()] = values
cursor.close()
conn.close()
return db_values_dict
def calculate_substring_match_percentage(query, target):
query = query.lower()
target = target.lower()
substrings = []
for i in range(len(query)):
for j in range(i + 1, len(query) + 1):
substrings.append(query[i:j])
max_matched_substring_len = max([len(substring) for substring in substrings if substring in target])
return max_matched_substring_len/len(query)
def retrieve_relevant_hits(searcher, queries):
queries = list(dict.fromkeys(queries))
# print("len(queries):", len(queries))
q_ids = [f"{idx}" for idx in range(len(queries))]
query2hits = dict()
search_results = searcher.batch_search(queries, q_ids, k = 10, threads=60)
for query, q_id in zip(queries, q_ids):
hits = search_results[q_id]
hits = list(dict.fromkeys(([hit.raw for hit in hits])))
hits = [json.loads(hit) for hit in hits]
query2hits[query] = hits
return query2hits
def retrieve_question_related_db_values(hits, question):
high_score_hits = []
for idx, hit in enumerate(hits):
table_name, column_name, c_id = hit["id"].split("-**-")
score = calculate_substring_match_percentage(hit["contents"], question)
if score > 0.85:
high_score_hits.append(
{
"table_dot_column_lower_case": f"{table_name}.{column_name}".lower(),
"db_value": hit["contents"],
"score": score,
"index": idx,
}
)
high_score_hits = sorted(high_score_hits, key=lambda x: (x["score"], len(x["db_value"]), x["index"]), reverse=True)
high_score_hits = high_score_hits[:20] # remain top 20 db values
relavant_db_values_dict = dict()
for hit in high_score_hits:
if hit["table_dot_column_lower_case"] in relavant_db_values_dict:
relavant_db_values_dict[hit["table_dot_column_lower_case"]].append(hit["db_value"])
else:
relavant_db_values_dict[hit["table_dot_column_lower_case"]] = [hit["db_value"]]
return relavant_db_values_dict
def obtain_n_grams(sequence, max_n):
'''
returns all grams of sequence less than or equal to `max_n`
'''
tokens = word_tokenize(sequence)
all_n_grams = []
for n in range(1, max_n + 1):
all_n_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)])
return all_n_grams
input_prompt_template = '''Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
{db_engine}
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```sql
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.
'''
def obtain_pk_fk_column_idx(db_info):
pk_fk_column_idx_list = []
for primary_keys_idx in db_info["primary_keys"]:
if isinstance(primary_keys_idx, int):
pk_fk_column_idx_list.append(primary_keys_idx)
elif isinstance(primary_keys_idx, list):
pk_fk_column_idx_list.extend(primary_keys_idx)
for (source_column_idx, target_column_idx) in db_info["foreign_keys"]:
pk_fk_column_idx_list.append(source_column_idx)
pk_fk_column_idx_list.append(target_column_idx)
return pk_fk_column_idx_list
def prepare_schema_filter_data(question, db_info):
data = dict()
data["text"] = question
data["schema"] = dict()
data["schema"]["schema_items"] = []
for outer_table_idx, table_name_original in enumerate(db_info["table_names_original"]):
table_info = dict()
table_info["table_name"] = table_name_original
table_info["table_comment"] = ""
table_info["column_names"] = []
table_info["column_comments"] = []
for (inner_table_idx, column_name_original), (_, column_comment) in zip(db_info["column_names_original"], db_info["column_names"]):
if outer_table_idx == inner_table_idx:
table_info["column_names"].append(column_name_original)
table_info["column_comments"].append(column_comment)
data["schema"]["schema_items"].append(table_info)
return data
def obtain_db_details(db_info, data_source, sampled_db_values_dict, relavant_db_values_dict, output_seq, mode, question):
db_details = []
assert len(db_info["column_names_original"]) == len(db_info["column_names"]) == len(db_info["column_types"])
if mode == "train":
'''
to increase training data's diversity, the input database schema includes:
[PK, FK, output sequence-used columns, random sampled unused columns]
'''
# remain primary and foreign key columns
used_column_idx_list = obtain_pk_fk_column_idx(db_info)
# remain SQL-used columns
for column_idx, (inner_table_idx, column_name) in enumerate(db_info["column_names_original"]):
if column_name.lower() in output_seq.lower():
used_column_idx_list.append(column_idx)
used_column_idx_list = list(set(used_column_idx_list))
used_column_num = len(used_column_idx_list)
all_column_idx_list = list(range(len(db_info["column_names_original"])))
unused_column_idx_list = [idx for idx in all_column_idx_list if idx not in used_column_idx_list]
# random select some unused columns to mimic noise in the input sequence
if unused_column_idx_list:
unused_column_prob = random.choice([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
sample_size = int(unused_column_prob * len(unused_column_idx_list))
max_column_num = 225
if used_column_num > max_column_num:
sample_size = 0
elif used_column_num + sample_size > max_column_num:
sample_size = max_column_num - used_column_num
else:
sample_size = sample_size
used_column_idx_list.extend(random.sample(unused_column_idx_list, sample_size))
else:
# put all tables and columns in the prompt
used_column_idx_list = list(range(len(db_info["column_names_original"])))
# print(used_column_idx_list)
for outer_table_idx, table_name in enumerate(db_info["table_names_original"]):
column_info_list = []
pk_columns = []
fk_info = []
column_comment_prob = random.choice([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])
for column_idx, ((inner_table_idx, column_name), (_, column_comment), column_type) in enumerate(zip(
db_info["column_names_original"], db_info["column_names"], db_info["column_types"]
)):
if inner_table_idx == outer_table_idx:
if column_idx not in used_column_idx_list:
continue
column_values = []
if f"{table_name}.{column_name}".lower() in relavant_db_values_dict:
column_values.extend(relavant_db_values_dict[f"{table_name}.{column_name}".lower()])
if f"{table_name}.{column_name}".lower() in sampled_db_values_dict:
column_values.extend(sampled_db_values_dict[f"{table_name}.{column_name}".lower()])
column_values = list(dict.fromkeys(column_values)) # dedup (reserve order)
column_values = column_values[:6]
if data_source == "synthetic":
if random.random() < column_comment_prob:
column_info = f' {format_identifier(column_name)} {column_type}, -- {column_comment}'
if len(column_values) > 0:
column_info += f", example: {column_values}"
else: # simulate some columns do not have comment
column_info = f' {format_identifier(column_name)} {column_type},'
if len(column_values) > 0:
column_info += f" -- example: {column_values}"
else:
if column_name.lower() in [column_comment.lower(), column_comment.lower().replace(" ", "_"), column_comment.lower().replace(" ", "")] \
or column_comment.strip() == "":
column_info = f' {format_identifier(column_name)} {column_type},'
if len(column_values) > 0:
column_info += f" -- example: {column_values}"
else:
column_info = f' {format_identifier(column_name)} {column_type}, -- {column_comment}'
if len(column_values) > 0:
column_info += f", example: {column_values}"
column_info_list.append(column_info)
for primary_keys_idx in db_info["primary_keys"]:
if isinstance(primary_keys_idx, int):
if column_idx == primary_keys_idx:
pk_columns.append(column_name) # f' PRIMARY KEY ("{ }")'
elif isinstance(primary_keys_idx, list):
if column_idx in primary_keys_idx:
pk_columns.append(column_name)
for (source_column_idx, target_column_idx) in db_info["foreign_keys"]:
if column_idx == source_column_idx:
source_table_idx = db_info["column_names_original"][source_column_idx][0]
source_table_name = db_info["table_names_original"][source_table_idx]
source_column_name = db_info["column_names_original"][source_column_idx][1]
target_table_idx = db_info["column_names_original"][target_column_idx][0]
target_table_name = db_info["table_names_original"][target_table_idx]
target_column_name = db_info["column_names_original"][target_column_idx][1]
fk_info.append(f' CONSTRAINT fk_{source_table_name.lower().replace(" ", "_")}_{source_column_name.lower().replace(" ", "_")} FOREIGN KEY ({format_identifier(source_column_name)}) REFERENCES {format_identifier(target_table_name)} ({format_identifier(target_column_name)}),')
if len(column_info_list) > 0:
pk_columns = list(OrderedDict.fromkeys(pk_columns))
if len(pk_columns) > 0:
pk_info = [' PRIMARY KEY (' + ', '.join([f'{format_identifier(column_name)}' for column_name in pk_columns]) + '),']
else:
pk_info = []
fk_info = list(OrderedDict.fromkeys(fk_info))
table_ddl = ""
table_ddl += f'CREATE TABLE {format_identifier(table_name)} (\n'
table_ddl += "\n".join(column_info_list + pk_info + fk_info)
if table_ddl.endswith(","):
table_ddl = table_ddl[:-1] # remove extra commas
table_ddl += "\n);"
db_details.append(table_ddl)
if mode == "train":
random.shuffle(db_details)
db_details = "\n\n".join(db_details)
# double check
for column_idx, (_, column_name) in enumerate(db_info["column_names_original"]):
if column_name == "*":
continue
if column_idx in used_column_idx_list:
assert column_name.lower() in db_details.lower()
return db_details
def deduplicate_dicts(dict_list):
seen = set()
unique_dicts = []
for d in dict_list:
dict_tuple = frozenset(d.items())
if dict_tuple not in seen:
seen.add(dict_tuple)
unique_dicts.append(d)
return unique_dicts
def prepare_input_output_pairs(data, ek_key, db_id2relevant_hits, sampled_db_values_dict, db_info, source, output_key, mode):
if data[ek_key].strip() == "":
question = data["question"]
else:
question = data[ek_key] + "\n" + data["question"]
relavant_db_values_dict = dict()
if db_id2relevant_hits: # retrieve matched values from the databases
queries = obtain_n_grams(question, 8) + [question]
queries = list(dict.fromkeys(queries))
hits = []
for query in queries:
hits.extend(db_id2relevant_hits[data["db_id"]][query])
hits = deduplicate_dicts(hits)
relavant_db_values_dict = retrieve_question_related_db_values(hits, question)
db_details = obtain_db_details(
db_info, source, sampled_db_values_dict, relavant_db_values_dict,
data[output_key], mode, question
)
input_seq = input_prompt_template.format(
db_engine = "SQLite",
db_details = db_details,
question = question
)
return {"input_seq": input_seq, "output_seq": data[output_key]}
def process_data(args):
data, ek_key, searcher, sampled_db_values, db_info, source, output_key, mode = args
return prepare_input_output_pairs(data, ek_key, searcher, sampled_db_values, db_info, source, output_key, mode)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_data_file", type = str)
parser.add_argument("--output_data_file", type = str)
parser.add_argument("--db_path", type = str)
parser.add_argument("--tables", type = str)
parser.add_argument("--source", type = str)
parser.add_argument("--mode", type = str)
parser.add_argument("--value_limit_num", type = int)
parser.add_argument("--db_content_index_path", type = str)
opt = parser.parse_args()
print(opt)
random.seed(42)
assert opt.mode in ["train", "dev", "test"]
dataset = load_json_file(opt.input_data_file)
ek_key = "external_knowledge"
if opt.source == "synthetic":
output_key = "cot"
elif opt.source == "spider2.0":
output_key = "query"
for data in dataset:
data[output_key] = "" # spider2.0 does not provide gold sqls
elif opt.source == "spider":
if opt.mode == "train":
output_key = "cot" # use our synthetic CoT during training
else:
output_key = "query"
for data in dataset:
data[ek_key] = "" # spider does not provide external knowledge
elif opt.source == "bird":
if opt.mode == "train":
output_key = "cot" # use our synthetic CoT during training
else:
output_key = "SQL"
ek_key = "evidence"
elif opt.source == "spider_dk":
output_key = "query"
for data in dataset:
data[ek_key] = "" # spider_dk does not provide external knowledge
elif opt.source == "spider_realistic":
output_key = "query"
for data in dataset:
data[ek_key] = "" # spider_realistic does not provide external knowledge
elif opt.source == "spider_syn":
output_key = "query"
for data in dataset:
data[ek_key] = "" # spider_syn does not provide external knowledge
data["question"] = data["SpiderSynQuestion"]
elif opt.source in ["ehrsql", "sciencebenchmark"]:
output_key = "query"
for data in dataset:
data[ek_key] = "" # ehrsql and sciencebenchmark does not provide external knowledge
else:
assert "argument `source` should be in [xxxx]."
used_db_ids = list(set([data["db_id"] for data in dataset]))
db_id2sampled_db_values = dict()
db_id2db_info = dict()
for db_info in tqdm(load_json_file(opt.tables)):
db_id = db_info["db_id"]
if db_id not in used_db_ids:
continue
db_file = os.path.join(opt.db_path, db_id, db_id + ".sqlite")
sampled_db_values_dict = sample_table_values(db_file, db_info["table_names_original"], opt.value_limit_num)
db_id2sampled_db_values[db_id] = sampled_db_values_dict
db_id2db_info[db_id] = db_info
batch_size = 20000
sliced_datasets = [dataset[i: i+batch_size] for i in range(0, len(dataset), batch_size)]
print(len(dataset))
print([len(batch_dataset) for batch_dataset in sliced_datasets])
assert len(dataset) == sum([len(batch_dataset) for batch_dataset in sliced_datasets])
new_dataset = []
for batch_idx, batch_dataset in enumerate(sliced_datasets):
print(f"Process: {batch_idx+1}/{len(sliced_datasets)}")
if opt.db_content_index_path:
db_id2searcher = dict()
batch_db_ids = list(set([data["db_id"] for data in batch_dataset]))
# load db context index searchers
for db_id in batch_db_ids:
db_id2searcher[db_id] = LuceneSearcher(os.path.join(opt.db_content_index_path, db_id))
db_id2queries = dict()
for data in tqdm(batch_dataset):
if data[ek_key].strip() == "":
question = data["question"]
else:
question = data[ek_key] + "\n" + data["question"]
queries = obtain_n_grams(question, 8) + [question]
queries = list(set(queries))
if data["db_id"] in db_id2queries:
db_id2queries[data["db_id"]].extend(queries)
else:
db_id2queries[data["db_id"]] = queries
# perform db content retrieval (in a large batch)
db_id2relevant_hits = dict()
for db_id in tqdm(batch_db_ids):
db_id2relevant_hits[db_id] = retrieve_relevant_hits(db_id2searcher[db_id], db_id2queries[db_id])
else:
db_id2relevant_hits = None
for data in tqdm(batch_dataset):
new_dataset.append(
prepare_input_output_pairs(data, ek_key, db_id2relevant_hits, db_id2sampled_db_values[data["db_id"]],
db_id2db_info[data["db_id"]], opt.source, output_key, opt.mode)
)
del db_id2searcher, db_id2relevant_hits,
with open(opt.output_data_file, "w", encoding = "utf-8") as f:
f.write(json.dumps(new_dataset, indent = 2, ensure_ascii = False))
\ No newline at end of file
set -e
# Spider (dev)
python process_dataset.py --input_data_file ./data/spider/dev.json --output_data_file ./data/dev_spider.json --db_path ./data/spider/database/ --tables ./data/spider/tables.json --source spider --mode dev --value_limit_num 2 --db_content_index_path ./data/spider/db_contents_index
# Spider (test)
python process_dataset.py --input_data_file ./data/spider/test.json --output_data_file ./data/test_spider.json --db_path ./data/spider/test_database/ --tables ./data/spider/test_tables.json --source spider --mode test --value_limit_num 2 --db_content_index_path ./data/spider/db_contents_index
# BIRD (dev)
python process_dataset.py --input_data_file ./data/bird/dev_20240627/dev.json --output_data_file ./data/dev_bird.json --db_path ./data/bird/dev_20240627/dev_databases/ --tables ./data/bird/dev_20240627/dev_tables.json --source bird --mode dev --value_limit_num 2 --db_content_index_path ./data/bird/dev_20240627/db_contents_index
# Spider2.0-SQLite
python process_dataset.py --input_data_file ./data/spider2_sqlite/test.json --output_data_file ./data/test_spider2_sqlite.json --db_path ./data/spider2_sqlite/databases/ --tables ./data/spider2_sqlite/tables.json --source spider2.0 --mode test --value_limit_num 2 --db_content_index_path ./data/spider2_sqlite/db_contents_index
# Spider-DK
python process_dataset.py --input_data_file ./data/Spider-DK/Spider-DK.json --output_data_file ./data/dev_spider_dk.json --db_path ./data/Spider-DK/database --tables ./data/Spider-DK/tables.json --source spider_dk --mode dev --value_limit_num 2 --db_content_index_path ./data/Spider-DK/db_contents_index
# Spider-Realistic
python process_dataset.py --input_data_file ./data/spider-realistic/spider-realistic.json --output_data_file ./data/dev_spider_realistic.json --db_path ./data/spider/database/ --tables ./data/spider/tables.json --source spider_realistic --mode dev --value_limit_num 2 --db_content_index_path ./data/spider/db_contents_index
# Spider-Syn
python process_dataset.py --input_data_file ./data/Spider-Syn/dev.json --output_data_file ./data/dev_spider_syn.json --db_path ./data/spider/database/ --tables ./data/spider/tables.json --source spider_syn --mode dev --value_limit_num 2 --db_content_index_path ./data/spider/db_contents_index
# EHRSQL
python process_dataset.py --input_data_file ./data/EHRSQL/dev.json --output_data_file ./data/dev_ehrsql.json --db_path ./data/EHRSQL/database --tables ./data/EHRSQL/tables.json --source ehrsql --mode dev --value_limit_num 2 --db_content_index_path ./data/EHRSQL/db_contents_index
# ScienceBenchmark
python process_dataset.py --input_data_file ./data/sciencebenchmark/dev.json --output_data_file ./data/dev_sciencebenchmark.json --db_path ./data/sciencebenchmark/databases --tables ./data/sciencebenchmark/tables.json --source sciencebenchmark --mode dev --value_limit_num 2 --db_content_index_path ./data/sciencebenchmark/db_contents_index
# Spider (Training set)
python process_dataset.py --input_data_file ./data/spider/train_spider_enhanced_with_cot.json --output_data_file ./data/train_spider.json --db_path ./data/spider/database/ --tables ./data/spider/tables.json --source spider --mode train --value_limit_num 2 --db_content_index_path ./data/spider/db_contents_index
# BIRD (Training set)
python process_dataset.py --input_data_file ./data/bird/train/train_enhanced_with_cot.json --output_data_file ./data/train_bird.json --db_path ./data/bird/train/train_databases/ --tables ./data/bird/train/train_tables.json --source bird --mode train --value_limit_num 2 --db_content_index_path ./data/bird/train/db_contents_index
# SynSQL-2.5M
python process_dataset.py --input_data_file ./data/SynSQL-2.5M/data.json --output_data_file ./data/train_synsql.json --db_path ./data/SynSQL-2.5M/databases --tables ./data/SynSQL-2.5M/tables.json --source synthetic --mode train --value_limit_num 2 --db_content_index_path ./data/SynSQL-2.5M/db_contents_index
import argparse
import os
import math
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.load_sft_dataset import SFTDataset
from utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate.utils import set_seed
from accelerate import Accelerator
from torch.utils.tensorboard import SummaryWriter
from peft import LoraConfig, TaskType, get_peft_model, AutoPeftModelForCausalLM
from monkey_patch_packing import monkey_patch_packing_for_model
'''
Training LLM using Huggingface Accelerate + Deepspeed.
'''
def parse_option():
parser = argparse.ArgumentParser()
# global args
parser.add_argument('--per_device_train_batch_size', type = int, default = 4,
help = 'batch size per gpu device.')
parser.add_argument('--block_size', type = int, default = 8192,
help = 'block size, i.e., the length of training sequences.')
parser.add_argument('--seed', type = int, default = 42)
parser.add_argument('--pretrained_model_name_or_path', type = str, default = "deepseek-ai/deepseek-coder-6.7b-base")
parser.add_argument('--epochs', type = int, default = 1)
parser.add_argument('--lr', type = float, default = 5e-5, help = "5e-5 for pre-training, 5e-6 for fine-tuning.")
parser.add_argument('--ckpt_num', type = int, default = 20, help = "The number of ckpts during training. (uniform sampling)")
parser.add_argument('--tensorboard_log_dir', type = str, default = "./train_logs")
parser.add_argument('--output_ckpt_dir', type = str, default = "./ckpts")
parser.add_argument('--mode', type = str, default = "pre-train")
# args for supervised fine-tuning
parser.add_argument('--sft_data_dir', type = str, default = "train_20240127.json")
# args for lora tuning
parser.add_argument('--use_lora', action = 'store_true', help = "Whether to use Lora to fine-tune the model")
parser.add_argument('--target_modules', type = str, help = "The names of the modules to apply the adapter to")
parser.add_argument('--r', type = int, help = "Lora attention dimension (the `rank`)")
parser.add_argument('--lora_alpha', type = int, help = "The alpha parameter for Lora scaling")
parser.add_argument('--lora_dropout', type = float, help = "The dropout probability for Lora layers")
opt = parser.parse_args()
return opt
def sanity_check(input, target, tokenizer):
print("Start Sanity Check -------->")
for t, m in zip(input, target):
decoded = tokenizer.decode([t])
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
print("<-------- End Sanity Check")
assert len(input) == len(target), f"length mismatch: {len(input)} vs {len(target)}"
def checkpoint_model(accelerator, model, tokenizer, output_ckpt_dir, last_global_step):
'''
Utility fuction for only checkpointing the model dictionary (i.e., only model parameters)
'''
ckpt_path = os.path.join(output_ckpt_dir, "ckpt-{}".format(last_global_step))
accelerator.print("checkpointing model state dict at {}".format(ckpt_path))
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
ckpt_path,
is_main_process = accelerator.is_main_process,
save_function = accelerator.save,
state_dict = accelerator.get_state_dict(model),
max_shard_size = "100GB"
)
if accelerator.is_main_process:
tokenizer.save_pretrained(ckpt_path)
return
def train(opt):
set_seed(opt.seed)
writer = SummaryWriter(opt.tensorboard_log_dir)
accelerator = Accelerator()
print("accelerator.is_main_process:", accelerator.is_main_process)
print("accelerator.device:", accelerator.device)
total_batch_size = opt.per_device_train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps
accelerator.print(opt)
accelerator.print("tokens per batch:", total_batch_size * opt.block_size)
accelerator.print("sequences per batch:", total_batch_size)
accelerator.print("using LLM from:", opt.pretrained_model_name_or_path)
# packing inputs without cross-contamination attention (must use flash attention)
monkey_patch_packing_for_model(opt.pretrained_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(opt.pretrained_model_name_or_path, trust_remote_code=True)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is None:
raise ValueError("please set a right eos_token_id in the tokenizer")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
opt.pretrained_model_name_or_path,
torch_dtype = torch.bfloat16,
trust_remote_code = True,
attn_implementation = "flash_attention_2"
)
if opt.use_lora:
target_modules = [target_module.strip() for target_module in opt.target_modules.split(',')]
accelerator.print("Lora target_modules:", target_modules)
peft_config = LoraConfig(
task_type = TaskType.CAUSAL_LM,
target_modules = target_modules,
r = opt.r,
lora_alpha = opt.lora_alpha,
lora_dropout = opt.lora_dropout
)
model = get_peft_model(model, peft_config)
if accelerator.is_main_process:
model.print_trainable_parameters()
# enable gradient checkpointing to save GPU memory, but this action would slowdown the training speed 20-30%.
# in addition, gradient_checkpointing can not be enabled when using deepspeed ZERO-3
model.gradient_checkpointing_enable()
dataset = SFTDataset(opt.sft_data_dir, tokenizer, opt.block_size, opt.mode)
if accelerator.is_main_process:
sanity_check(dataset[0]["input_ids"], dataset[0]["labels"], tokenizer)
dataloader = DataLoader(dataset, batch_size = opt.per_device_train_batch_size, shuffle = True, drop_last = True)
num_total_batches = math.ceil(opt.epochs * math.ceil(len(dataset) / total_batch_size)) # number of total batches
checkpointing_steps = int(num_total_batches/opt.ckpt_num)
accelerator.print("checkpointing_steps:", checkpointing_steps)
optimizer = AdamW(model.parameters(), lr = opt.lr, betas = (0.9, 0.95), eps = 1e-8, weight_decay = 0.1)
num_warm_up_batches = int(num_total_batches * 0.05) # 5% of total batches for warm up
lr_scheduler = LinearWarmupCosineAnnealingLR(
optimizer = optimizer,
warmup_epochs = num_warm_up_batches * accelerator.num_processes, # * accelerator.num_processes
max_epochs = num_total_batches* accelerator.num_processes, # * accelerator.num_processes
warmup_start_lr = 0.0,
eta_min = 0.1 * opt.lr
)
optimizer, model, dataloader, lr_scheduler = accelerator.prepare(optimizer, model, dataloader, lr_scheduler)
# print(type(optimizer))
# print(type(model))
# print(type(dataloader))
# print(type(lr_scheduler))
accumulation_loss = 0
global_completed_steps = 0
model.train()
st = time.time()
for epoch in range(opt.epochs):
print("This is epoch:", epoch+1)
for batch_idx, batch in enumerate(dataloader):
accelerator.print(batch["input_ids"].shape)
# `accelerator.accumulate(model)` aims to set right `sync_gradients` state based on the recorded training steps
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accumulation_loss += loss.detach().float()
# when deepspeed is enabled, `accelerator.backward(loss)` is doing optimizer.step(), optimizer.zero_grad(), and grad accumulation automatically.
# see `if self.is_gradient_accumulation_boundary():` line in path-to-env/site-packages/deepspeed/runtime/engine.py
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# 'accelerator.sync_gradients' checks if the accelerator has performed an optimization step on the `total_batch_size` examples
if accelerator.sync_gradients:
global_completed_steps += 1
accelerator.print("GPU 0, step {}, loss {}".format(global_completed_steps, accumulation_loss / accelerator.gradient_accumulation_steps))
accelerator.print("GPU 0, step {}, lr state dict:".format(global_completed_steps), lr_scheduler.state_dict())
accelerator.print(time.time()-st)
st = time.time()
writer.add_scalar(
'train-loss/gpu-{}'.format(accelerator.process_index),
accumulation_loss / accelerator.gradient_accumulation_steps,
global_completed_steps
)
writer.add_scalar(
'learning-rate/gpu-{}'.format(accelerator.process_index),
lr_scheduler.get_last_lr()[0],
global_completed_steps
)
# reset accumulation_loss to 0
accumulation_loss = 0
# save checkpoints for each checkpointing_steps total batch size
if global_completed_steps % checkpointing_steps == 0:
accelerator.print("after {} global training steps, save a checkpoint".format(global_completed_steps))
accelerator.wait_for_everyone()
checkpoint_model(accelerator, model, tokenizer, opt.output_ckpt_dir, global_completed_steps)
accelerator.print("in the end of an epoch, save a checkpoint")
accelerator.wait_for_everyone()
checkpoint_model(accelerator, model, tokenizer, opt.output_ckpt_dir, global_completed_steps)
if __name__ == "__main__":
opt = parse_option()
train(opt)
\ No newline at end of file
set -e
LR=4e-6
EPOCHS=2
CONFIG_FILE="./accelerate_config_14b.yaml"
PER_DEVICE_TRAIN_BATCH_SIZE=1
MODEL_PATH="Qwen/Qwen2.5-Coder-14B-Instruct"
CKPT_NUM=10
BASE_NAME="omnisql_14b_lr${LR}_epochs${EPOCHS}"
CKPT_DIR="./ckpts/$BASE_NAME"
LOG_DIR="./train_logs/$BASE_NAME"
DATASET_DIR="./data/train_synsql.json"
accelerate launch --main_process_port 10000 --config_file $CONFIG_FILE train.py \
--per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
--block_size 8192 \
--seed 42 \
--pretrained_model_name_or_path $MODEL_PATH \
--epochs $EPOCHS \
--lr $LR \
--ckpt_num $CKPT_NUM \
--tensorboard_log_dir $LOG_DIR \
--output_ckpt_dir $CKPT_DIR \
--sft_data_dir $DATASET_DIR \
--mode sft
\ No newline at end of file
set -e
LR=2e-4
EPOCHS=2
CONFIG_FILE="./accelerate_config_32b.yaml"
PER_DEVICE_TRAIN_BATCH_SIZE=2
MODEL_PATH="Qwen/Qwen2.5-Coder-32B-Instruct"
CKPT_NUM=10
BASE_NAME="omnisql_32b_lr${LR}_epochs${EPOCHS}-lora"
CKPT_DIR="./ckpts/$BASE_NAME"
LOG_DIR="./train_logs/$BASE_NAME"
DATASET_DIR="./data/train_synsql.json"
accelerate launch --main_process_port 10000 --config_file $CONFIG_FILE train.py \
--per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
--block_size 8192 \
--seed 42 \
--pretrained_model_name_or_path $MODEL_PATH \
--epochs $EPOCHS \
--lr $LR \
--ckpt_num $CKPT_NUM \
--tensorboard_log_dir $LOG_DIR \
--output_ckpt_dir $CKPT_DIR \
--sft_data_dir $DATASET_DIR \
--mode sft \
--use_lora \
--target_modules "q_proj, k_proj, v_proj" \
--r 256 \
--lora_alpha 512 \
--lora_dropout 0.1
set -e
LR=2e-5
EPOCHS=2
CONFIG_FILE="./accelerate_config_7b.yaml"
PER_DEVICE_TRAIN_BATCH_SIZE=1
MODEL_PATH="/home/ckpts/Qwen2.5-Coder-7B-Instruct"
CKPT_NUM=1
BASE_NAME="omnisql_7b_lr${LR}_epochs${EPOCHS}"
CKPT_DIR="./ckpts/$BASE_NAME"
LOG_DIR="./train_logs/$BASE_NAME"
# DATASET_DIR="./data/train_bird.json"
# DATASET_DIR="./data/train_spider.json"
DATASET_DIR="./data/train_synsql_part.json"
accelerate launch --main_process_port 10000 --config_file $CONFIG_FILE train.py \
--per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
--block_size 8192 \
--seed 42 \
--pretrained_model_name_or_path $MODEL_PATH \
--epochs $EPOCHS \
--lr $LR \
--ckpt_num $CKPT_NUM \
--tensorboard_log_dir $LOG_DIR \
--output_ckpt_dir $CKPT_DIR \
--sft_data_dir $DATASET_DIR \
--mode sft
import json
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
def find_sublist_index(lst, sublist):
sublist_length = len(sublist)
for i in range(len(lst) - sublist_length + 1):
if lst[i:i + sublist_length] == sublist:
return i
return -1
def obtain_labels(input_ids, assistant_start_token_ids):
'''
Mask everything before assistant_start_token_ids with -100
'''
assistant_start_idx = find_sublist_index(input_ids, assistant_start_token_ids)
if assistant_start_idx == -1:
labels = input_ids
print("length of the output sequence exceeds max length")
else:
labels = [-100] * assistant_start_idx + input_ids[assistant_start_idx: ]
assert len(input_ids) == len(labels)
return labels
class SFTDataset(Dataset):
def __init__(self, data_dir, tokenizer, max_length, mode):
super().__init__()
self.mode = mode
assistant_start_token_ids = [151644, 77091] # for Qwen2.5's tokenizer, the start token ids of the Assistant (<|im_start|>assistant)
if mode == "pre-train":
packed_data = np.load(data_dir)
self.all_input_ids = torch.tensor(packed_data["all_packed_input_ids"], dtype=torch.int32)
self.all_attention_mask = torch.tensor(packed_data["all_packed_attention_masks"], dtype=torch.int32)
self.all_labels = torch.tensor(packed_data["all_packed_labels"], dtype=torch.int32)
del packed_data
elif mode == "sft":
dataset = json.load(open(data_dir))
sequences = [tokenizer.apply_chat_template([
{"role": "user", "content": data["input_seq"]},
{"role": "assistant", "content": data["output_seq"]}
], add_generation_prompt = False, tokenize = False) for data in tqdm(dataset)]
tokenized_results = tokenizer.batch_encode_plus(
sequences,
truncation = False
)
self.all_input_ids = []
self.all_attention_mask = []
self.all_labels = []
num = 0
for input_ids in tokenized_results["input_ids"]:
if len(input_ids) > max_length: # pre-truncation
input_ids = input_ids[-max_length:]
num += 1
self.all_input_ids.append(input_ids + [tokenizer.pad_token_id] * (max_length-len(input_ids)))
self.all_attention_mask.append([1] * len(input_ids) + [0] * (max_length-len(input_ids)))
# mask prompt loss
self.all_labels.append(obtain_labels(input_ids, assistant_start_token_ids) + [-100] * (max_length-len(input_ids)))
# no-mask prompt loss
# self.all_labels.append(input_ids + [-100] * (max_length-len(input_ids)))
print(f"There are {num} sequences have been truncated.")
self.all_input_ids = torch.tensor(self.all_input_ids, dtype=torch.int64)
self.all_attention_mask = torch.tensor(self.all_attention_mask, dtype=torch.int64)
self.all_labels = torch.tensor(self.all_labels, dtype=torch.int64)
def __getitem__(self, index):
if self.mode == "pre-train":
return {
"input_ids": torch.tensor(self.all_input_ids[index], dtype=torch.int64),
"attention_mask": torch.tensor(self.all_attention_mask[index], dtype=torch.int64),
"labels": torch.tensor(self.all_labels[index], dtype=torch.int64)
}
elif self.mode == "sft":
return {
"input_ids": self.all_input_ids[index],
"attention_mask": self.all_attention_mask[index],
"labels": self.all_labels[index]
}
def __len__(self):
return self.all_input_ids.shape[0]
\ No newline at end of file
import math
import warnings
from typing import List
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
'''
copy from the source code of pl_bolts
'''
class LinearWarmupCosineAnnealingLR(_LRScheduler):
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
after each iteration as calling it after each epoch will keep the starting lr at
warmup_start_lr for the first epoch which is 0 in most cases.
.. warning::
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
train and validation methods.
Example:
>>> import torch.nn as nn
>>> from torch.optim import Adam
>>> #
>>> layer = nn.Linear(10, 1)
>>> optimizer = Adam(layer.parameters(), lr=0.02)
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
>>> # the default case
>>> for epoch in range(40):
... # train(...)
... # validate(...)
... scheduler.step()
>>> # passing epoch param case
>>> for epoch in range(40):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
"""
def __init__(
self,
optimizer: Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.0,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_epochs (int): Maximum number of iterations for linear warmup
max_epochs (int): Maximum number of iterations
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
"""Compute learning rate using chainable form of the scheduler."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
if self.last_epoch == 0:
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"]
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (
1
+ math.cos(
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
)
)
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
def _get_closed_form_lr(self) -> List[float]:
"""Called when epoch is passed as a param to the `step` function of the scheduler."""
if self.last_epoch < self.warmup_epochs:
return [
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr in self.base_lrs
]
return [
self.eta_min
+ 0.5
* (base_lr - self.eta_min)
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
for base_lr in self.base_lrs
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment