"vscode:/vscode.git/clone" did not exist on "129e013b41133e9bf236642fa43362e68623716a"
Commit 4f4ba442 authored by mashun1's avatar mashun1
Browse files

omnisql

parents
Pipeline #2643 canceled with stages
__pycache__
ckpt
bin
pt
\ No newline at end of file
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
\ No newline at end of file
# OmniSQL
## 论文
`OmniSQL: Synthesizing High-quality Text-to-SQL Data at Scale`
* https://arxiv.org/pdf/2503.02240
## 模型结构
该框架分为四个关键步骤:
- 基于网页表的数据库合成:利用网络上丰富的表格数据,合成出符合现实商业场景的数据库。通过提示LLM生成与给定表格相关的数据库,包括多个关系表及其结构信息。
- 复杂度感知的SQL查询生成:根据合成的数据库生成SQL查询,LLM会根据指定的复杂度级别(简单、中等、复杂和高度复杂)生成相应的SQL查询。
- 风格化自然语言问题合成:将生成的SQL查询转换为自然语言问题,采用多种语言风格(如正式、口语、模糊等),以增强语言多样性。
- 链式推理解决方案合成:为每个合成的文本到SQL数据生成逐步的链式推理解决方案,详细说明从问题到SQL查询的推理过程,增强数据的可解释性。
![alt text](readme_imgs/arch.png)
## 算法原理
主要原理是利用大型语言模型(LLMs)和自动化预处理与后处理策略来生成高质量和多样化的数据样本,从而减少对人工干预的依赖,并使用生成的数据在现有的LLM基础上训练模型。
![alt text](readme_imgs/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
docker run --shm-size 500g --network=host --name=omnisql --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 500g --network=host --name=omnisql --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
dtk: 25.04
python: 3.10
torch: 2.4.1
deepspeed: 0.14.2
flash-attn: 2.6.1
vllm: 0.6.2
triton: 3.0.0
```
2、其他非特殊库直接按照requirements.txt安装
```
pip install -r requirments.txt
```
## 数据集
下载至`train_and_evaluate`并解压.
训练数据 - [hf](https://huggingface.co/datasets/seeklhy/OmniSQL-datasets/) | [modelscope](https://huggingface.co/datasets/seeklhy/OmniSQL-datasets)
测试数据 - [googledriver](https://drive.google.com/file/d/1iNa1WgA9tN_OFna08nq_tHZdXx9Lz2vO/view)
除此之外也可以按需自行合成数据,参考`data_synthesis`.
本项目提供了用于测试训练功能的数据,位于`train_and_evaluate/data`.
## 训练
```bash
cd train_and_evaluate
# train OmniSQL-7B using SynSQL-2.5M
sh train_omnisql_7b.sh
# train OmniSQL-14B using SynSQL-2.5M
sh train_omnisql_14b.sh
# train OmniSQL-32B using SynSQL-2.5M
sh train_omnisql_32b.sh
```
注意:训练前需修改文件中的模型及数据路径,更多设置见`accelerate_config_(7/14/32)b.yaml`
## 推理
注意:在运行推理前按需修改代码中的模型路径及prompt。
### Transformers
```bash
cd inferences
python tf_inference.py
```
### vllm
```bash
cd inferences
python vllm_inference.py
```
### vllm_serve
```bash
vllm serve /path/to/model -tp 1
```
```bash
cd inferences
bash vllm_inference.sh
```
## result
![alt text](readme_imgs/result.png)
### 精度
||loss|
|:---:|:---:|
|N卡|0.1449|
|dcu|0.1448|
epoch: 1
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`电商,教育,广媒`
## 预训练权重
下载后的模型放在 `ckpts` 目录(自行创建)
| Model | url |
|-----------|------------------|
| OmniSQL-7B | [✨ Modelscope](https://modelscope.cn/models/seeklhy/OmniSQL-7B), [🤗 HuggingFace](https://huggingface.co/seeklhy/OmniSQL-7B) |
| OmniSQL-14B | [✨ Modelscope](https://modelscope.cn/models/seeklhy/OmniSQL-14B), [🤗 HuggingFace](https://huggingface.co/seeklhy/OmniSQL-14B) |
| OmniSQL-32B | [✨ Modelscope](https://modelscope.cn/models/seeklhy/OmniSQL-32B), [🤗 HuggingFace](https://huggingface.co/seeklhy/OmniSQL-32B) |
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/omnisql_pytorch
## 参考资料
* https://github.com/RUCKBReasoning/OmniSQL
# Data Synthesis Framework
This directory contains the source code and prompts for our data synthesis framework.
- **Step 1:** Web Table-Driven Database Synthesis (see `database_synthesis`)
- **Step 2:** Complexity-Aware SQL Query Generation (see `sql_synthesis`)
- **Step 3:** Stylized Natural Language Question Synthesis (see `question_synthesis`)
- **Step 4:** Chain-of-Thought Solution Synthesis (see `cot_synthesis`)
These steps are sequential, but you can start at any intermediate step to synthesize text-to-SQL data samples. For instance, if you already have databases, you can skip Step 1 and generate high-quality `<question, SQL query, CoT solution>` pairs for your databases.
To set up the Anaconda environment for data synthesis:
```bash
conda create -n omnisql_data_synthesis python=3.9.5
conda activate omnisql_data_synthesis
pip install -U sentence-transformers
pip install json-repair ijson matplotlib func_timeout
```
\ No newline at end of file
# 程式化自然语言问题合成
这是我们数据合成框架的最后一步,专注于为三元组生成分步思维链 (CoT) 解决方案。
## 第 1 步:思维链生成
```bash
# 运行 以准备用于生成 CoT 的提示
mkdir prompts
python3 generate_cot_synthesis_prompts.py
```
```bash
# 执行以生成样品的 CoT 解决方案
mkdir results
python3 synthesize_cot.py --model model_name --base_url vllm_serve_url(http://x.x.x.x:8000/v1)
```
## 第 2 步:后处理
```bash
# 执行基于执行的主要投票,选择最可靠的 CoT 解决方案
python3 post_process_cot.py
```
\ No newline at end of file
# Stylized Natural Language Question Synthesis
This is the final step in our data synthesis framework, focused on generating step-by-step chain-of-thought (CoT) solutions for `<database, question, SQL query>` triplets.
## Step 1: Chain-of-Thought Generation
Create CoT solutions for each data sample.
1. Run `python3 generate_cot_synthesis_prompts.py` to prepare prompts for CoT generation.
2. Execute `python3 synthesize_cot.py` to generate CoT solutions for `<database, question, SQL query>` samples. (Note: Ensure the `llm_inference()` function is implemented to integrate your preferred LLM. For each prompt, we sample multiple CoT solutions with a temperature of `0.8`.)
## Step 2: Post-Processing
1. Run `python3 post_process_cot.py` to perform execution-based major voting, selecting the most reliable CoT solutions.
2. Save the final synthetic `<database, question, SQL query, CoT solution>` samples to `./results/synthetic_text2sql_dataset.json`.
\ No newline at end of file
import json
import re
from tqdm import tqdm
def remove_sql_comments(sql):
# Remove single-line comments
sql = re.sub(r'--.*', '', sql)
# Remove multi-line comments
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
return sql.strip()
if __name__ == "__main__":
dataset = json.load(open("../question_synthesis/results/question_and_sql_pairs.json"))
tables = json.load(open("../database_synthesis/tables.json"))
print("len(tables):", len(tables))
prompts = []
db_id2ddls = dict()
for table in tables:
db_id2ddls[table["db_id"]] = table["ddls"]
print("len(db_id2ddls):", len(db_id2ddls))
prompt_tamplate = open("./prompt_templates/cot_synthesis_prompt_template.txt").read()
for data in tqdm(dataset):
if data["external_knowledge"] != "":
question = data["external_knowledge"] + "\n" + data["question"]
else:
question = data["question"]
data["cot_synthesis_prompt"] = prompt_tamplate.format(
schema = "\n\n".join(db_id2ddls[data["db_id"]]),
question = question,
sql = remove_sql_comments(data["sql"])
)
with open("./prompts/cot_synthesis_prompts.json", "w", encoding="utf-8") as f:
f.write(json.dumps(dataset, indent=2, ensure_ascii=False))
\ No newline at end of file
import json
import re
import sqlite3
import os
from tqdm import tqdm
from func_timeout import func_timeout, FunctionTimedOut
import multiprocessing as mp
import sys
import ijson
import random
def parse_response(response):
pattern = r"```sql\s*(.*?)\s*```"
sql_blocks = re.findall(pattern, response, re.DOTALL)
if sql_blocks:
# extract the last SQL query in the response text and remove extra whitespace characters
last_sql = sql_blocks[-1].strip()
return last_sql
else:
# print("No SQL blocks found.")
return ""
def execute_sql(data_idx, db_file, sql):
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
try:
cursor.execute(sql)
execution_res = cursor.fetchall()
execution_res = frozenset(execution_res) # `set` operation on execution results, and make `set` hashable
conn.rollback() # Roll back any changes
return data_idx, db_file, sql, execution_res, 1
except:
conn.rollback() # Ensure rollback on exception
return data_idx, db_file, sql, None, 0
finally:
conn.close()
def execute_sql_wrapper(data_idx, db_file, sql, timeout):
try:
res = func_timeout(timeout, execute_sql, args=(data_idx, db_file, sql))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
res = (data_idx, db_file, sql, None, 0)
except Exception as e:
res = (data_idx, db_file, sql, None, 0)
return res
def execute_callback_execute_sqls(result):
execution_results.append(result)
def execute_sqls_parallel(db_files, sqls, num_cpus=1, timeout=1):
pool = mp.Pool(processes=num_cpus)
for data_idx, db_file, sql in zip(list(range(len(sqls))), db_files, sqls):
pool.apply_async(execute_sql_wrapper, args=(data_idx, db_file, sql, timeout), callback=execute_callback_execute_sqls)
pool.close()
pool.join()
def load_json_file(file):
dataset = []
with open(file, 'r', encoding='utf-8') as f:
objects = ijson.items(f, 'item')
for obj in objects:
dataset.append(obj)
return dataset
if __name__ == "__main__":
results = load_json_file("./results/cot_synthesis.json")
sampling_num = len(results[0]["responses"])
print("sampling_num:", sampling_num)
# execution results-guided major voting
major_voting_filter_num = 0
major_voting_results = []
process_batch_size = 10240
for pred_idx in tqdm(range(0, len(results), process_batch_size)):
batch_cot_results = results[pred_idx: pred_idx + process_batch_size]
batch_db_files = []
batch_sqls = []
execution_results = []
for cot_result in batch_cot_results:
batch_db_files.extend([os.path.join("../database_synthesis/synthetic_sqlite_databases", cot_result["db_id"], cot_result["db_id"] + ".sqlite")] * sampling_num)
batch_sqls.extend([parse_response(response) for response in cot_result["responses"]])
assert len(batch_db_files) == len(batch_sqls)
execute_sqls_parallel(batch_db_files, batch_sqls, 20, 2)
execution_results = sorted(execution_results, key = lambda x: x[0])
assert len(batch_cot_results) * sampling_num == len(execution_results)
for data_idx in range(len(batch_cot_results)):
cot_result = batch_cot_results[data_idx]
execution_results_in_one_sample = execution_results[sampling_num * data_idx: sampling_num * (data_idx + 1)]
assert len(cot_result["responses"]) == len(execution_results_in_one_sample)
major_voting_dict = dict()
for cot, execution_result in zip(cot_result["responses"], execution_results_in_one_sample):
if execution_result[-1] == 0: # invalid SQL queries
continue
if execution_result[-2] in major_voting_dict:
major_voting_dict[execution_result[-2]].append(cot)
else:
major_voting_dict[execution_result[-2]] = [cot]
# if the number of valid cots is less than 3, we discard current data sample
valid_cot_num = sum([len(cot_list) for cot_list in major_voting_dict.values()])
# print("valid_cot_num:", valid_cot_num)
if valid_cot_num < 3:
major_voting_filter_num += 1
continue
# find cots with the most vote count, based on the execution results
voting_key = max(major_voting_dict, key = lambda k: len(major_voting_dict[k]))
voting_cots = major_voting_dict[voting_key]
final_cot = random.choice(voting_cots)
major_voting_results.append(
{
"db_id": cot_result["db_id"],
"sql_complexity": cot_result["sql_complexity"],
"question_style": cot_result["question_style"],
"question": cot_result["question"],
"external_knowledge": cot_result["external_knowledge"],
"cot": final_cot,
"sql": parse_response(final_cot)
}
)
print("major_voting_filter_num:", major_voting_filter_num)
print("num of data samples (after execution-based major voting):", len(major_voting_results))
with open("results/synthetic_text2sql_dataset.json", "w", encoding="utf-8") as f:
f.write(json.dumps(major_voting_results, ensure_ascii=False, indent=2))
\ No newline at end of file
You are a senior data analyst specializing in SQL. Your task is to translate a natural language question into an executable SQLite query, providing a detailed reasoning trace.
You will also receive a reference solution from a colleague, which may or may not be correct. This extra information intends to help you generate your answer, but you are asked not to mention the reference solution in any form.
The reference solution might include:
1. Unnecessary table and column selections.
2. Incorrect or excessive joins.
3. Misalignment with the question.
4. Opportunities for simplification.
Ensure the SQL query is presented in a Markdown code block with proper syntax highlighting, like this:
```sql
SELECT * FROM table;
```
[Database Schema]:
{schema}
[Natural Language Question]:
{question}
[Reference Solution]:
```sql
{sql}
```
Provide your step-by-step text-to-SQL solution here.
\ No newline at end of file
import argparse
import json
import argparse
import json
import openai
def llm_inference(model, base_url, dataset):
"""
Perform LLM inference to generate multiple responses for each prompt in the dataset.
Args:
model: The LLM used for inference.
dataset: A list of dictionaries.
Returns:
A list of dictionaries, where each dictionary includes the original data and the corresponding generated responses.
"""
client = openai.OpenAI(
base_url=base_url,
api_key="EMPTY"
)
prompts = [data["cot_synthesis_prompt"] for data in dataset]
# Placeholder for storing generated responses for each prompt
# Each element in `responses_list` is a list of responses (strings) corresponding to a prompt.
responses_list = [] # Replace this with your actual response generation logic.
for prompt in prompts:
response = client.chat.completions.create(
model=model,
messages=[{"role":"user", "content": prompt}],
max_tokens=4196,
temperature=0.8
)
responses_list.append(response.choices[0].message.content.strip())
# Initialize an empty list to store the results
results = []
# Iterate through the dataset and the corresponding responses
for data, responses in zip(dataset, responses_list):
# Add the generated responses to the current data entry
data["responses"] = responses
# Append the updated data entry to the results
results.append(data)
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", type = str)
parser.add_argument("--base_url", type=str)
opt = parser.parse_args()
print(opt)
input_dataset = json.load(open("./prompts/cot_synthesis_prompts.json"))
output_file = "./results/cot_synthesis.json"
results = llm_inference(opt.model, opt.base_url, input_dataset)
with open(output_file, "w", encoding = "utf-8") as f:
f.write(json.dumps(results, indent = 2, ensure_ascii = False))
# Web 表驱动的数据库合成
## 准备 Web 表
```bash
unzip web_tables.json.zip
```
### 第 1 步:初始数据库生成
从 Web 表生成初始数据库。
```bash
# Run 创建用于数据库生成的提示。
mkdir prompts
python3 generate_schema_synthesis_prompts.py
```
```bash
# 生成初始数据库架构。
# 此处使用vllm服务,也可以修改脚本使用其他服务
vllm serve model_name -tp 2
mkdir results
python3 synthesize_schema.py --model model_name --base_url vllm_serve_url(http://x.x.x.x:8000/v1)
```
### 第 2 步:数据库增强
增强最初生成的数据库,以提高复杂性和真实感。
```bash
## 运行 以创建数据库增强的提示
python3 generate_schema_enhancement_prompts.py
```
```bash
# 运行以生成增强的数据库架构
# vllm serve model_name -tp 2
python3 enhance_schema.py --model model_name --base_url vllm_serve_url(http://x.x.x.x:8000/v1)
```
### 第 3 步:构建 SQLite 数据库
```bash
# 运行以构建存储在文件夹中的 SQLite 数据库
python3 build_sqlite_databases.py
```
```bash
# Run 以创建文件,其中包含有关合成数据库的详细信息,与以前的文本到 SQL 数据集保持一致
python3 generate_tables_json.py
```
\ No newline at end of file
# Web Table-Driven Database Synthesis
This is the first step in our data synthesis framework, designed to generate realistic databases using web tables.
## Prepare Web Tables
Unzip `web_tables.json.zip` to access 19,935 high-quality web tables from [Tablib](https://arxiv.org/pdf/2310.07875).
## Step 1: Initial Database Generation
Generate an initial database from the web tables.
1. Run `python3 generate_schema_synthesis_prompts.py` to create prompts for database generation.
2. Run `python3 synthesize_schema.py` to generate initial database schemas. (Implement the `llm_inference()` function to use your preferred LLMs.)
## Step 2: Database Enhancement
Enhance the initially generated databases to increase complexity and realism.
1. Run `python3 generate_schema_enhancement_prompts.py` to create prompts for database enhancement.
2. Run `python3 enhance_schema.py` to generate enhanced database schemas. (Implement the `llm_inference()` function to use your preferred LLMs.)
## Step 3: Building SQLite Databases
Build SQLite databases based on the enhanced database schemas.
1. Run `python3 build_sqlite_databases.py` to construct SQLite databases, which are stored in the `synthetic_sqlite_databases` folder.
2. Run `python3 generate_tables_json.py` to create the `tables.json` file, containing detailed information about the synthetic databases, aligning with previous text-to-SQL datasets.
\ No newline at end of file
import json
from tqdm import tqdm
from sqlite_schema_parser import verify_schema
import random
if __name__ == "__main__":
enhanced_results = json.load(open("./results/schema_enhancement.json"))
final_schemas = []
error_case_num = 0
for result in tqdm(enhanced_results):
try:
domain = result["domain"]
schema = json.loads(result["enhanced_schema"])
assert "tables" in schema and "foreign_keys" in schema
tables = []
for table in schema["tables"]:
try:
assert "table_name" in table and "column_names" in table and \
"column_types" in table and "column_descriptions" in table
assert len(table["column_names"]) == len(table["column_types"]) == len(table["column_descriptions"])
tables.append(table)
except Exception as e:
pass
table_names_lower = [table["table_name"].lower() for table in tables]
foreign_keys = []
for foreign_key in schema["foreign_keys"]:
try:
assert "source_table" in foreign_key and "column_in_source_table" in foreign_key and \
"referenced_table" in foreign_key and "column_in_referenced_table" in foreign_key
assert foreign_key["source_table"].lower() in table_names_lower and \
foreign_key["referenced_table"].lower() in table_names_lower
foreign_keys.append(foreign_key)
except Exception as e:
pass
final_schemas.append(
{
"domain": domain,
"tables": tables,
"foreign_keys": foreign_keys
}
)
except Exception as e:
error_case_num += 1
# print(e)
print("error_case_num:", error_case_num)
db_ids = []
success_labels = []
for final_schema in tqdm(final_schemas):
db_id = final_schema["domain"].lower().replace("(", "_").replace(")", "_").replace("-", "_").replace(" ", "_").replace("*", "_").strip()
if len(db_id) > 75:
db_id = db_id[:75]
# resolve db_id conflict issues
while db_id in db_ids:
db_id += "_" + str(random.randint(0, 1000000000000))
success_label = verify_schema(final_schema, db_id)
if success_label:
db_ids.append(db_id)
success_labels.append(success_label)
print("success rate:", sum(success_labels)/len(success_labels))
\ No newline at end of file
import argparse
import json
import os
import re
import time
from json_repair import json_repair
import openai
def parse_response(response):
schema_pattern = r'```json\s*([\s\S]*?)\s*```'
try:
enhanced_schema_match = re.search(schema_pattern, response, re.DOTALL)
enhanced_schema_str = enhanced_schema_match.group(0).strip() if enhanced_schema_match else None
enhanced_schema_dict = json_repair.loads(enhanced_schema_str)
return enhanced_schema_dict
except Exception as e:
print(response)
print("Parsing Exception:", str(e))
return None
def parse_prompt(prompt):
domain_pattern = r'(?<=\*\*Business Domain:\*\*)(.*?)(?=\*\*Business Scenario:\*\*)'
scenario_pattern = r'(?<=\*\*Business Scenario:\*\*)(.*?)(?=\*\*Initial Database Schema:\*\*)'
domain_match = re.search(domain_pattern, prompt, re.DOTALL)
domain = domain_match.group(0).strip() if domain_match else None
scenario_match = re.search(scenario_pattern, prompt, re.DOTALL)
scenario = scenario_match.group(0).strip() if scenario_match else None
return domain, scenario
def llm_inference(model, base_url, prompts):
'''
This function leverages a large language model (LLM) to generate responses for a given list of prompts.
You can integrate your preferred LLM within this function.
Args:
model: The LLM to be used for inference.
prompts: A list of prompts for which the LLM will generate responses.
Returns:
A list of dictionaries, each containing the original prompt, extracted domain and scenario,
and a JSON-formatted enhanced schema.
'''
client = openai.OpenAI(
base_url=base_url,
api_key="EMPTY"
)
# Generate responses using the LLM (each prompt corresponds to one response)
# responses = None # Replace this with the actual LLM call, e.g., model.generate(prompts, temperature=0, n=1)
responses = []
for prompt in prompts:
response = client.chat.completions.create(
model=model,
messages=[{"role":"user", "content": prompt}],
max_tokens=4196,
temperature=0.2
)
responses.append(response.choices[0].message.content.strip())
# Initialize a list to store the processed results
results = []
# Iterate over prompts and their corresponding responses
for prompt, response in zip(prompts, responses):
# Parse the response to get the enhanced schema
enhanced_schema_dict = parse_response(response)
if enhanced_schema_dict is None:
continue
# Extract domain and scenario from the prompt
domain, scenario = parse_prompt(prompt)
# Append the results with structured data
results.append({
"prompt": prompt,
"domain": domain,
"scenario": scenario,
"enhanced_schema": json.dumps(enhanced_schema_dict, indent=2, ensure_ascii=False)
})
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model", type = str)
parser.add_argument("--base_url", type=str)
args = parser.parse_args()
print(args)
prompts = json.load(open("./prompts/prompts_schema_enhancement.json"))
output_file = "./results/schema_enhancement.json"
results = llm_inference(args.model, args.base_url, prompts)
with open(output_file, "w", encoding = "utf-8") as f:
f.write(json.dumps(results, indent = 2, ensure_ascii = False))
import json
import random
random.seed(42)
if __name__ == '__main__':
prompts = []
prompt_template = open("./prompt_templates/enhance_prompt.txt", "r", encoding = "utf-8").read()
schema_synthesis_results = json.load(open("./results/schema_synthesis.json"))
no_res_num = 0
for data in schema_synthesis_results:
try:
if data["generated_content"] == {}:
no_res_num += 1
continue
domain = data["generated_content"]["domain"]
scenario = data["generated_content"]["scenario"]
schema_str = data["generated_content"]["schema"]
prompts.append(
prompt_template.format(domain = domain, scenario = scenario, schema = schema_str).strip()
)
except Exception as e:
print(e)
print("no_res_num:", no_res_num)
print("len(prompts):", len(prompts))
random.shuffle(prompts)
with open("./prompts/prompts_schema_enhancement.json", "w", encoding="utf-8") as file:
file.write(json.dumps(prompts, ensure_ascii=False, indent=2))
\ No newline at end of file
import json
import random
import numpy as np
def generate_a_normal_integer(mean = 10, std_dev = 4, lower_bound = 1, upper_bound = 20):
sample = np.random.normal(mean, std_dev)
sample = np.clip(sample, lower_bound, upper_bound)
return int(sample)
if __name__ == '__main__':
random.seed(42)
tables = json.load(open("web_tables.json", "r", encoding = "utf-8"))
prompt_template = open("./prompt_templates/schema_prompt.txt", "r", encoding = "utf-8").read()
prompts = []
for table in tables:
random_table_num = generate_a_normal_integer()
print(random_table_num)
prompt = prompt_template.format(
table_num = random_table_num,
table = table
)
prompts.append(prompt.strip())
random.shuffle(prompts)
with open("./prompts/prompts_schema_synthesis.json", "w", encoding = "utf-8") as file:
file.write(json.dumps(prompts, ensure_ascii = False, indent = 2))
\ No newline at end of file
import json
import sqlite3
import os
import re
from tqdm import tqdm
def obtain_db_ddls(db_file_dir):
conn = sqlite3.connect(db_file_dir)
cursor = conn.cursor()
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
create_statements = []
for table in tables:
_, create_statement = table
create_statements.append(create_statement)
cursor.close()
conn.close()
return create_statements
def obtain_pks(db_file_dir, table_name):
conn = sqlite3.connect(db_file_dir)
cursor = conn.cursor()
cursor.execute("SELECT name, type, pk FROM PRAGMA_TABLE_INFO('{}')".format(table_name))
results = cursor.fetchall()
# print(results)
column_names = [result[0] for result in results]
column_types = [result[1] for result in results]
pk_indicators = [result[2] for result in results]
pk_columns = [column_name for column_name, pk_indicator in zip(column_names, pk_indicators) if pk_indicator == 1]
return [f'"{table_name}"."{pk_column}"' for pk_column in pk_columns]
def obtain_fks(db_file_dir, table_name):
conn = sqlite3.connect(db_file_dir)
cursor = conn.cursor()
# obtain foreign keys in the current table
cursor.execute("SELECT * FROM pragma_foreign_key_list('{}');".format(table_name))
results = cursor.fetchall()
foreign_keys = []
for result in results:
if None not in [result[3], result[2], result[4]]:
foreign_keys.append([f'"{table_name}"."{result[3]}"', f'"{result[2]}"."{result[4]}"'])
return foreign_keys
if __name__ == "__main__":
db_ids = os.listdir("./synthetic_sqlite_databases")
tables = []
for db_id in tqdm(db_ids):
table = dict()
table["db_id"] = db_id
table["ddls"] = []
table["column_names"] = [[-1, "*"]]
table["column_names_original"] = [[-1, "*"]]
table["column_types"] = ["text"]
table["table_names"] = []
table["table_names_original"] = []
table["foreign_keys"] = []
table["primary_keys"] = []
db_file_dir = os.path.join("synthetic_sqlite_databases", db_id, db_id + ".sqlite")
ddls = obtain_db_ddls(db_file_dir)
# print("\n\n".join(ddls))
primary_keys_info = []
foreign_keys_info = []
table_column_names = ["*"]
for table_idx, ddl in enumerate(ddls):
if ddl.count("PRIMARY KEY") > 1:
print(ddl)
table["ddls"].append(ddl)
table_name_match = re.search(r'CREATE TABLE\s+"([^"]+)"', ddl)
table_name = table_name_match.group(1) if table_name_match else None
if table_name is None:
continue
table["table_names"].append(table_name)
table["table_names_original"].append(table_name)
column_infos = re.findall(r'"([^"]+)"\s+(\w+)\s*/\*\s*(.*?)\s*\*/', ddl)
# print(f"Table Name: {table_name}")
for column_name, column_type, comment in column_infos:
# print(f"Column Name: {column_name}, Type: {column_type}, Comment: {comment}")
table["column_names"].append([table_idx, comment]) # column_names is the semantic names (i.e., descriptions) of columns
table["column_names_original"].append([table_idx, column_name]) # column_names_original is the original names used in DDLs
table["column_types"].append(column_type)
table_column_names.append(f'"{table_name}"."{column_name}"')
primary_keys_info.append(obtain_pks(db_file_dir, table_name))
foreign_keys_info.extend(obtain_fks(db_file_dir, table_name))
for primary_key_info in primary_keys_info:
try:
if len(primary_key_info) == 1:
table["primary_keys"].append(table_column_names.index(primary_key_info[0]))
elif len(primary_key_info) > 1:
pk_idx_list = []
for primary_key_info_str in primary_key_info:
pk_idx_list.append(table_column_names.index(primary_key_info_str))
table["primary_keys"].append(pk_idx_list)
except Exception as e:
print(primary_key_info)
# print(db_id)
print(e)
for foreign_key_info in foreign_keys_info:
try:
table["foreign_keys"].append(
[table_column_names.index(foreign_key_info[0]), table_column_names.index(foreign_key_info[1])]
)
except Exception as e:
print(foreign_key_info)
# print(db_id)
print(e)
tables.append(table)
with open("tables.json", "w", encoding="utf-8") as f:
f.write(json.dumps(tables, ensure_ascii=False, indent=2))
\ No newline at end of file
**Task Overview:**
As a senior data analyst, your task is to enhance an initial database schema to provide a more detailed and realistic structure based on a given business scenario.
**Steps:**
1. **Analyze the Scenario:** Understand the provided business context.
2. **Identify Enhancements:** For each existing table, suggest new columns and explain their relevance. Be creative and thorough.
3. **Enrich the Schema:** Present the enriched schema in JSON format, ensuring proper primary and foreign key relationships.
**Business Domain:**
{domain}
**Business Scenario:**
{scenario}
**Initial Database Schema:**
```json
{schema}
```
**Output Format:**
Your output should provide the enriched database schema in JSON format:
```json
-- enriched database schema
```
Let's think step by step.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment