import json import sqlite3 import os from tqdm import tqdm import re import argparse import random from collections import OrderedDict from pyserini.search.lucene import LuceneSearcher from nltk.tokenize import word_tokenize from nltk import ngrams import ijson SQL_RESERVED_WORDS = {'IDENTIFIED', 'FOREIGN', 'CONSTRAINT', 'USER', 'POSITION', 'DESCRIBE', 'CHECK', 'RECURSIVE', 'REAL', 'CONTINUE', 'GLOBAL', 'RLIKE', 'INSENSITIVE', 'BOOLEAN', 'CHAR', 'ROLE', 'CASE', 'SCHEMA', 'CLOB', 'RESIGNAL', 'ROW', 'DEC', 'TOP', 'EXCEPT', 'SENSITIVE', 'OUT', 'RENAME', 'READS', 'BLOB', 'INT', 'EXTERNAL', 'LOCALTIMESTAMP', 'DECLARE', 'DO', 'AS', 'OVER', 'CONDITION', 'SELECT', 'SAVEPOINT', 'WITHIN', 'ELSEIF', 'UNLOCK', 'DATABASE', 'TRIGGER', 'ACCESS', 'FALSE', 'BREAK', 'ITERATE', 'SMALLINT', 'ASC', 'YEAR', 'DELETE', 'ROLLBACK', 'ON', 'ESCAPE', 'CREATE', 'MONTH', 'SPECIFIC', 'SESSION', 'SQLSTATE', 'HOLD', 'SET', 'EXPLAIN', 'RETURN', 'ROWNUM', 'BINARY', 'SYSDATE', 'SQLWARNING', 'EXTEND', 'CAST', 'FOR', 'TERMINATED', 'VIEW', 'TRAILING', 'HOUR', 'VARYING', 'RESTRICT', 'RIGHT', 'DISTINCT', 'JOIN', 'UNKNOWN', 'VALUES', 'TABLE', 'OR', 'DOUBLE', 'DROP', 'COMMIT', 'PRECISION', 'LANGUAGE', 'START', 'INTERSECT', 'IGNORE', 'NULL', 'CURRENT_DATE', 'LOCK', 'INTO', 'NEW', 'DESC', 'STATIC', 'MODIFIES', 'GRANT', 'VALUE', 'LIMIT', 'MODULE', 'DATE', 'LOCALTIME', 'PERCENT', 'REPEAT', 'FULL', 'USAGE', 'ORDER', 'WHEN', 'PRIMARY', 'BETWEEN', 'CURSOR', 'DECIMAL', 'HAVING', 'IF', 'FILTER', 'INDEX', 'ILIKE', 'VARCHAR', 'EXEC', 'USING', 'ROWS', 'PLACING', 'WHILE', 'EXECUTE', 'EACH', 'LEFT', 'FLOAT', 'COLLATE', 'CURRENT_TIME', 'OPEN', 'RANGE', 'CROSS', 'FUNCTION', 'TIME', 'BOTH', 'NOT', 'CONVERT', 'NCHAR', 'KEY', 'DEFAULT', 'LIKE', 'ANALYZE', 'EXISTS', 'IN', 'BIT', 'INOUT', 'SUM', 'NUMERIC', 'AFTER', 'LEAVE', 'INSERT', 'TO', 'COUNT', 'THEN', 'BEFORE', 'OUTER', 'COLUMN', 'ONLY', 'END', 'PROCEDURE', 'OFFSET', 'ADD', 'INNER', 'RELEASE', 'FROM', 'DAY', 'NO', 'CALL', 'BY', 'LOCAL', 'ZONE', 'TRUE', 'EXIT', 'LEADING', 'INTEGER', 'MERGE', 'OLD', 'AVG', 'MIN', 'SQL', 'LOOP', 'SIGNAL', 'REFERENCES', 'MINUTE', 'UNIQUE', 'GENERATED', 'ALL', 'MATCH', 'CASCADE', 'UNION', 'COMMENT', 'FETCH', 'UNDO', 'UPDATE', 'WHERE', 'ELSE', 'PARTITION', 'BIGINT', 'CHARACTER', 'CURRENT_TIMESTAMP', 'ALTER', 'INTERVAL', 'REVOKE', 'CONNECT', 'WITH', 'TIMESTAMP', 'GROUP', 'BEGIN', 'CURRENT', 'REGEXP', 'NATURAL', 'SOME', 'SQLEXCEPTION', 'MAX', 'SUBSTRING', 'OF', 'AND', 'REPLACE', 'IS'} SPECIAL_CHARS_PATTERN = re.compile(r'[^a-zA-Z0-9_]') def load_json_file(file): dataset = [] with open(file, 'r', encoding='utf-8') as f: objects = ijson.items(f, 'item') for obj in tqdm(objects): dataset.append(obj) return dataset def remove_sql_comments(sql): # Remove single-line comments sql = re.sub(r'--.*', '', sql) # Remove multi-line comments sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) return sql.strip() def obtain_db_ddls(db_file_dir): conn = sqlite3.connect(db_file_dir) cursor = conn.cursor() cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() create_statements = [] for table in tables: _, create_statement = table create_statements.append(create_statement) cursor.close() conn.close() # table_schemas = [remove_sql_comments(stat) for stat in create_statements] return create_statements def needs_backticks(identifier): if identifier.upper() in SQL_RESERVED_WORDS: return True if SPECIAL_CHARS_PATTERN.search(identifier): return True return False def format_identifier(identifier): if needs_backticks(identifier): return f'`{identifier}`' return identifier def sample_table_values(db_file_dir, table_names, limit_num): db_values_dict = dict() conn = sqlite3.connect(db_file_dir) cursor = conn.cursor() for table_name in table_names: cursor.execute(f"PRAGMA table_info(`{table_name}`);") columns = cursor.fetchall() column_names = [column[1] for column in columns] for column_name in column_names: # cursor.execute(f"SELECT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL LIMIT {limit_num};") query = f""" SELECT `{column_name}` FROM ( SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL and `{column_name}` != '' ) AS unique_values LIMIT {limit_num}; """ cursor.execute(query) values = cursor.fetchall() values = [value[0] for value in values] # truncate too long strings for idx in range(len(values)): if isinstance(values[idx], str): values[idx] = values[idx][:40] if len(values) > 0: db_values_dict[f"{table_name}.{column_name}".lower()] = values cursor.close() conn.close() return db_values_dict def calculate_substring_match_percentage(query, target): query = query.lower() target = target.lower() substrings = [] for i in range(len(query)): for j in range(i + 1, len(query) + 1): substrings.append(query[i:j]) max_matched_substring_len = max([len(substring) for substring in substrings if substring in target]) return max_matched_substring_len/len(query) def retrieve_relevant_hits(searcher, queries): queries = list(dict.fromkeys(queries)) # print("len(queries):", len(queries)) q_ids = [f"{idx}" for idx in range(len(queries))] query2hits = dict() search_results = searcher.batch_search(queries, q_ids, k = 10, threads=60) for query, q_id in zip(queries, q_ids): hits = search_results[q_id] hits = list(dict.fromkeys(([hit.raw for hit in hits]))) hits = [json.loads(hit) for hit in hits] query2hits[query] = hits return query2hits def retrieve_question_related_db_values(hits, question): high_score_hits = [] for idx, hit in enumerate(hits): table_name, column_name, c_id = hit["id"].split("-**-") score = calculate_substring_match_percentage(hit["contents"], question) if score > 0.85: high_score_hits.append( { "table_dot_column_lower_case": f"{table_name}.{column_name}".lower(), "db_value": hit["contents"], "score": score, "index": idx, } ) high_score_hits = sorted(high_score_hits, key=lambda x: (x["score"], len(x["db_value"]), x["index"]), reverse=True) high_score_hits = high_score_hits[:20] # remain top 20 db values relavant_db_values_dict = dict() for hit in high_score_hits: if hit["table_dot_column_lower_case"] in relavant_db_values_dict: relavant_db_values_dict[hit["table_dot_column_lower_case"]].append(hit["db_value"]) else: relavant_db_values_dict[hit["table_dot_column_lower_case"]] = [hit["db_value"]] return relavant_db_values_dict def obtain_n_grams(sequence, max_n): ''' returns all grams of sequence less than or equal to `max_n` ''' tokens = word_tokenize(sequence) all_n_grams = [] for n in range(1, max_n + 1): all_n_grams.extend([" ".join(gram) for gram in ngrams(tokens, n)]) return all_n_grams input_prompt_template = '''Task Overview: You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question. Database Engine: {db_engine} Database Schema: {db_details} This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints. Question: {question} Instructions: - Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more. - The generated query should return all of the information asked in the question without any missing or extra information. - Before generating the final SQL query, please think through the steps of how to write the query. Output Format: In your answer, please enclose the generated SQL query in a code block: ```sql -- Your SQL query ``` Take a deep breath and think step by step to find the correct SQL query. ''' def obtain_pk_fk_column_idx(db_info): pk_fk_column_idx_list = [] for primary_keys_idx in db_info["primary_keys"]: if isinstance(primary_keys_idx, int): pk_fk_column_idx_list.append(primary_keys_idx) elif isinstance(primary_keys_idx, list): pk_fk_column_idx_list.extend(primary_keys_idx) for (source_column_idx, target_column_idx) in db_info["foreign_keys"]: pk_fk_column_idx_list.append(source_column_idx) pk_fk_column_idx_list.append(target_column_idx) return pk_fk_column_idx_list def prepare_schema_filter_data(question, db_info): data = dict() data["text"] = question data["schema"] = dict() data["schema"]["schema_items"] = [] for outer_table_idx, table_name_original in enumerate(db_info["table_names_original"]): table_info = dict() table_info["table_name"] = table_name_original table_info["table_comment"] = "" table_info["column_names"] = [] table_info["column_comments"] = [] for (inner_table_idx, column_name_original), (_, column_comment) in zip(db_info["column_names_original"], db_info["column_names"]): if outer_table_idx == inner_table_idx: table_info["column_names"].append(column_name_original) table_info["column_comments"].append(column_comment) data["schema"]["schema_items"].append(table_info) return data def obtain_db_details(db_info, data_source, sampled_db_values_dict, relavant_db_values_dict, output_seq, mode, question): db_details = [] assert len(db_info["column_names_original"]) == len(db_info["column_names"]) == len(db_info["column_types"]) if mode == "train": ''' to increase training data's diversity, the input database schema includes: [PK, FK, output sequence-used columns, random sampled unused columns] ''' # remain primary and foreign key columns used_column_idx_list = obtain_pk_fk_column_idx(db_info) # remain SQL-used columns for column_idx, (inner_table_idx, column_name) in enumerate(db_info["column_names_original"]): if column_name.lower() in output_seq.lower(): used_column_idx_list.append(column_idx) used_column_idx_list = list(set(used_column_idx_list)) used_column_num = len(used_column_idx_list) all_column_idx_list = list(range(len(db_info["column_names_original"]))) unused_column_idx_list = [idx for idx in all_column_idx_list if idx not in used_column_idx_list] # random select some unused columns to mimic noise in the input sequence if unused_column_idx_list: unused_column_prob = random.choice([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) sample_size = int(unused_column_prob * len(unused_column_idx_list)) max_column_num = 225 if used_column_num > max_column_num: sample_size = 0 elif used_column_num + sample_size > max_column_num: sample_size = max_column_num - used_column_num else: sample_size = sample_size used_column_idx_list.extend(random.sample(unused_column_idx_list, sample_size)) else: # put all tables and columns in the prompt used_column_idx_list = list(range(len(db_info["column_names_original"]))) # print(used_column_idx_list) for outer_table_idx, table_name in enumerate(db_info["table_names_original"]): column_info_list = [] pk_columns = [] fk_info = [] column_comment_prob = random.choice([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]) for column_idx, ((inner_table_idx, column_name), (_, column_comment), column_type) in enumerate(zip( db_info["column_names_original"], db_info["column_names"], db_info["column_types"] )): if inner_table_idx == outer_table_idx: if column_idx not in used_column_idx_list: continue column_values = [] if f"{table_name}.{column_name}".lower() in relavant_db_values_dict: column_values.extend(relavant_db_values_dict[f"{table_name}.{column_name}".lower()]) if f"{table_name}.{column_name}".lower() in sampled_db_values_dict: column_values.extend(sampled_db_values_dict[f"{table_name}.{column_name}".lower()]) column_values = list(dict.fromkeys(column_values)) # dedup (reserve order) column_values = column_values[:6] if data_source == "synthetic": if random.random() < column_comment_prob: column_info = f' {format_identifier(column_name)} {column_type}, -- {column_comment}' if len(column_values) > 0: column_info += f", example: {column_values}" else: # simulate some columns do not have comment column_info = f' {format_identifier(column_name)} {column_type},' if len(column_values) > 0: column_info += f" -- example: {column_values}" else: if column_name.lower() in [column_comment.lower(), column_comment.lower().replace(" ", "_"), column_comment.lower().replace(" ", "")] \ or column_comment.strip() == "": column_info = f' {format_identifier(column_name)} {column_type},' if len(column_values) > 0: column_info += f" -- example: {column_values}" else: column_info = f' {format_identifier(column_name)} {column_type}, -- {column_comment}' if len(column_values) > 0: column_info += f", example: {column_values}" column_info_list.append(column_info) for primary_keys_idx in db_info["primary_keys"]: if isinstance(primary_keys_idx, int): if column_idx == primary_keys_idx: pk_columns.append(column_name) # f' PRIMARY KEY ("{ }")' elif isinstance(primary_keys_idx, list): if column_idx in primary_keys_idx: pk_columns.append(column_name) for (source_column_idx, target_column_idx) in db_info["foreign_keys"]: if column_idx == source_column_idx: source_table_idx = db_info["column_names_original"][source_column_idx][0] source_table_name = db_info["table_names_original"][source_table_idx] source_column_name = db_info["column_names_original"][source_column_idx][1] target_table_idx = db_info["column_names_original"][target_column_idx][0] target_table_name = db_info["table_names_original"][target_table_idx] target_column_name = db_info["column_names_original"][target_column_idx][1] fk_info.append(f' CONSTRAINT fk_{source_table_name.lower().replace(" ", "_")}_{source_column_name.lower().replace(" ", "_")} FOREIGN KEY ({format_identifier(source_column_name)}) REFERENCES {format_identifier(target_table_name)} ({format_identifier(target_column_name)}),') if len(column_info_list) > 0: pk_columns = list(OrderedDict.fromkeys(pk_columns)) if len(pk_columns) > 0: pk_info = [' PRIMARY KEY (' + ', '.join([f'{format_identifier(column_name)}' for column_name in pk_columns]) + '),'] else: pk_info = [] fk_info = list(OrderedDict.fromkeys(fk_info)) table_ddl = "" table_ddl += f'CREATE TABLE {format_identifier(table_name)} (\n' table_ddl += "\n".join(column_info_list + pk_info + fk_info) if table_ddl.endswith(","): table_ddl = table_ddl[:-1] # remove extra commas table_ddl += "\n);" db_details.append(table_ddl) if mode == "train": random.shuffle(db_details) db_details = "\n\n".join(db_details) # double check for column_idx, (_, column_name) in enumerate(db_info["column_names_original"]): if column_name == "*": continue if column_idx in used_column_idx_list: assert column_name.lower() in db_details.lower() return db_details def deduplicate_dicts(dict_list): seen = set() unique_dicts = [] for d in dict_list: dict_tuple = frozenset(d.items()) if dict_tuple not in seen: seen.add(dict_tuple) unique_dicts.append(d) return unique_dicts def prepare_input_output_pairs(data, ek_key, db_id2relevant_hits, sampled_db_values_dict, db_info, source, output_key, mode): if data[ek_key].strip() == "": question = data["question"] else: question = data[ek_key] + "\n" + data["question"] relavant_db_values_dict = dict() if db_id2relevant_hits: # retrieve matched values from the databases queries = obtain_n_grams(question, 8) + [question] queries = list(dict.fromkeys(queries)) hits = [] for query in queries: hits.extend(db_id2relevant_hits[data["db_id"]][query]) hits = deduplicate_dicts(hits) relavant_db_values_dict = retrieve_question_related_db_values(hits, question) db_details = obtain_db_details( db_info, source, sampled_db_values_dict, relavant_db_values_dict, data[output_key], mode, question ) input_seq = input_prompt_template.format( db_engine = "SQLite", db_details = db_details, question = question ) return {"input_seq": input_seq, "output_seq": data[output_key]} def process_data(args): data, ek_key, searcher, sampled_db_values, db_info, source, output_key, mode = args return prepare_input_output_pairs(data, ek_key, searcher, sampled_db_values, db_info, source, output_key, mode) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_data_file", type = str) parser.add_argument("--output_data_file", type = str) parser.add_argument("--db_path", type = str) parser.add_argument("--tables", type = str) parser.add_argument("--source", type = str) parser.add_argument("--mode", type = str) parser.add_argument("--value_limit_num", type = int) parser.add_argument("--db_content_index_path", type = str) opt = parser.parse_args() print(opt) random.seed(42) assert opt.mode in ["train", "dev", "test"] dataset = load_json_file(opt.input_data_file) ek_key = "external_knowledge" if opt.source == "synthetic": output_key = "cot" elif opt.source == "spider2.0": output_key = "query" for data in dataset: data[output_key] = "" # spider2.0 does not provide gold sqls elif opt.source == "spider": if opt.mode == "train": output_key = "cot" # use our synthetic CoT during training else: output_key = "query" for data in dataset: data[ek_key] = "" # spider does not provide external knowledge elif opt.source == "bird": if opt.mode == "train": output_key = "cot" # use our synthetic CoT during training else: output_key = "SQL" ek_key = "evidence" elif opt.source == "spider_dk": output_key = "query" for data in dataset: data[ek_key] = "" # spider_dk does not provide external knowledge elif opt.source == "spider_realistic": output_key = "query" for data in dataset: data[ek_key] = "" # spider_realistic does not provide external knowledge elif opt.source == "spider_syn": output_key = "query" for data in dataset: data[ek_key] = "" # spider_syn does not provide external knowledge data["question"] = data["SpiderSynQuestion"] elif opt.source in ["ehrsql", "sciencebenchmark"]: output_key = "query" for data in dataset: data[ek_key] = "" # ehrsql and sciencebenchmark does not provide external knowledge else: assert "argument `source` should be in [xxxx]." used_db_ids = list(set([data["db_id"] for data in dataset])) db_id2sampled_db_values = dict() db_id2db_info = dict() for db_info in tqdm(load_json_file(opt.tables)): db_id = db_info["db_id"] if db_id not in used_db_ids: continue db_file = os.path.join(opt.db_path, db_id, db_id + ".sqlite") sampled_db_values_dict = sample_table_values(db_file, db_info["table_names_original"], opt.value_limit_num) db_id2sampled_db_values[db_id] = sampled_db_values_dict db_id2db_info[db_id] = db_info batch_size = 20000 sliced_datasets = [dataset[i: i+batch_size] for i in range(0, len(dataset), batch_size)] print(len(dataset)) print([len(batch_dataset) for batch_dataset in sliced_datasets]) assert len(dataset) == sum([len(batch_dataset) for batch_dataset in sliced_datasets]) new_dataset = [] for batch_idx, batch_dataset in enumerate(sliced_datasets): print(f"Process: {batch_idx+1}/{len(sliced_datasets)}") if opt.db_content_index_path: db_id2searcher = dict() batch_db_ids = list(set([data["db_id"] for data in batch_dataset])) # load db context index searchers for db_id in batch_db_ids: db_id2searcher[db_id] = LuceneSearcher(os.path.join(opt.db_content_index_path, db_id)) db_id2queries = dict() for data in tqdm(batch_dataset): if data[ek_key].strip() == "": question = data["question"] else: question = data[ek_key] + "\n" + data["question"] queries = obtain_n_grams(question, 8) + [question] queries = list(set(queries)) if data["db_id"] in db_id2queries: db_id2queries[data["db_id"]].extend(queries) else: db_id2queries[data["db_id"]] = queries # perform db content retrieval (in a large batch) db_id2relevant_hits = dict() for db_id in tqdm(batch_db_ids): db_id2relevant_hits[db_id] = retrieve_relevant_hits(db_id2searcher[db_id], db_id2queries[db_id]) else: db_id2relevant_hits = None for data in tqdm(batch_dataset): new_dataset.append( prepare_input_output_pairs(data, ek_key, db_id2relevant_hits, db_id2sampled_db_values[data["db_id"]], db_id2db_info[data["db_id"]], opt.source, output_key, opt.mode) ) del db_id2searcher, db_id2relevant_hits, with open(opt.output_data_file, "w", encoding = "utf-8") as f: f.write(json.dumps(new_dataset, indent = 2, ensure_ascii = False))