Commit 4f4ba442 authored by mashun1's avatar mashun1
Browse files

omnisql

parents
Pipeline #2643 canceled with stages
Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
CREATE TABLE badges (
Id integer, -- example: [1, 2]
UserId integer, -- example: [5, 6]
Name text, -- example: ['Teacher', 'Student']
`Date` datetime, -- example: ['2010-07-19 19:39:07.0', '2010-07-19 19:39:08.0']
PRIMARY KEY (Id),
CONSTRAINT fk_badges_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE comments (
Id integer, -- example: [1, 2]
PostId integer, -- example: [3, 5]
Score integer, -- example: [5, 0]
Text text, -- example: ['Could be a poster child fo argumentative', "Yes, R is nice- but WHY is it 'valuable'"]
CreationDate datetime, -- example: ['2010-07-19 19:15:52.0', '2010-07-19 19:16:14.0']
UserId integer, -- example: [13, 37]
UserDisplayName text, -- example: ['user28', 'Statprof']
PRIMARY KEY (Id),
CONSTRAINT fk_comments_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_comments_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE postHistory (
Id integer, -- example: [1, 2]
PostHistoryTypeId integer, -- example: [2, 1]
PostId integer, -- example: [1, 2]
RevisionGUID text, -- example: ['e58bf7fd-e60f-4c58-a6e4-dfc91cf98a69', '18bf9150-f1cb-432d-b7b7-26d2f8e33581']
CreationDate datetime, -- example: ['2010-07-19 19:12:12.0', '2010-07-19 19:12:57.0']
UserId integer, -- example: [8, 24]
Text text, -- example: ['How should I elicit prior distributions ', 'Eliciting priors from experts']
`Comment` text, -- example: ['more ', 'more', 'edited tags', 'add content from the comments;']
UserDisplayName text, -- example: ['User', 'user28', 'user209']
PRIMARY KEY (Id),
CONSTRAINT fk_posthistory_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_posthistory_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE postLinks (
Id integer, -- example: [108, 145]
CreationDate datetime, -- example: ['2010-07-21 14:47:33.0', '2010-07-23 16:30:41.0']
PostId integer, -- example: [395, 548]
RelatedPostId integer, -- example: [173, 539]
LinkTypeId integer, -- example: [1, 3]
PRIMARY KEY (Id),
CONSTRAINT fk_postlinks_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_postlinks_relatedpostid FOREIGN KEY (RelatedPostId) REFERENCES posts (Id)
);
CREATE TABLE posts (
Id integer, -- example: [1, 2]
PostTypeId integer, -- example: [1, 2]
AcceptedAnswerId integer, -- example: [15, 59]
CreaionDate datetime, -- Creation Date, example: ['2010-07-19 19:12:12.0', '2010-07-19 19:12:57.0']
Score integer, -- example: [23, 22]
ViewCount integer, -- example: [1278, 8198]
Body text, -- example: ['<p>How should I elicit prior distributio', '<p>In many different statistical methods']
OwnerUserId integer, -- example: [8, 24]
LasActivityDate datetime, -- Last Activity Date, example: ['2010-09-15 21:08:26.0', '2012-11-12 09:21:54.0']
Title text, -- example: ['Eliciting priors from experts', 'What is normality?']
Tags text, -- example: ['<bayesian><prior><elicitation>', '<distributions><normality>']
AnswerCount integer, -- example: [5, 7]
CommentCount integer, -- example: [1, 4]
FavoriteCount integer, -- example: [14, 8]
LastEditorUserId integer, -- example: [88, 183]
LastEditDate datetime, -- example: ['2010-08-07 17:56:44.0', '2011-02-12 05:50:03.0']
CommunityOwnedDate datetime, -- example: ['2010-07-19 19:13:28.0', '2010-07-19 19:14:43.0']
ParentId integer, -- example: [3, 7]
ClosedDate datetime, -- example: ['2010-07-19 20:19:46.0', '2010-08-05 13:06:12.0']
OwnerDisplayName text, -- example: ['User', 'user28', 'user209']
LastEditorDisplayName text, -- example: ['user28', 'user10525']
PRIMARY KEY (Id),
CONSTRAINT fk_posts_owneruserid FOREIGN KEY (OwnerUserId) REFERENCES users (Id),
CONSTRAINT fk_posts_lasteditoruserid FOREIGN KEY (LastEditorUserId) REFERENCES users (Id),
CONSTRAINT fk_posts_parentid FOREIGN KEY (ParentId) REFERENCES posts (Id)
);
CREATE TABLE tags (
Id integer, -- example: [1, 2]
TagName text, -- example: ['bayesian', 'prior']
`Count` integer, -- example: [1342, 168]
ExcerptPostId integer, -- example: [20258, 62158]
WikiPostId integer, -- example: [20257, 62157]
PRIMARY KEY (Id),
CONSTRAINT fk_tags_excerptpostid FOREIGN KEY (ExcerptPostId) REFERENCES posts (Id)
);
CREATE TABLE users (
Id integer, -- example: [-1, 2]
Reputation integer, -- example: [1, 101]
CreationDate datetime, -- example: ['2010-07-19 06:55:26.0', '2010-07-19 14:01:36.0']
DisplayName text, -- example: ['User', 'useR', 'user', 'Community', 'Geoff Dalgas']
LastAccessDate datetime, -- example: ['2010-07-19 06:55:26.0', '2013-11-12 22:07:23.0']
WebsiteUrl text, -- example: ['http://meta.stackexchange.com/', 'http://stackoverflow.com']
Location text, -- example: ['on the server farm', 'Corvallis, OR']
AboutMe text, -- example: ["<p>Hi, I'm not really a person.</p>\n\n<p>", '<p>Developer on the StackOverflow team. ']
Views integer, -- example: [0, 25]
UpVotes integer, -- example: [5007, 3]
DownVotes integer, -- example: [1920, 0]
AccountId integer, -- example: [-1, 2]
Age integer, -- example: [37, 35]
ProfileImageUrl text, -- example: ['http://i.stack.imgur.com/d1oHX.jpg', 'http://i.stack.imgur.com/km1pr.jpg']
PRIMARY KEY (Id)
);
CREATE TABLE votes (
Id integer, -- example: [1, 2]
PostId integer, -- example: [3, 2]
VoteTypeId integer, -- example: [2, 5]
CreationDate date, -- example: ['2010-07-19', '2010-07-20']
UserId integer, -- example: [58, 6]
BountyAmount integer, -- example: [50, 25]
PRIMARY KEY (Id),
CONSTRAINT fk_votes_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_votes_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
more than 10 views refers to Views > 10; created after the year 2013 refers to year (CreationDate) > 2013
How many users with more than 10 views created their account after the year 2013?
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```sql
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.
\ No newline at end of file
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from pathlib import Path
project_dir = str(Path(__file__).resolve().parent.parent)
input_prompt_template = '''Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.'''
db_details = """
CREATE TABLE cards (
id integer, -- unique id number identifying the cards, example: [41138, 1349]
artist text, -- example: ['Pete Venters', 'Volkan Baǵa']
asciiName text, -- example: ['El-Hajjaj', 'Junun Efreet']
availability text, -- example: ['mtgo,paper', 'paper']
borderColor text, -- example: ['black', 'white']
cardKingdomFoilId text, -- example: ['123094', '123095']
cardKingdomId text, -- example: ['122719', '122720']
colorIdentity text, -- example: ['W', 'B']
colorIndicator text, -- example: ['U', 'G']
colors text, -- example: ['W', 'B']
convertedManaCost real, -- example: [7.0, 5.0]
duelDeck text, -- example: ['a', 'b']
edhrecRank integer, -- rec Rank in edh, example: [15650, 12702]
faceConvertedManaCost real, -- example: [4.0, 5.0]
faceName text, -- example: ['Dusk', 'Dawn']
flavorName text, -- example: ['Godzilla, King of the Monsters', 'King Caesar, Ancient Guardian']
flavorText text, -- example: ['Every tear shed is a drop of immortality', 'The perfect antidote for a tightly packe']
frameEffects text, -- example: ['legendary', 'nyxtouched']
frameVersion text, -- example: ['2003', '1993']
hand text, -- example: ['1', '0']
hasAlternativeDeckLimit integer, -- example: [0, 1]
hasContentWarning integer, -- example: [0, 1]
hasFoil integer, -- example: [0, 1]
hasNonFoil integer, -- example: [1, 0]
isAlternative integer, -- example: [0, 1]
isFullArt integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isOversized integer, -- example: [0, 1]
isPromo integer, -- is Promotion, example: [0, 1]
isReprint integer, -- example: [1, 0]
isReserved integer, -- example: [0, 1]
isStarter integer, -- example: [0, 1]
isStorySpotlight integer, -- example: [0, 1]
isTextless integer, -- example: [0, 1]
isTimeshifted integer, -- example: [0, 1]
keywords text, -- example: ['First strike', 'Flying']
layout text, -- example: ['normal', 'aftermath']
leadershipSkills text, -- example: ["{'brawl': False, 'commander': True, 'oat", "{'brawl': False, 'commander': False, 'oa"]
life text, -- example: ['-5', '-1']
loyalty text, -- example: ['6', '3']
manaCost text, -- example: ['{5}{W}{W}', '{4}{W}']
mcmId text, -- example: ['16165', '16166']
mcmMetaId text, -- example: ['156', '176']
mtgArenaId text, -- example: ['74983', '74986']
mtgjsonV4Id text, -- example: ['ad41be73-582f-58ed-abd4-a88c1f616ac3', '9eb2e54c-a12b-5e88-a9c0-d8c84c52d59c']
mtgoFoilId text, -- example: ['27501', '26993']
mtgoId text, -- example: ['27500', '26992']
multiverseId text, -- example: ['130550', '129465']
name text, -- example: ["Ancestor's Chosen", 'Angel of Mercy']
number text, -- example: ['1', '2']
originalReleaseDate text, -- example: ['2012/12/1', '2006/12/1']
originalText text, -- example: ['First strike (This creature deals combat', "Flying (This creature can't be blocked e"]
originalType text, -- example: ['Creature - Human Cleric', 'Creature - Angel']
otherFaceIds text, -- example: ['87f0062a-8321-5c16-960e-a12ce1df5839', 'f9f10d34-071c-57a6-b58c-7553abad5c20']
power text, -- example: ['4', '3']
printings text, -- example: ['10E,JUD,UMA', '10E,8ED,9ED,DDC,DVD,IMA,INV,JMP,MB1,P02,']
promoTypes text, -- example: ['boxtopper,boosterfun', 'boosterfun']
purchaseUrls text, -- example: ["{'cardKingdom': 'https://mtgjson.com/lin"]
rarity text, -- example: ['uncommon', 'common']
scryfallId text, -- example: ['7a5cd03c-4227-4551-aa4b-7d119f0468b5', '8f7980d4-da43-4d6d-ad16-14b8a34ae91d']
scryfallIllustrationId text, -- example: ['be2f7173-c8b7-4172-a388-9b2c6b3c16e5', 'e4d6c53f-e936-4be8-8b70-47c2be863b20']
scryfallOracleId text, -- example: ['fc2ccab7-cab1-4463-b73d-898070136d74', 'a2daaf32-dbfe-4618-892e-0da24f63a44a']
setCode text, -- example: ['10E', '2ED']
side text, -- example: ['a', 'b']
subtypes text, -- example: ['Human,Cleric', 'Angel']
supertypes text, -- example: ['Legendary', 'Basic']
tcgplayerProductId text, -- example: ['15032', '15033']
text text, -- example: ['First strike (This creature deals combat', 'Flying\nWhen Angel of Mercy enters the ba']
toughness text, -- example: ['4', '3']
type text, -- example: ['Creature — Human Cleric', 'Creature — Angel']
types text, -- example: ['Creature', 'Instant']
uuid text, -- example: ['00010d56-fe38-5e35-8aed-518019aa36a5', '0001e0d0-2dcd-5640-aadc-a84765cf5fc9']
variations text, -- example: ['b7c19924-b4bf-56fc-aa73-f586e940bd42', '8fd4e2eb-3eb4-50ea-856b-ef638fa47f8a']
watermark text, -- example: ['set', 'set (HOU)', 'set (LGN)']
PRIMARY KEY (id)
);
CREATE TABLE foreign_data (
id integer, -- example: [1, 2]
flavorText text, -- example: ['„Es ist der Wille aller, und meine Hand,', '"La voluntad de todos, realizada por mi ']
`language` text, -- example: ['Italian', 'German', 'Spanish']
multiverseid integer, -- example: [148411, 150317]
name text, -- example: ['Ausgewählter der Ahnfrau', 'Elegido de la Antepasada']
text text, -- example: ['Erstschlag (Diese Kreatur fügt Kampfscha', 'Daña primero. (Esta criatura hace daño d']
type text, -- example: ['Kreatur — Mensch, Kleriker', 'Criatura — Clérigo humano']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_foreign_data_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE legalities (
id integer, -- example: [1, 2]
format text, -- example: ['commander', 'duel']
status text, -- example: ['Legal', 'Banned']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_legalities_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE sets (
id integer, -- example: [1, 2]
baseSetSize integer, -- example: [383, 302]
block text, -- example: ['Core Set', 'Mirrodin']
booster text, -- example: ["{'default': {'boosters': [{'contents': {"]
code text, -- example: ['10E', '2ED']
isFoilOnly integer, -- example: [0, 1]
isForeignOnly integer, -- example: [0, 1]
isNonFoilOnly integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isPartialPreview integer, -- example: [0, 1]
keyruneCode text, -- example: ['10E', '2ED']
mcmId integer, -- magic card market id, example: [74, 3204]
mcmIdExtras integer, -- magic card market ID Extras, example: [3209, 3459]
mcmName text, -- magic card market name, example: ['Tenth Edition', 'Double Masters']
mtgoCode text, -- magic the gathering online code, example: ['10E', '2XM']
name text, -- example: ['Tenth Edition', 'Unlimited Edition']
parentCode text, -- example: ['JMP', 'MH1']
releaseDate date, -- example: ['2007-07-13', '1993-12-01']
tcgplayerGroupId integer, -- example: [1, 115]
totalSetSize integer, -- example: [508, 302]
type text, -- example: ['core', 'masters']
PRIMARY KEY (id)
);
CREATE TABLE set_translations (
id integer, -- example: [1, 2]
`language` text, -- example: ['Italian', 'Chinese Simplified', 'Chinese Traditional']
setCode text, -- example: ['10E', '4ED']
translation text, -- example: ['核心系列第十版', 'Dixième édition']
PRIMARY KEY (id),
CONSTRAINT fk_set_translations_setcode FOREIGN KEY (setCode) REFERENCES sets (code)
);
CREATE TABLE rulings (
id integer, -- example: [1, 2]
`date` date, -- example: ['2007-07-15', '2007-02-01']
text text, -- example: ['You draw the card when Bandage resolves,', 'If you double a negative life total, you']
uuid text, -- example: ['6d268c95-c176-5766-9a46-c14f739aba1c', '56f4935b-f6c5-59b9-88bf-9bcce20247ce']
PRIMARY KEY (id),
CONSTRAINT fk_rulings_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
"""
question = """
Italian translation refers to language = 'Italian'; have a translation means translation is not null; base set number of under 100 refers to baseSetSize < 10
Among the sets of cards that have an Italian translation, how many of them have a base set number of under 100?
"""
prompt = input_prompt_template.format(db_details = db_details, question = question)
model_path = os.path.join(project_dir, "ckpts", "OmniSQL-7B")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
).to("cuda:0")
chat_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt = True, tokenize = False
)
inputs = tokenizer([chat_prompt], return_tensors="pt")
inputs = inputs.to(model.device)
output_ids = model.generate(
**inputs,
eos_token_id = tokenizer.eos_token_id,
max_new_tokens = 2048
)
input_len = len(inputs.input_ids[0])
output_ids = output_ids[0][input_len:]
response = tokenizer.batch_decode([output_ids], skip_special_tokens = True)[0]
print(response)
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os
from pathlib import Path
project_dir = str(Path(__file__).resolve().parent.parent)
input_prompt_template = '''Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.'''
db_details = """
CREATE TABLE cards (
id integer, -- unique id number identifying the cards, example: [41138, 1349]
artist text, -- example: ['Pete Venters', 'Volkan Baǵa']
asciiName text, -- example: ['El-Hajjaj', 'Junun Efreet']
availability text, -- example: ['mtgo,paper', 'paper']
borderColor text, -- example: ['black', 'white']
cardKingdomFoilId text, -- example: ['123094', '123095']
cardKingdomId text, -- example: ['122719', '122720']
colorIdentity text, -- example: ['W', 'B']
colorIndicator text, -- example: ['U', 'G']
colors text, -- example: ['W', 'B']
convertedManaCost real, -- example: [7.0, 5.0]
duelDeck text, -- example: ['a', 'b']
edhrecRank integer, -- rec Rank in edh, example: [15650, 12702]
faceConvertedManaCost real, -- example: [4.0, 5.0]
faceName text, -- example: ['Dusk', 'Dawn']
flavorName text, -- example: ['Godzilla, King of the Monsters', 'King Caesar, Ancient Guardian']
flavorText text, -- example: ['Every tear shed is a drop of immortality', 'The perfect antidote for a tightly packe']
frameEffects text, -- example: ['legendary', 'nyxtouched']
frameVersion text, -- example: ['2003', '1993']
hand text, -- example: ['1', '0']
hasAlternativeDeckLimit integer, -- example: [0, 1]
hasContentWarning integer, -- example: [0, 1]
hasFoil integer, -- example: [0, 1]
hasNonFoil integer, -- example: [1, 0]
isAlternative integer, -- example: [0, 1]
isFullArt integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isOversized integer, -- example: [0, 1]
isPromo integer, -- is Promotion, example: [0, 1]
isReprint integer, -- example: [1, 0]
isReserved integer, -- example: [0, 1]
isStarter integer, -- example: [0, 1]
isStorySpotlight integer, -- example: [0, 1]
isTextless integer, -- example: [0, 1]
isTimeshifted integer, -- example: [0, 1]
keywords text, -- example: ['First strike', 'Flying']
layout text, -- example: ['normal', 'aftermath']
leadershipSkills text, -- example: ["{'brawl': False, 'commander': True, 'oat", "{'brawl': False, 'commander': False, 'oa"]
life text, -- example: ['-5', '-1']
loyalty text, -- example: ['6', '3']
manaCost text, -- example: ['{5}{W}{W}', '{4}{W}']
mcmId text, -- example: ['16165', '16166']
mcmMetaId text, -- example: ['156', '176']
mtgArenaId text, -- example: ['74983', '74986']
mtgjsonV4Id text, -- example: ['ad41be73-582f-58ed-abd4-a88c1f616ac3', '9eb2e54c-a12b-5e88-a9c0-d8c84c52d59c']
mtgoFoilId text, -- example: ['27501', '26993']
mtgoId text, -- example: ['27500', '26992']
multiverseId text, -- example: ['130550', '129465']
name text, -- example: ["Ancestor's Chosen", 'Angel of Mercy']
number text, -- example: ['1', '2']
originalReleaseDate text, -- example: ['2012/12/1', '2006/12/1']
originalText text, -- example: ['First strike (This creature deals combat', "Flying (This creature can't be blocked e"]
originalType text, -- example: ['Creature - Human Cleric', 'Creature - Angel']
otherFaceIds text, -- example: ['87f0062a-8321-5c16-960e-a12ce1df5839', 'f9f10d34-071c-57a6-b58c-7553abad5c20']
power text, -- example: ['4', '3']
printings text, -- example: ['10E,JUD,UMA', '10E,8ED,9ED,DDC,DVD,IMA,INV,JMP,MB1,P02,']
promoTypes text, -- example: ['boxtopper,boosterfun', 'boosterfun']
purchaseUrls text, -- example: ["{'cardKingdom': 'https://mtgjson.com/lin"]
rarity text, -- example: ['uncommon', 'common']
scryfallId text, -- example: ['7a5cd03c-4227-4551-aa4b-7d119f0468b5', '8f7980d4-da43-4d6d-ad16-14b8a34ae91d']
scryfallIllustrationId text, -- example: ['be2f7173-c8b7-4172-a388-9b2c6b3c16e5', 'e4d6c53f-e936-4be8-8b70-47c2be863b20']
scryfallOracleId text, -- example: ['fc2ccab7-cab1-4463-b73d-898070136d74', 'a2daaf32-dbfe-4618-892e-0da24f63a44a']
setCode text, -- example: ['10E', '2ED']
side text, -- example: ['a', 'b']
subtypes text, -- example: ['Human,Cleric', 'Angel']
supertypes text, -- example: ['Legendary', 'Basic']
tcgplayerProductId text, -- example: ['15032', '15033']
text text, -- example: ['First strike (This creature deals combat', 'Flying\nWhen Angel of Mercy enters the ba']
toughness text, -- example: ['4', '3']
type text, -- example: ['Creature — Human Cleric', 'Creature — Angel']
types text, -- example: ['Creature', 'Instant']
uuid text, -- example: ['00010d56-fe38-5e35-8aed-518019aa36a5', '0001e0d0-2dcd-5640-aadc-a84765cf5fc9']
variations text, -- example: ['b7c19924-b4bf-56fc-aa73-f586e940bd42', '8fd4e2eb-3eb4-50ea-856b-ef638fa47f8a']
watermark text, -- example: ['set', 'set (HOU)', 'set (LGN)']
PRIMARY KEY (id)
);
CREATE TABLE foreign_data (
id integer, -- example: [1, 2]
flavorText text, -- example: ['„Es ist der Wille aller, und meine Hand,', '"La voluntad de todos, realizada por mi ']
`language` text, -- example: ['Italian', 'German', 'Spanish']
multiverseid integer, -- example: [148411, 150317]
name text, -- example: ['Ausgewählter der Ahnfrau', 'Elegido de la Antepasada']
text text, -- example: ['Erstschlag (Diese Kreatur fügt Kampfscha', 'Daña primero. (Esta criatura hace daño d']
type text, -- example: ['Kreatur — Mensch, Kleriker', 'Criatura — Clérigo humano']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_foreign_data_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE legalities (
id integer, -- example: [1, 2]
format text, -- example: ['commander', 'duel']
status text, -- example: ['Legal', 'Banned']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_legalities_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE sets (
id integer, -- example: [1, 2]
baseSetSize integer, -- example: [383, 302]
block text, -- example: ['Core Set', 'Mirrodin']
booster text, -- example: ["{'default': {'boosters': [{'contents': {"]
code text, -- example: ['10E', '2ED']
isFoilOnly integer, -- example: [0, 1]
isForeignOnly integer, -- example: [0, 1]
isNonFoilOnly integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isPartialPreview integer, -- example: [0, 1]
keyruneCode text, -- example: ['10E', '2ED']
mcmId integer, -- magic card market id, example: [74, 3204]
mcmIdExtras integer, -- magic card market ID Extras, example: [3209, 3459]
mcmName text, -- magic card market name, example: ['Tenth Edition', 'Double Masters']
mtgoCode text, -- magic the gathering online code, example: ['10E', '2XM']
name text, -- example: ['Tenth Edition', 'Unlimited Edition']
parentCode text, -- example: ['JMP', 'MH1']
releaseDate date, -- example: ['2007-07-13', '1993-12-01']
tcgplayerGroupId integer, -- example: [1, 115]
totalSetSize integer, -- example: [508, 302]
type text, -- example: ['core', 'masters']
PRIMARY KEY (id)
);
CREATE TABLE set_translations (
id integer, -- example: [1, 2]
`language` text, -- example: ['Italian', 'Chinese Simplified', 'Chinese Traditional']
setCode text, -- example: ['10E', '4ED']
translation text, -- example: ['核心系列第十版', 'Dixième édition']
PRIMARY KEY (id),
CONSTRAINT fk_set_translations_setcode FOREIGN KEY (setCode) REFERENCES sets (code)
);
CREATE TABLE rulings (
id integer, -- example: [1, 2]
`date` date, -- example: ['2007-07-15', '2007-02-01']
text text, -- example: ['You draw the card when Bandage resolves,', 'If you double a negative life total, you']
uuid text, -- example: ['6d268c95-c176-5766-9a46-c14f739aba1c', '56f4935b-f6c5-59b9-88bf-9bcce20247ce']
PRIMARY KEY (id),
CONSTRAINT fk_rulings_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
"""
question = """
Italian translation refers to language = 'Italian'; have a translation means translation is not null; base set number of under 100 refers to baseSetSize < 10
Among the sets of cards that have an Italian translation, how many of them have a base set number of under 100?
"""
prompt = input_prompt_template.format(db_details = db_details, question = question)
model_path = os.path.join(project_dir, "ckpts", "OmniSQL-7B")
tokenizer = AutoTokenizer.from_pretrained(model_path)
sampling_params = SamplingParams(
temperature = 0,
max_tokens = 2048,
n = 1
)
llm = LLM(
model = model_path,
dtype = "float16",
tensor_parallel_size = 1,
max_model_len = 8192,
gpu_memory_utilization = 0.92,
swap_space = 8,
enforce_eager = True,
disable_custom_all_reduce = True,
trust_remote_code = True
)
chat_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt = True, tokenize = False
)
outputs = llm.generate([chat_prompt], sampling_params)
for output in outputs:
responses = [o.text for o in output.outputs]
print(responses[0])
#!/bin/bash
curl http://10.16.5.2:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ckpts/OmniSQL-7B",
"messages": [{"role": "user", "content": "Task Overview:\nYou are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.\n\nDatabase Engine:\nSQLite\n\nDatabase Schema:\nTable: cards(id, name, language, translation, baseSetSize)\n\nQuestion:\nAmong the sets of cards that have an Italian translation, how many of them have a base set number of under 100?\n\nInstructions:\n- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.\n- The generated query should return all of the information asked in the question without any missing or extra information.\n- Before generating the final SQL query, please think through the steps of how to write the query.\n\nOutput Format:\nIn your answer, please enclose the generated SQL query in a code block:\n```sql\n-- Your SQL query\n```\n\nTake a deep breath and think step by step to find the correct SQL query."}],
"max_tokens": 1024,
"temperature": 0
}'
\ No newline at end of file
# 模型唯一标识
modelCode=1499
# 模型名称
modelName=OmniSQL_pytorch
# 模型描述
modelDescription=数据库问答模型
# 应用场景
appScenario=训练,推理,对话问答,电商,教育,交通,能源
# 框架类型
frameType=Pytorch
# OmniSQL Training and Evaluation
## Environment Setup
All experiments were conducted using:
- **Anaconda 3**
- **Python 3.9.5**
- **8 x NVIDIA A800 80GB GPUs**
**Note:** A single A800 80GB GPU is sufficient for inference and evaluation. For training OmniSQL from scratch, 8 x A800 80GB GPUs are recommended.
## Dataset Preparation
### Download
Download the datasets from:
- [ModelScope-OmniSQL-datasets](https://modelscope.cn/datasets/seeklhy/OmniSQL-datasets/summary)
- [HuggingFace-OmniSQL-datasets](https://huggingface.co/datasets/seeklhy/OmniSQL-datasets)
The datasets include BIRD, Spider, ScienceBenchmark, EHRSQL, Spider2-SQLite, Spider-DK, Spider-Realistic, Spider-Syn, and SynSQL-2.5M. Unzip `data.zip` in this folder.
### Pre-processing
The pre-processed datasets are included in `data.zip` (see the `*.json` files). You can also reproduce the pre-processing steps if needed.
1. **Set Up Environment:**
```sh
conda create -n omnisql_process_data python=3.9.5
conda activate omnisql_process_data
apt-get update
apt-get install -y openjdk-11-jdk
pip3 install func_timeout ijson pyserini==0.22.1 faiss-cpu torch==2.1.0 numpy==1.24.3 nltk==3.8.1
python3 nltk_downloader.py
```
2. **Run Pre-processing Scripts:**
```sh
# Build BM25 index for database values
python3 build_contents_index.py
# Prepare input-output sequences
sh process_dataset.sh
```
**Note:** Processing SynSQL-2.5M may take over 24 hours due to its size (~2.5 million samples).
## Evaluation Reproduction
You can easily reproduce our evaluation results as follows:
1. **Set Up Environment:**
```sh
conda create -n omnisql_eval python=3.9.5
conda activate omnisql_eval
pip3 install vllm==0.6.3.post1 func_timeout tqdm matplotlib nltk==3.8.1 sqlparse
python3 nltk_downloader.py
```
2. **Download Evaluation Materials:**
Download Spider's test-suite databases and evaluation scripts from [test_suite_sql_eval.zip](https://drive.google.com/file/d/1iNa1WgA9tN_OFna08nq_tHZdXx9Lz2vO/view) and unzip `test_suite_sql_eval.zip` in this folder.
3. **Run Evaluation:**
```python
python3 eval_open_source_models.py
```
Predicted SQL queries are saved in the `results` folder, and evaluation results (e.g., model accuracy) are stored in the `evaluation_results` folder.
## Training OmniSQL from Scratch
To train OmniSQL from scratch:
1. **Set Up Environment:**
```sh
conda create -n omnisql_train python=3.9.5
conda activate omnisql_train
pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 transformers==4.45.1 accelerate==0.34.2 deepspeed==0.10.3 numpy==1.24.3 peft datasets tensorboard ijson
```
To speed up attention calculation, install flash-attention:
```bash
# Build from source (not recommended)
pip3 install flash-attn==2.5.8 --no-build-isolation
```
It's recommended to download a precompiled flash-attn Wheel from [flash-attn-2.5.8](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8). Choose the appropriate `.whl` file based on your environment: `flash_attn-2.5.8+cu{cuda_version}torch{torch_version}cxx11abiFALSE-cp{python_version}-cp{python_version}-linux_x86_64.whl`.
For example, if your CUDA version is 12.2, PyTorch version is 2.1, and Python version is 3.9.5, download `flash_attn-2.5.8+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl` and install it using `pip3 install`.
2. **Training Scripts:**
```sh
# train OmniSQL-7B using SynSQL-2.5M
sh train_omnisql_7b.sh
# train OmniSQL-14B using SynSQL-2.5M
sh train_omnisql_14b.sh
# train OmniSQL-32B using SynSQL-2.5M
sh train_omnisql_32b.sh
```
To train the full version of OmniSQL, you should manually merge the three training sets (`./data/train_synsql.json`, `./data/train_bird.json`, and `./data/train_spider.json`) and update the `DATASET_DIR` in the scripts. For OmniSQL-32B, you can merge LoRA adapters into the base model using `merge_lora_adapter.py`.
**Note:** Training OmniSQL from scratch is resource and time-intensive. As reported in our paper, training OmniSQL-7B/14B/32B requires approximately 6, 12, and 20 days, respectively, on a single machine equipped with 8 NVIDIA A800 80GB GPUs. Please consider whether you need to train them again. **We encourage using our open-sourced OmniSQL models directly or continuing to train your text-to-SQL model with a smaller dataset based on OmniSQL.**
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 64
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 32
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 64
gradient_clipping: 1.0
offload_optimizer_device: 'cpu'
offload_param_device: 'cpu'
zero3_init_flag: false
zero_stage: 3
zero3_save_16bit_model: true
distributed_type: DEEPSPEED
downcast_bf16: 'true'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
import os
import json
import argparse
import evaluate_bird
import evaluate_spider2
import evaluate_spider
import matplotlib.pyplot as plt
from tqdm import tqdm
def visualize(eval_name, acc_dict, ylabel, file_path):
plt.figure(figsize=(10, 6))
ckpt_ids = list(range(len(acc_dict)))
values = list(acc_dict.values())
if isinstance(values[0], list): # Spider has two metrics: EX acc and TS acc
num_lines = len(values[0])
labels = ["EX", "TS"]
assert num_lines == len(labels)
for i in range(num_lines):
line_values = [v[i] for v in values]
plt.plot(ckpt_ids, line_values, marker='o', linestyle='-', label=labels[i])
else:
plt.plot(ckpt_ids, values, marker='o', linestyle='-', label="EX")
plt.title(eval_name)
plt.xlabel('ckpt-id')
plt.ylabel(ylabel)
plt.grid(True)
plt.legend()
plt.savefig(file_path)
plt.close()
def save_evaluation_results(file_path, acc_dict):
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(acc_dict, indent=2, ensure_ascii=False))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_ckpt_dir", type = str, default = "./ckpts")
parser.add_argument('--multiple_models', action='store_true', help='Evaluate multiple models from a folder.')
parser.add_argument("--source", type = str, default = "bird")
parser.add_argument("--visible_devices", type = str, default = "0,1")
parser.add_argument("--input_file", type = str, help = "input file path (prompts)")
parser.add_argument("--eval_name", type = str, help = "name of the evaluation set")
parser.add_argument("--tensor_parallel_size", type = int, help = "the number of used GPUs", default = 1)
parser.add_argument("--n", type = int, help = "sampling number", default = 16)
parser.add_argument("--gold_file", type = str, help = "gold sql path")
parser.add_argument("--db_path", type = str, help = "database path")
parser.add_argument("--ts_db_path", type = str, default = "", help = "test suite database path (required by Spider)")
parser.add_argument("--gold_result_dir", type = str, help = "gold sql execution results (required by Spider2.0)")
parser.add_argument("--eval_standard", type = str, help = "evaluation standard (required by Spider2.0)")
opt = parser.parse_args()
print(opt)
assert opt.source in ["spider", "bird", "spider2.0"]
if opt.multiple_models:
ckpt_ids = os.listdir(opt.output_ckpt_dir)
ckpt_ids = sorted(ckpt_ids, key=lambda x: int(x.split("-")[1]))
print(ckpt_ids)
else:
ckpt_ids = [""]
greedy_search_acc_dict = dict()
pass_at_k_acc_dict = dict()
major_voting_acc_dict = dict()
os.makedirs(os.path.join("results", opt.eval_name), exist_ok=True)
os.makedirs(os.path.join("evaluation_results", opt.eval_name), exist_ok=True)
for ckpt_id in tqdm(ckpt_ids):
print("Evaluating ckpt:", ckpt_id)
if ckpt_id not in greedy_search_acc_dict.keys():
# greedy decoding
gs_pred_file = f"results/{opt.eval_name}/greedy_search_{ckpt_id}.json"
greedy_search_cmd = f"CUDA_VISIBLE_DEVICES={opt.visible_devices} python3 infer.py \
--pretrained_model_name_or_path {os.path.join(opt.output_ckpt_dir, ckpt_id)} \
--input_file {opt.input_file} \
--output_file {gs_pred_file} \
--tensor_parallel_size {opt.tensor_parallel_size} \
--n 1 \
--temperature 0.0"
os.system(greedy_search_cmd)
# evaluate greedy search
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("greedy_search", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, gs_pred_file, opt.db_path, True)
# record evaluation results
gs_acc, _ = evaluate_spider2.evaluate("greedy_search", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, gs_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, gs_pred_file, opt.db_path, "greedy_search", True)
# record evaluation results
gs_acc, _ = evaluate_bird.run_eval(opt.gold_file, gs_pred_file, opt.db_path, "greedy_search", True)
elif opt.source == "spider": # for "spider"
# warm up
evaluate_spider.run_spider_eval(opt.gold_file, gs_pred_file, opt.db_path,
opt.ts_db_path, "greedy_search", True)
# record evaluation results
ex_score, ts_score = evaluate_spider.run_spider_eval(opt.gold_file, gs_pred_file, opt.db_path,
opt.ts_db_path, "greedy_search", True)
if ts_score is None:
gs_acc = ex_score
else:
gs_acc = [ex_score, ts_score]
greedy_search_acc_dict[ckpt_id] = gs_acc
print(greedy_search_acc_dict)
visualize(opt.eval_name, greedy_search_acc_dict, "greedy_search",
os.path.join("evaluation_results", opt.eval_name, "greedy_search.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "greedy_search.json"), greedy_search_acc_dict)
else:
print(f"skip {ckpt_id} greedy search")
if ckpt_id not in major_voting_acc_dict.keys():
# sampling
sampling_pred_file = f"results/{opt.eval_name}/sampling_{ckpt_id}.json"
sampling_cmd = f"CUDA_VISIBLE_DEVICES={opt.visible_devices} python3 infer.py \
--pretrained_model_name_or_path {os.path.join(opt.output_ckpt_dir, ckpt_id)} \
--input_file {opt.input_file} \
--output_file {sampling_pred_file} \
--tensor_parallel_size {opt.tensor_parallel_size} \
--n {opt.n} \
--temperature 0.8"
os.system(sampling_cmd)
# evaluate pass@k (we do not evaluate pass@k for spider and its variants)
if opt.source in ["bird", "spider2.0"]:
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("pass@k", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
# record evaluation results
pass_at_k_acc, _ = evaluate_spider2.evaluate("pass@k", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "pass@k", True)
# record evaluation results
pass_at_k_acc, _ = evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "pass@k", True)
pass_at_k_acc_dict[ckpt_id] = pass_at_k_acc
print(pass_at_k_acc_dict)
visualize(opt.eval_name, pass_at_k_acc_dict, "pass_at_k",
os.path.join("evaluation_results", opt.eval_name, "pass_at_k.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "pass_at_k.json"), pass_at_k_acc_dict)
# evaluate major voting
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("major_voting", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
# record evaluation results
major_voting_acc, _ = evaluate_spider2.evaluate("major_voting", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "major_voting", True)
# record evaluation results
major_voting_acc, _ = evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "major_voting", True)
else: # spider
# warm up
evaluate_spider.run_spider_eval(opt.gold_file, sampling_pred_file, opt.db_path,
opt.ts_db_path, "major_voting", True)
# record evaluation results
ex_score, ts_score = evaluate_spider.run_spider_eval(opt.gold_file, sampling_pred_file, opt.db_path,
opt.ts_db_path, "major_voting", True)
if ts_score is None:
major_voting_acc = ex_score
else:
major_voting_acc = [ex_score, ts_score]
major_voting_acc_dict[ckpt_id] = major_voting_acc
print(major_voting_acc_dict)
visualize(opt.eval_name, major_voting_acc_dict, "major_voting",
os.path.join("evaluation_results", opt.eval_name, "major_voting.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "major_voting.json"), major_voting_acc_dict)
else:
print(f"skip {ckpt_id} pass at k and major voting")
\ No newline at end of file
import json
import os, shutil
import sqlite3
from func_timeout import func_set_timeout, FunctionTimedOut
from pathlib import Path
# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path):
try:
if not os.path.exists(sqlite_path):
print("Openning a new connection %s" % sqlite_path)
connection = sqlite3.connect(sqlite_path, check_same_thread = False)
except Exception as e:
print(sqlite_path)
raise e
connection.text_factory = lambda b: b.decode(errors="ignore")
cursor = connection.cursor()
return cursor
# execute predicted sql with a long time limitation (for buiding content index)
@func_set_timeout(3600)
def execute_sql(cursor, sql):
cursor.execute(sql)
return cursor.fetchall()
def remove_contents_of_a_folder(index_path):
# if index_path does not exist, then create it
os.makedirs(index_path, exist_ok = True)
# remove files in index_path
for filename in os.listdir(index_path):
file_path = os.path.join(index_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def build_content_index(db_file_path, index_path):
'''
create BM25 index for all string values in a database
'''
cursor = get_cursor_from_path(db_file_path)
results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
table_names = [result[0] for result in results]
all_column_contents = []
for table_name in table_names:
# skip SQLite system table: sqlite_sequence
if table_name == "sqlite_sequence":
continue
results = execute_sql(cursor, f"SELECT name FROM PRAGMA_TABLE_INFO('{table_name}')")
column_names_in_one_table = [result[0] for result in results]
for column_name in column_names_in_one_table:
try:
print(f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
results = execute_sql(cursor, f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
column_contents = [result[0] for result in results if isinstance(result[0], str) and not is_number(result[0])]
for c_id, column_content in enumerate(column_contents):
# remove empty and extremely-long contents
if len(column_content) != 0 and len(column_content) <= 40:
all_column_contents.append(
{
"id": "{}-**-{}-**-{}".format(table_name, column_name, c_id), # .lower()
"contents": column_content
}
)
except Exception as e:
print(str(e))
os.makedirs('./data/temp_db_index', exist_ok = True)
with open("./data/temp_db_index/contents.json", "w") as f:
f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True))
# Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md
cmd = f'python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index "{index_path}" --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw'
d = os.system(cmd)
print(d)
os.remove("./data/temp_db_index/contents.json")
if __name__ == "__main__":
dataset_info = {
# BIRD train
"bird_train": {"db_path": "./data/bird/train/train_databases", "index_path_prefix": "./data/bird/train/db_contents_index"},
# BIRD dev
"bird_dev": {"db_path": "./data/bird/dev_20240627/dev_databases", "index_path_prefix": "./data/bird/dev_20240627/db_contents_index"},
# Spider train-dev-test
"spider": {"db_path": "./data/spider/test_database", "index_path_prefix": "./data/spider/db_contents_index"},
# Spider2.0-SQLite
"spider2_sqlite": {"db_path": "./data/spider2_sqlite/databases", "index_path_prefix": "./data/spider2_sqlite/db_contents_index"},
# SynSQL-2.5M dataset
"SynSQL-2.5M": {"db_path": "./data/SynSQL-2.5M/databases", "index_path_prefix": "./data/SynSQL-2.5M/db_contents_index"},
# spider-dk
"spider_dk": {"db_path": "./data/Spider-DK/database", "index_path_prefix": "./data/Spider-DK/db_contents_index"},
# EHRSQL_dev
"EHRSQL_dev": {"db_path": "./data/EHRSQL/database", "index_path_prefix": "./data/EHRSQL/db_contents_index"},
# sciencebenchmark_dev
"sciencebenchmark_dev": {"db_path": "./data/sciencebenchmark/databases", "index_path_prefix": "./data/sciencebenchmark/db_contents_index"},
}
for dataset_name in dataset_info:
print(dataset_name)
db_path = dataset_info[dataset_name]["db_path"]
index_path_prefix = dataset_info[dataset_name]["index_path_prefix"]
remove_contents_of_a_folder(index_path_prefix)
# build content index
db_ids = os.listdir(db_path)
# db_ids = ["the_table's_domain_appears_to_be_related_to_demographic_and_employment_data"]
for db_id in db_ids:
db_file_path = os.path.join(db_path, db_id, db_id + ".sqlite")
if os.path.exists(db_file_path) and os.path.isfile(db_file_path):
print(f"The file '{db_file_path}' exists.")
build_content_index(
db_file_path,
os.path.join(index_path_prefix, db_id)
)
else:
print(f"The file '{db_file_path}' does not exist.")
\ No newline at end of file
This diff is collapsed.
import os
models = [
"/home/ckpts/OmniSQL-7B",
#"seeklhy/OmniSQL-14B",
#"seeklhy/OmniSQL-32B",
# "qwen/Qwen2.5-Coder-7B-Instruct",
# "qwen/Qwen2.5-Coder-14B-Instruct",
# "qwen/Qwen2.5-Coder-32B-Instruct",
# "qwen/Qwen2.5-7B-Instruct",
# "qwen/Qwen2.5-14B-Instruct",
# "qwen/Qwen2.5-32B-Instruct",
# "qwen/Qwen2.5-72B-Instruct",
# "meta-llama/Meta-Llama-3.1-8B-Instruct",
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
# "infly/OpenCoder-8B-Instruct",
# "deepseek-ai/deepseek-coder-6.7b-instruct",
# "deepseek-ai/deepseek-coder-33b-instruct",
# "deepseek-ai/deepseek-v3",
# "ibm-granite/granite-34b-code-instruct-8k",
# "ibm-granite/granite-20b-code-instruct-8k",
# "ibm-granite/granite-8b-code-instruct-128k",
# "ibm-granite/granite-3.1-8b-instruct",
# "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
# "bigcode/starcoder2-15b-instruct-v0.1",
# "mistralai/Codestral-22B-v0.1",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
]
visible_devices = "0,1" # visible devices for vLLM
tensor_parallel_size = len(visible_devices.split(","))
for model in models:
model_name = model.split("/")[-1].strip()
spider2_test_eval_name = f"{model_name}_test_spider2_sqlite"
spider2_test_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider2.0 --visible_devices {visible_devices} --input_file ./data/test_spider2_sqlite.json --eval_name {spider2_test_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider2_sqlite/test.json --db_path ./data/spider2_sqlite/databases/ --gold_result_dir ./data/spider2_sqlite/gold_exec_result/ --eval_standard ./data/spider2_sqlite/spider2_sqlite_eval.jsonl"
os.system(spider2_test_evaluation_cmd)
dev_bird_eval_name = f"{model_name}_dev_bird"
dev_bird_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_bird.json --eval_name {dev_bird_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/bird/dev_20240627/dev.json --db_path ./data/bird/dev_20240627/dev_databases"
os.system(dev_bird_evaluation_cmd)
dev_spider_eval_name = f"{model_name}_dev_spider"
dev_spider_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider.json --eval_name {dev_spider_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider/dev_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(dev_spider_evaluation_cmd)
test_spider_eval_name = f"{model_name}_test_spider"
test_spider_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/test_spider.json --eval_name {test_spider_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider/test_gold.sql --db_path ./data/spider/test_database"
os.system(test_spider_evaluation_cmd)
spider_dk_eval_name = f"{model_name}_dev_spider_dk"
spider_dk_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_dk.json --eval_name {spider_dk_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/Spider-DK/spider_dk_gold.sql --db_path ./data/Spider-DK/database"
os.system(spider_dk_evaluation_cmd)
spider_realistic_eval_name = f"{model_name}_dev_spider_realistic"
spider_realistic_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_realistic.json --eval_name {spider_realistic_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider-realistic/spider_realistic_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(spider_realistic_evaluation_cmd)
spider_syn_eval_name = f"{model_name}_dev_spider_syn"
spider_syn_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_syn.json --eval_name {spider_syn_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/Spider-Syn/spider_syn_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(spider_syn_evaluation_cmd)
dev_ehrsql_eval_name = f"{model_name}_dev_ehrsql"
dev_ehrsql_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_ehrsql.json --eval_name {dev_ehrsql_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/EHRSQL/dev.json --db_path ./data/EHRSQL/database"
os.system(dev_ehrsql_evaluation_cmd)
dev_sciencebenchmark_eval_name = f"{model_name}_dev_sciencebenchmark"
dev_sciencebenchmark_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_sciencebenchmark.json --eval_name {dev_sciencebenchmark_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/sciencebenchmark/dev.json --db_path ./data/sciencebenchmark/databases"
os.system(dev_sciencebenchmark_evaluation_cmd)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment