Commit 4f4ba442 authored by mashun1's avatar mashun1
Browse files

omnisql

parents
Pipeline #2643 canceled with stages
Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
CREATE TABLE badges (
Id integer, -- example: [1, 2]
UserId integer, -- example: [5, 6]
Name text, -- example: ['Teacher', 'Student']
`Date` datetime, -- example: ['2010-07-19 19:39:07.0', '2010-07-19 19:39:08.0']
PRIMARY KEY (Id),
CONSTRAINT fk_badges_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE comments (
Id integer, -- example: [1, 2]
PostId integer, -- example: [3, 5]
Score integer, -- example: [5, 0]
Text text, -- example: ['Could be a poster child fo argumentative', "Yes, R is nice- but WHY is it 'valuable'"]
CreationDate datetime, -- example: ['2010-07-19 19:15:52.0', '2010-07-19 19:16:14.0']
UserId integer, -- example: [13, 37]
UserDisplayName text, -- example: ['user28', 'Statprof']
PRIMARY KEY (Id),
CONSTRAINT fk_comments_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_comments_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE postHistory (
Id integer, -- example: [1, 2]
PostHistoryTypeId integer, -- example: [2, 1]
PostId integer, -- example: [1, 2]
RevisionGUID text, -- example: ['e58bf7fd-e60f-4c58-a6e4-dfc91cf98a69', '18bf9150-f1cb-432d-b7b7-26d2f8e33581']
CreationDate datetime, -- example: ['2010-07-19 19:12:12.0', '2010-07-19 19:12:57.0']
UserId integer, -- example: [8, 24]
Text text, -- example: ['How should I elicit prior distributions ', 'Eliciting priors from experts']
`Comment` text, -- example: ['more ', 'more', 'edited tags', 'add content from the comments;']
UserDisplayName text, -- example: ['User', 'user28', 'user209']
PRIMARY KEY (Id),
CONSTRAINT fk_posthistory_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_posthistory_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
CREATE TABLE postLinks (
Id integer, -- example: [108, 145]
CreationDate datetime, -- example: ['2010-07-21 14:47:33.0', '2010-07-23 16:30:41.0']
PostId integer, -- example: [395, 548]
RelatedPostId integer, -- example: [173, 539]
LinkTypeId integer, -- example: [1, 3]
PRIMARY KEY (Id),
CONSTRAINT fk_postlinks_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_postlinks_relatedpostid FOREIGN KEY (RelatedPostId) REFERENCES posts (Id)
);
CREATE TABLE posts (
Id integer, -- example: [1, 2]
PostTypeId integer, -- example: [1, 2]
AcceptedAnswerId integer, -- example: [15, 59]
CreaionDate datetime, -- Creation Date, example: ['2010-07-19 19:12:12.0', '2010-07-19 19:12:57.0']
Score integer, -- example: [23, 22]
ViewCount integer, -- example: [1278, 8198]
Body text, -- example: ['<p>How should I elicit prior distributio', '<p>In many different statistical methods']
OwnerUserId integer, -- example: [8, 24]
LasActivityDate datetime, -- Last Activity Date, example: ['2010-09-15 21:08:26.0', '2012-11-12 09:21:54.0']
Title text, -- example: ['Eliciting priors from experts', 'What is normality?']
Tags text, -- example: ['<bayesian><prior><elicitation>', '<distributions><normality>']
AnswerCount integer, -- example: [5, 7]
CommentCount integer, -- example: [1, 4]
FavoriteCount integer, -- example: [14, 8]
LastEditorUserId integer, -- example: [88, 183]
LastEditDate datetime, -- example: ['2010-08-07 17:56:44.0', '2011-02-12 05:50:03.0']
CommunityOwnedDate datetime, -- example: ['2010-07-19 19:13:28.0', '2010-07-19 19:14:43.0']
ParentId integer, -- example: [3, 7]
ClosedDate datetime, -- example: ['2010-07-19 20:19:46.0', '2010-08-05 13:06:12.0']
OwnerDisplayName text, -- example: ['User', 'user28', 'user209']
LastEditorDisplayName text, -- example: ['user28', 'user10525']
PRIMARY KEY (Id),
CONSTRAINT fk_posts_owneruserid FOREIGN KEY (OwnerUserId) REFERENCES users (Id),
CONSTRAINT fk_posts_lasteditoruserid FOREIGN KEY (LastEditorUserId) REFERENCES users (Id),
CONSTRAINT fk_posts_parentid FOREIGN KEY (ParentId) REFERENCES posts (Id)
);
CREATE TABLE tags (
Id integer, -- example: [1, 2]
TagName text, -- example: ['bayesian', 'prior']
`Count` integer, -- example: [1342, 168]
ExcerptPostId integer, -- example: [20258, 62158]
WikiPostId integer, -- example: [20257, 62157]
PRIMARY KEY (Id),
CONSTRAINT fk_tags_excerptpostid FOREIGN KEY (ExcerptPostId) REFERENCES posts (Id)
);
CREATE TABLE users (
Id integer, -- example: [-1, 2]
Reputation integer, -- example: [1, 101]
CreationDate datetime, -- example: ['2010-07-19 06:55:26.0', '2010-07-19 14:01:36.0']
DisplayName text, -- example: ['User', 'useR', 'user', 'Community', 'Geoff Dalgas']
LastAccessDate datetime, -- example: ['2010-07-19 06:55:26.0', '2013-11-12 22:07:23.0']
WebsiteUrl text, -- example: ['http://meta.stackexchange.com/', 'http://stackoverflow.com']
Location text, -- example: ['on the server farm', 'Corvallis, OR']
AboutMe text, -- example: ["<p>Hi, I'm not really a person.</p>\n\n<p>", '<p>Developer on the StackOverflow team. ']
Views integer, -- example: [0, 25]
UpVotes integer, -- example: [5007, 3]
DownVotes integer, -- example: [1920, 0]
AccountId integer, -- example: [-1, 2]
Age integer, -- example: [37, 35]
ProfileImageUrl text, -- example: ['http://i.stack.imgur.com/d1oHX.jpg', 'http://i.stack.imgur.com/km1pr.jpg']
PRIMARY KEY (Id)
);
CREATE TABLE votes (
Id integer, -- example: [1, 2]
PostId integer, -- example: [3, 2]
VoteTypeId integer, -- example: [2, 5]
CreationDate date, -- example: ['2010-07-19', '2010-07-20']
UserId integer, -- example: [58, 6]
BountyAmount integer, -- example: [50, 25]
PRIMARY KEY (Id),
CONSTRAINT fk_votes_postid FOREIGN KEY (PostId) REFERENCES posts (Id),
CONSTRAINT fk_votes_userid FOREIGN KEY (UserId) REFERENCES users (Id)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
more than 10 views refers to Views > 10; created after the year 2013 refers to year (CreationDate) > 2013
How many users with more than 10 views created their account after the year 2013?
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```sql
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.
\ No newline at end of file
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from pathlib import Path
project_dir = str(Path(__file__).resolve().parent.parent)
input_prompt_template = '''Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.'''
db_details = """
CREATE TABLE cards (
id integer, -- unique id number identifying the cards, example: [41138, 1349]
artist text, -- example: ['Pete Venters', 'Volkan Baǵa']
asciiName text, -- example: ['El-Hajjaj', 'Junun Efreet']
availability text, -- example: ['mtgo,paper', 'paper']
borderColor text, -- example: ['black', 'white']
cardKingdomFoilId text, -- example: ['123094', '123095']
cardKingdomId text, -- example: ['122719', '122720']
colorIdentity text, -- example: ['W', 'B']
colorIndicator text, -- example: ['U', 'G']
colors text, -- example: ['W', 'B']
convertedManaCost real, -- example: [7.0, 5.0]
duelDeck text, -- example: ['a', 'b']
edhrecRank integer, -- rec Rank in edh, example: [15650, 12702]
faceConvertedManaCost real, -- example: [4.0, 5.0]
faceName text, -- example: ['Dusk', 'Dawn']
flavorName text, -- example: ['Godzilla, King of the Monsters', 'King Caesar, Ancient Guardian']
flavorText text, -- example: ['Every tear shed is a drop of immortality', 'The perfect antidote for a tightly packe']
frameEffects text, -- example: ['legendary', 'nyxtouched']
frameVersion text, -- example: ['2003', '1993']
hand text, -- example: ['1', '0']
hasAlternativeDeckLimit integer, -- example: [0, 1]
hasContentWarning integer, -- example: [0, 1]
hasFoil integer, -- example: [0, 1]
hasNonFoil integer, -- example: [1, 0]
isAlternative integer, -- example: [0, 1]
isFullArt integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isOversized integer, -- example: [0, 1]
isPromo integer, -- is Promotion, example: [0, 1]
isReprint integer, -- example: [1, 0]
isReserved integer, -- example: [0, 1]
isStarter integer, -- example: [0, 1]
isStorySpotlight integer, -- example: [0, 1]
isTextless integer, -- example: [0, 1]
isTimeshifted integer, -- example: [0, 1]
keywords text, -- example: ['First strike', 'Flying']
layout text, -- example: ['normal', 'aftermath']
leadershipSkills text, -- example: ["{'brawl': False, 'commander': True, 'oat", "{'brawl': False, 'commander': False, 'oa"]
life text, -- example: ['-5', '-1']
loyalty text, -- example: ['6', '3']
manaCost text, -- example: ['{5}{W}{W}', '{4}{W}']
mcmId text, -- example: ['16165', '16166']
mcmMetaId text, -- example: ['156', '176']
mtgArenaId text, -- example: ['74983', '74986']
mtgjsonV4Id text, -- example: ['ad41be73-582f-58ed-abd4-a88c1f616ac3', '9eb2e54c-a12b-5e88-a9c0-d8c84c52d59c']
mtgoFoilId text, -- example: ['27501', '26993']
mtgoId text, -- example: ['27500', '26992']
multiverseId text, -- example: ['130550', '129465']
name text, -- example: ["Ancestor's Chosen", 'Angel of Mercy']
number text, -- example: ['1', '2']
originalReleaseDate text, -- example: ['2012/12/1', '2006/12/1']
originalText text, -- example: ['First strike (This creature deals combat', "Flying (This creature can't be blocked e"]
originalType text, -- example: ['Creature - Human Cleric', 'Creature - Angel']
otherFaceIds text, -- example: ['87f0062a-8321-5c16-960e-a12ce1df5839', 'f9f10d34-071c-57a6-b58c-7553abad5c20']
power text, -- example: ['4', '3']
printings text, -- example: ['10E,JUD,UMA', '10E,8ED,9ED,DDC,DVD,IMA,INV,JMP,MB1,P02,']
promoTypes text, -- example: ['boxtopper,boosterfun', 'boosterfun']
purchaseUrls text, -- example: ["{'cardKingdom': 'https://mtgjson.com/lin"]
rarity text, -- example: ['uncommon', 'common']
scryfallId text, -- example: ['7a5cd03c-4227-4551-aa4b-7d119f0468b5', '8f7980d4-da43-4d6d-ad16-14b8a34ae91d']
scryfallIllustrationId text, -- example: ['be2f7173-c8b7-4172-a388-9b2c6b3c16e5', 'e4d6c53f-e936-4be8-8b70-47c2be863b20']
scryfallOracleId text, -- example: ['fc2ccab7-cab1-4463-b73d-898070136d74', 'a2daaf32-dbfe-4618-892e-0da24f63a44a']
setCode text, -- example: ['10E', '2ED']
side text, -- example: ['a', 'b']
subtypes text, -- example: ['Human,Cleric', 'Angel']
supertypes text, -- example: ['Legendary', 'Basic']
tcgplayerProductId text, -- example: ['15032', '15033']
text text, -- example: ['First strike (This creature deals combat', 'Flying\nWhen Angel of Mercy enters the ba']
toughness text, -- example: ['4', '3']
type text, -- example: ['Creature — Human Cleric', 'Creature — Angel']
types text, -- example: ['Creature', 'Instant']
uuid text, -- example: ['00010d56-fe38-5e35-8aed-518019aa36a5', '0001e0d0-2dcd-5640-aadc-a84765cf5fc9']
variations text, -- example: ['b7c19924-b4bf-56fc-aa73-f586e940bd42', '8fd4e2eb-3eb4-50ea-856b-ef638fa47f8a']
watermark text, -- example: ['set', 'set (HOU)', 'set (LGN)']
PRIMARY KEY (id)
);
CREATE TABLE foreign_data (
id integer, -- example: [1, 2]
flavorText text, -- example: ['„Es ist der Wille aller, und meine Hand,', '"La voluntad de todos, realizada por mi ']
`language` text, -- example: ['Italian', 'German', 'Spanish']
multiverseid integer, -- example: [148411, 150317]
name text, -- example: ['Ausgewählter der Ahnfrau', 'Elegido de la Antepasada']
text text, -- example: ['Erstschlag (Diese Kreatur fügt Kampfscha', 'Daña primero. (Esta criatura hace daño d']
type text, -- example: ['Kreatur — Mensch, Kleriker', 'Criatura — Clérigo humano']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_foreign_data_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE legalities (
id integer, -- example: [1, 2]
format text, -- example: ['commander', 'duel']
status text, -- example: ['Legal', 'Banned']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_legalities_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE sets (
id integer, -- example: [1, 2]
baseSetSize integer, -- example: [383, 302]
block text, -- example: ['Core Set', 'Mirrodin']
booster text, -- example: ["{'default': {'boosters': [{'contents': {"]
code text, -- example: ['10E', '2ED']
isFoilOnly integer, -- example: [0, 1]
isForeignOnly integer, -- example: [0, 1]
isNonFoilOnly integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isPartialPreview integer, -- example: [0, 1]
keyruneCode text, -- example: ['10E', '2ED']
mcmId integer, -- magic card market id, example: [74, 3204]
mcmIdExtras integer, -- magic card market ID Extras, example: [3209, 3459]
mcmName text, -- magic card market name, example: ['Tenth Edition', 'Double Masters']
mtgoCode text, -- magic the gathering online code, example: ['10E', '2XM']
name text, -- example: ['Tenth Edition', 'Unlimited Edition']
parentCode text, -- example: ['JMP', 'MH1']
releaseDate date, -- example: ['2007-07-13', '1993-12-01']
tcgplayerGroupId integer, -- example: [1, 115]
totalSetSize integer, -- example: [508, 302]
type text, -- example: ['core', 'masters']
PRIMARY KEY (id)
);
CREATE TABLE set_translations (
id integer, -- example: [1, 2]
`language` text, -- example: ['Italian', 'Chinese Simplified', 'Chinese Traditional']
setCode text, -- example: ['10E', '4ED']
translation text, -- example: ['核心系列第十版', 'Dixième édition']
PRIMARY KEY (id),
CONSTRAINT fk_set_translations_setcode FOREIGN KEY (setCode) REFERENCES sets (code)
);
CREATE TABLE rulings (
id integer, -- example: [1, 2]
`date` date, -- example: ['2007-07-15', '2007-02-01']
text text, -- example: ['You draw the card when Bandage resolves,', 'If you double a negative life total, you']
uuid text, -- example: ['6d268c95-c176-5766-9a46-c14f739aba1c', '56f4935b-f6c5-59b9-88bf-9bcce20247ce']
PRIMARY KEY (id),
CONSTRAINT fk_rulings_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
"""
question = """
Italian translation refers to language = 'Italian'; have a translation means translation is not null; base set number of under 100 refers to baseSetSize < 10
Among the sets of cards that have an Italian translation, how many of them have a base set number of under 100?
"""
prompt = input_prompt_template.format(db_details = db_details, question = question)
model_path = os.path.join(project_dir, "ckpts", "OmniSQL-7B")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
).to("cuda:0")
chat_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt = True, tokenize = False
)
inputs = tokenizer([chat_prompt], return_tensors="pt")
inputs = inputs.to(model.device)
output_ids = model.generate(
**inputs,
eos_token_id = tokenizer.eos_token_id,
max_new_tokens = 2048
)
input_len = len(inputs.input_ids[0])
output_ids = output_ids[0][input_len:]
response = tokenizer.batch_decode([output_ids], skip_special_tokens = True)[0]
print(response)
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os
from pathlib import Path
project_dir = str(Path(__file__).resolve().parent.parent)
input_prompt_template = '''Task Overview:
You are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.
Database Engine:
SQLite
Database Schema:
{db_details}
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
Question:
{question}
Instructions:
- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.
- The generated query should return all of the information asked in the question without any missing or extra information.
- Before generating the final SQL query, please think through the steps of how to write the query.
Output Format:
In your answer, please enclose the generated SQL query in a code block:
```
-- Your SQL query
```
Take a deep breath and think step by step to find the correct SQL query.'''
db_details = """
CREATE TABLE cards (
id integer, -- unique id number identifying the cards, example: [41138, 1349]
artist text, -- example: ['Pete Venters', 'Volkan Baǵa']
asciiName text, -- example: ['El-Hajjaj', 'Junun Efreet']
availability text, -- example: ['mtgo,paper', 'paper']
borderColor text, -- example: ['black', 'white']
cardKingdomFoilId text, -- example: ['123094', '123095']
cardKingdomId text, -- example: ['122719', '122720']
colorIdentity text, -- example: ['W', 'B']
colorIndicator text, -- example: ['U', 'G']
colors text, -- example: ['W', 'B']
convertedManaCost real, -- example: [7.0, 5.0]
duelDeck text, -- example: ['a', 'b']
edhrecRank integer, -- rec Rank in edh, example: [15650, 12702]
faceConvertedManaCost real, -- example: [4.0, 5.0]
faceName text, -- example: ['Dusk', 'Dawn']
flavorName text, -- example: ['Godzilla, King of the Monsters', 'King Caesar, Ancient Guardian']
flavorText text, -- example: ['Every tear shed is a drop of immortality', 'The perfect antidote for a tightly packe']
frameEffects text, -- example: ['legendary', 'nyxtouched']
frameVersion text, -- example: ['2003', '1993']
hand text, -- example: ['1', '0']
hasAlternativeDeckLimit integer, -- example: [0, 1]
hasContentWarning integer, -- example: [0, 1]
hasFoil integer, -- example: [0, 1]
hasNonFoil integer, -- example: [1, 0]
isAlternative integer, -- example: [0, 1]
isFullArt integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isOversized integer, -- example: [0, 1]
isPromo integer, -- is Promotion, example: [0, 1]
isReprint integer, -- example: [1, 0]
isReserved integer, -- example: [0, 1]
isStarter integer, -- example: [0, 1]
isStorySpotlight integer, -- example: [0, 1]
isTextless integer, -- example: [0, 1]
isTimeshifted integer, -- example: [0, 1]
keywords text, -- example: ['First strike', 'Flying']
layout text, -- example: ['normal', 'aftermath']
leadershipSkills text, -- example: ["{'brawl': False, 'commander': True, 'oat", "{'brawl': False, 'commander': False, 'oa"]
life text, -- example: ['-5', '-1']
loyalty text, -- example: ['6', '3']
manaCost text, -- example: ['{5}{W}{W}', '{4}{W}']
mcmId text, -- example: ['16165', '16166']
mcmMetaId text, -- example: ['156', '176']
mtgArenaId text, -- example: ['74983', '74986']
mtgjsonV4Id text, -- example: ['ad41be73-582f-58ed-abd4-a88c1f616ac3', '9eb2e54c-a12b-5e88-a9c0-d8c84c52d59c']
mtgoFoilId text, -- example: ['27501', '26993']
mtgoId text, -- example: ['27500', '26992']
multiverseId text, -- example: ['130550', '129465']
name text, -- example: ["Ancestor's Chosen", 'Angel of Mercy']
number text, -- example: ['1', '2']
originalReleaseDate text, -- example: ['2012/12/1', '2006/12/1']
originalText text, -- example: ['First strike (This creature deals combat', "Flying (This creature can't be blocked e"]
originalType text, -- example: ['Creature - Human Cleric', 'Creature - Angel']
otherFaceIds text, -- example: ['87f0062a-8321-5c16-960e-a12ce1df5839', 'f9f10d34-071c-57a6-b58c-7553abad5c20']
power text, -- example: ['4', '3']
printings text, -- example: ['10E,JUD,UMA', '10E,8ED,9ED,DDC,DVD,IMA,INV,JMP,MB1,P02,']
promoTypes text, -- example: ['boxtopper,boosterfun', 'boosterfun']
purchaseUrls text, -- example: ["{'cardKingdom': 'https://mtgjson.com/lin"]
rarity text, -- example: ['uncommon', 'common']
scryfallId text, -- example: ['7a5cd03c-4227-4551-aa4b-7d119f0468b5', '8f7980d4-da43-4d6d-ad16-14b8a34ae91d']
scryfallIllustrationId text, -- example: ['be2f7173-c8b7-4172-a388-9b2c6b3c16e5', 'e4d6c53f-e936-4be8-8b70-47c2be863b20']
scryfallOracleId text, -- example: ['fc2ccab7-cab1-4463-b73d-898070136d74', 'a2daaf32-dbfe-4618-892e-0da24f63a44a']
setCode text, -- example: ['10E', '2ED']
side text, -- example: ['a', 'b']
subtypes text, -- example: ['Human,Cleric', 'Angel']
supertypes text, -- example: ['Legendary', 'Basic']
tcgplayerProductId text, -- example: ['15032', '15033']
text text, -- example: ['First strike (This creature deals combat', 'Flying\nWhen Angel of Mercy enters the ba']
toughness text, -- example: ['4', '3']
type text, -- example: ['Creature — Human Cleric', 'Creature — Angel']
types text, -- example: ['Creature', 'Instant']
uuid text, -- example: ['00010d56-fe38-5e35-8aed-518019aa36a5', '0001e0d0-2dcd-5640-aadc-a84765cf5fc9']
variations text, -- example: ['b7c19924-b4bf-56fc-aa73-f586e940bd42', '8fd4e2eb-3eb4-50ea-856b-ef638fa47f8a']
watermark text, -- example: ['set', 'set (HOU)', 'set (LGN)']
PRIMARY KEY (id)
);
CREATE TABLE foreign_data (
id integer, -- example: [1, 2]
flavorText text, -- example: ['„Es ist der Wille aller, und meine Hand,', '"La voluntad de todos, realizada por mi ']
`language` text, -- example: ['Italian', 'German', 'Spanish']
multiverseid integer, -- example: [148411, 150317]
name text, -- example: ['Ausgewählter der Ahnfrau', 'Elegido de la Antepasada']
text text, -- example: ['Erstschlag (Diese Kreatur fügt Kampfscha', 'Daña primero. (Esta criatura hace daño d']
type text, -- example: ['Kreatur — Mensch, Kleriker', 'Criatura — Clérigo humano']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_foreign_data_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE legalities (
id integer, -- example: [1, 2]
format text, -- example: ['commander', 'duel']
status text, -- example: ['Legal', 'Banned']
uuid text, -- example: ['5f8287b1-5bb6-5f4c-ad17-316a40d5bb0c', '57aaebc1-850c-503d-9f6e-bb8d00d8bf7c']
PRIMARY KEY (id),
CONSTRAINT fk_legalities_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
CREATE TABLE sets (
id integer, -- example: [1, 2]
baseSetSize integer, -- example: [383, 302]
block text, -- example: ['Core Set', 'Mirrodin']
booster text, -- example: ["{'default': {'boosters': [{'contents': {"]
code text, -- example: ['10E', '2ED']
isFoilOnly integer, -- example: [0, 1]
isForeignOnly integer, -- example: [0, 1]
isNonFoilOnly integer, -- example: [0, 1]
isOnlineOnly integer, -- example: [0, 1]
isPartialPreview integer, -- example: [0, 1]
keyruneCode text, -- example: ['10E', '2ED']
mcmId integer, -- magic card market id, example: [74, 3204]
mcmIdExtras integer, -- magic card market ID Extras, example: [3209, 3459]
mcmName text, -- magic card market name, example: ['Tenth Edition', 'Double Masters']
mtgoCode text, -- magic the gathering online code, example: ['10E', '2XM']
name text, -- example: ['Tenth Edition', 'Unlimited Edition']
parentCode text, -- example: ['JMP', 'MH1']
releaseDate date, -- example: ['2007-07-13', '1993-12-01']
tcgplayerGroupId integer, -- example: [1, 115]
totalSetSize integer, -- example: [508, 302]
type text, -- example: ['core', 'masters']
PRIMARY KEY (id)
);
CREATE TABLE set_translations (
id integer, -- example: [1, 2]
`language` text, -- example: ['Italian', 'Chinese Simplified', 'Chinese Traditional']
setCode text, -- example: ['10E', '4ED']
translation text, -- example: ['核心系列第十版', 'Dixième édition']
PRIMARY KEY (id),
CONSTRAINT fk_set_translations_setcode FOREIGN KEY (setCode) REFERENCES sets (code)
);
CREATE TABLE rulings (
id integer, -- example: [1, 2]
`date` date, -- example: ['2007-07-15', '2007-02-01']
text text, -- example: ['You draw the card when Bandage resolves,', 'If you double a negative life total, you']
uuid text, -- example: ['6d268c95-c176-5766-9a46-c14f739aba1c', '56f4935b-f6c5-59b9-88bf-9bcce20247ce']
PRIMARY KEY (id),
CONSTRAINT fk_rulings_uuid FOREIGN KEY (uuid) REFERENCES cards (uuid)
);
This schema describes the database's structure, including tables, columns, primary keys, foreign keys, and any relevant relationships or constraints.
"""
question = """
Italian translation refers to language = 'Italian'; have a translation means translation is not null; base set number of under 100 refers to baseSetSize < 10
Among the sets of cards that have an Italian translation, how many of them have a base set number of under 100?
"""
prompt = input_prompt_template.format(db_details = db_details, question = question)
model_path = os.path.join(project_dir, "ckpts", "OmniSQL-7B")
tokenizer = AutoTokenizer.from_pretrained(model_path)
sampling_params = SamplingParams(
temperature = 0,
max_tokens = 2048,
n = 1
)
llm = LLM(
model = model_path,
dtype = "float16",
tensor_parallel_size = 1,
max_model_len = 8192,
gpu_memory_utilization = 0.92,
swap_space = 8,
enforce_eager = True,
disable_custom_all_reduce = True,
trust_remote_code = True
)
chat_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt = True, tokenize = False
)
outputs = llm.generate([chat_prompt], sampling_params)
for output in outputs:
responses = [o.text for o in output.outputs]
print(responses[0])
#!/bin/bash
curl http://10.16.5.2:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ckpts/OmniSQL-7B",
"messages": [{"role": "user", "content": "Task Overview:\nYou are a data science expert. Below, you are provided with a database schema and a natural language question. Your task is to understand the schema and generate a valid SQL query to answer the question.\n\nDatabase Engine:\nSQLite\n\nDatabase Schema:\nTable: cards(id, name, language, translation, baseSetSize)\n\nQuestion:\nAmong the sets of cards that have an Italian translation, how many of them have a base set number of under 100?\n\nInstructions:\n- Make sure you only output the information that is asked in the question. If the question asks for a specific column, make sure to only include that column in the SELECT clause, nothing more.\n- The generated query should return all of the information asked in the question without any missing or extra information.\n- Before generating the final SQL query, please think through the steps of how to write the query.\n\nOutput Format:\nIn your answer, please enclose the generated SQL query in a code block:\n```sql\n-- Your SQL query\n```\n\nTake a deep breath and think step by step to find the correct SQL query."}],
"max_tokens": 1024,
"temperature": 0
}'
\ No newline at end of file
# 模型唯一标识
modelCode=1499
# 模型名称
modelName=OmniSQL_pytorch
# 模型描述
modelDescription=数据库问答模型
# 应用场景
appScenario=训练,推理,对话问答,电商,教育,交通,能源
# 框架类型
frameType=Pytorch
# OmniSQL Training and Evaluation
## Environment Setup
All experiments were conducted using:
- **Anaconda 3**
- **Python 3.9.5**
- **8 x NVIDIA A800 80GB GPUs**
**Note:** A single A800 80GB GPU is sufficient for inference and evaluation. For training OmniSQL from scratch, 8 x A800 80GB GPUs are recommended.
## Dataset Preparation
### Download
Download the datasets from:
- [ModelScope-OmniSQL-datasets](https://modelscope.cn/datasets/seeklhy/OmniSQL-datasets/summary)
- [HuggingFace-OmniSQL-datasets](https://huggingface.co/datasets/seeklhy/OmniSQL-datasets)
The datasets include BIRD, Spider, ScienceBenchmark, EHRSQL, Spider2-SQLite, Spider-DK, Spider-Realistic, Spider-Syn, and SynSQL-2.5M. Unzip `data.zip` in this folder.
### Pre-processing
The pre-processed datasets are included in `data.zip` (see the `*.json` files). You can also reproduce the pre-processing steps if needed.
1. **Set Up Environment:**
```sh
conda create -n omnisql_process_data python=3.9.5
conda activate omnisql_process_data
apt-get update
apt-get install -y openjdk-11-jdk
pip3 install func_timeout ijson pyserini==0.22.1 faiss-cpu torch==2.1.0 numpy==1.24.3 nltk==3.8.1
python3 nltk_downloader.py
```
2. **Run Pre-processing Scripts:**
```sh
# Build BM25 index for database values
python3 build_contents_index.py
# Prepare input-output sequences
sh process_dataset.sh
```
**Note:** Processing SynSQL-2.5M may take over 24 hours due to its size (~2.5 million samples).
## Evaluation Reproduction
You can easily reproduce our evaluation results as follows:
1. **Set Up Environment:**
```sh
conda create -n omnisql_eval python=3.9.5
conda activate omnisql_eval
pip3 install vllm==0.6.3.post1 func_timeout tqdm matplotlib nltk==3.8.1 sqlparse
python3 nltk_downloader.py
```
2. **Download Evaluation Materials:**
Download Spider's test-suite databases and evaluation scripts from [test_suite_sql_eval.zip](https://drive.google.com/file/d/1iNa1WgA9tN_OFna08nq_tHZdXx9Lz2vO/view) and unzip `test_suite_sql_eval.zip` in this folder.
3. **Run Evaluation:**
```python
python3 eval_open_source_models.py
```
Predicted SQL queries are saved in the `results` folder, and evaluation results (e.g., model accuracy) are stored in the `evaluation_results` folder.
## Training OmniSQL from Scratch
To train OmniSQL from scratch:
1. **Set Up Environment:**
```sh
conda create -n omnisql_train python=3.9.5
conda activate omnisql_train
pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 transformers==4.45.1 accelerate==0.34.2 deepspeed==0.10.3 numpy==1.24.3 peft datasets tensorboard ijson
```
To speed up attention calculation, install flash-attention:
```bash
# Build from source (not recommended)
pip3 install flash-attn==2.5.8 --no-build-isolation
```
It's recommended to download a precompiled flash-attn Wheel from [flash-attn-2.5.8](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8). Choose the appropriate `.whl` file based on your environment: `flash_attn-2.5.8+cu{cuda_version}torch{torch_version}cxx11abiFALSE-cp{python_version}-cp{python_version}-linux_x86_64.whl`.
For example, if your CUDA version is 12.2, PyTorch version is 2.1, and Python version is 3.9.5, download `flash_attn-2.5.8+cu122torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl` and install it using `pip3 install`.
2. **Training Scripts:**
```sh
# train OmniSQL-7B using SynSQL-2.5M
sh train_omnisql_7b.sh
# train OmniSQL-14B using SynSQL-2.5M
sh train_omnisql_14b.sh
# train OmniSQL-32B using SynSQL-2.5M
sh train_omnisql_32b.sh
```
To train the full version of OmniSQL, you should manually merge the three training sets (`./data/train_synsql.json`, `./data/train_bird.json`, and `./data/train_spider.json`) and update the `DATASET_DIR` in the scripts. For OmniSQL-32B, you can merge LoRA adapters into the base model using `merge_lora_adapter.py`.
**Note:** Training OmniSQL from scratch is resource and time-intensive. As reported in our paper, training OmniSQL-7B/14B/32B requires approximately 6, 12, and 20 days, respectively, on a single machine equipped with 8 NVIDIA A800 80GB GPUs. Please consider whether you need to train them again. **We encourage using our open-sourced OmniSQL models directly or continuing to train your text-to-SQL model with a smaller dataset based on OmniSQL.**
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 64
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 32
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 64
gradient_clipping: 1.0
offload_optimizer_device: 'cpu'
offload_param_device: 'cpu'
zero3_init_flag: false
zero_stage: 3
zero3_save_16bit_model: true
distributed_type: DEEPSPEED
downcast_bf16: 'true'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
import os
import json
import argparse
import evaluate_bird
import evaluate_spider2
import evaluate_spider
import matplotlib.pyplot as plt
from tqdm import tqdm
def visualize(eval_name, acc_dict, ylabel, file_path):
plt.figure(figsize=(10, 6))
ckpt_ids = list(range(len(acc_dict)))
values = list(acc_dict.values())
if isinstance(values[0], list): # Spider has two metrics: EX acc and TS acc
num_lines = len(values[0])
labels = ["EX", "TS"]
assert num_lines == len(labels)
for i in range(num_lines):
line_values = [v[i] for v in values]
plt.plot(ckpt_ids, line_values, marker='o', linestyle='-', label=labels[i])
else:
plt.plot(ckpt_ids, values, marker='o', linestyle='-', label="EX")
plt.title(eval_name)
plt.xlabel('ckpt-id')
plt.ylabel(ylabel)
plt.grid(True)
plt.legend()
plt.savefig(file_path)
plt.close()
def save_evaluation_results(file_path, acc_dict):
with open(file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(acc_dict, indent=2, ensure_ascii=False))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_ckpt_dir", type = str, default = "./ckpts")
parser.add_argument('--multiple_models', action='store_true', help='Evaluate multiple models from a folder.')
parser.add_argument("--source", type = str, default = "bird")
parser.add_argument("--visible_devices", type = str, default = "0,1")
parser.add_argument("--input_file", type = str, help = "input file path (prompts)")
parser.add_argument("--eval_name", type = str, help = "name of the evaluation set")
parser.add_argument("--tensor_parallel_size", type = int, help = "the number of used GPUs", default = 1)
parser.add_argument("--n", type = int, help = "sampling number", default = 16)
parser.add_argument("--gold_file", type = str, help = "gold sql path")
parser.add_argument("--db_path", type = str, help = "database path")
parser.add_argument("--ts_db_path", type = str, default = "", help = "test suite database path (required by Spider)")
parser.add_argument("--gold_result_dir", type = str, help = "gold sql execution results (required by Spider2.0)")
parser.add_argument("--eval_standard", type = str, help = "evaluation standard (required by Spider2.0)")
opt = parser.parse_args()
print(opt)
assert opt.source in ["spider", "bird", "spider2.0"]
if opt.multiple_models:
ckpt_ids = os.listdir(opt.output_ckpt_dir)
ckpt_ids = sorted(ckpt_ids, key=lambda x: int(x.split("-")[1]))
print(ckpt_ids)
else:
ckpt_ids = [""]
greedy_search_acc_dict = dict()
pass_at_k_acc_dict = dict()
major_voting_acc_dict = dict()
os.makedirs(os.path.join("results", opt.eval_name), exist_ok=True)
os.makedirs(os.path.join("evaluation_results", opt.eval_name), exist_ok=True)
for ckpt_id in tqdm(ckpt_ids):
print("Evaluating ckpt:", ckpt_id)
if ckpt_id not in greedy_search_acc_dict.keys():
# greedy decoding
gs_pred_file = f"results/{opt.eval_name}/greedy_search_{ckpt_id}.json"
greedy_search_cmd = f"CUDA_VISIBLE_DEVICES={opt.visible_devices} python3 infer.py \
--pretrained_model_name_or_path {os.path.join(opt.output_ckpt_dir, ckpt_id)} \
--input_file {opt.input_file} \
--output_file {gs_pred_file} \
--tensor_parallel_size {opt.tensor_parallel_size} \
--n 1 \
--temperature 0.0"
os.system(greedy_search_cmd)
# evaluate greedy search
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("greedy_search", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, gs_pred_file, opt.db_path, True)
# record evaluation results
gs_acc, _ = evaluate_spider2.evaluate("greedy_search", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, gs_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, gs_pred_file, opt.db_path, "greedy_search", True)
# record evaluation results
gs_acc, _ = evaluate_bird.run_eval(opt.gold_file, gs_pred_file, opt.db_path, "greedy_search", True)
elif opt.source == "spider": # for "spider"
# warm up
evaluate_spider.run_spider_eval(opt.gold_file, gs_pred_file, opt.db_path,
opt.ts_db_path, "greedy_search", True)
# record evaluation results
ex_score, ts_score = evaluate_spider.run_spider_eval(opt.gold_file, gs_pred_file, opt.db_path,
opt.ts_db_path, "greedy_search", True)
if ts_score is None:
gs_acc = ex_score
else:
gs_acc = [ex_score, ts_score]
greedy_search_acc_dict[ckpt_id] = gs_acc
print(greedy_search_acc_dict)
visualize(opt.eval_name, greedy_search_acc_dict, "greedy_search",
os.path.join("evaluation_results", opt.eval_name, "greedy_search.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "greedy_search.json"), greedy_search_acc_dict)
else:
print(f"skip {ckpt_id} greedy search")
if ckpt_id not in major_voting_acc_dict.keys():
# sampling
sampling_pred_file = f"results/{opt.eval_name}/sampling_{ckpt_id}.json"
sampling_cmd = f"CUDA_VISIBLE_DEVICES={opt.visible_devices} python3 infer.py \
--pretrained_model_name_or_path {os.path.join(opt.output_ckpt_dir, ckpt_id)} \
--input_file {opt.input_file} \
--output_file {sampling_pred_file} \
--tensor_parallel_size {opt.tensor_parallel_size} \
--n {opt.n} \
--temperature 0.8"
os.system(sampling_cmd)
# evaluate pass@k (we do not evaluate pass@k for spider and its variants)
if opt.source in ["bird", "spider2.0"]:
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("pass@k", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
# record evaluation results
pass_at_k_acc, _ = evaluate_spider2.evaluate("pass@k", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "pass@k", True)
# record evaluation results
pass_at_k_acc, _ = evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "pass@k", True)
pass_at_k_acc_dict[ckpt_id] = pass_at_k_acc
print(pass_at_k_acc_dict)
visualize(opt.eval_name, pass_at_k_acc_dict, "pass_at_k",
os.path.join("evaluation_results", opt.eval_name, "pass_at_k.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "pass_at_k.json"), pass_at_k_acc_dict)
# evaluate major voting
if opt.source == "spider2.0":
# warm up
evaluate_spider2.evaluate("major_voting", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
# record evaluation results
major_voting_acc, _ = evaluate_spider2.evaluate("major_voting", opt.gold_result_dir, opt.eval_standard,
opt.gold_file, sampling_pred_file, opt.db_path, True)
elif opt.source == "bird":
# warm up
evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "major_voting", True)
# record evaluation results
major_voting_acc, _ = evaluate_bird.run_eval(opt.gold_file, sampling_pred_file, opt.db_path, "major_voting", True)
else: # spider
# warm up
evaluate_spider.run_spider_eval(opt.gold_file, sampling_pred_file, opt.db_path,
opt.ts_db_path, "major_voting", True)
# record evaluation results
ex_score, ts_score = evaluate_spider.run_spider_eval(opt.gold_file, sampling_pred_file, opt.db_path,
opt.ts_db_path, "major_voting", True)
if ts_score is None:
major_voting_acc = ex_score
else:
major_voting_acc = [ex_score, ts_score]
major_voting_acc_dict[ckpt_id] = major_voting_acc
print(major_voting_acc_dict)
visualize(opt.eval_name, major_voting_acc_dict, "major_voting",
os.path.join("evaluation_results", opt.eval_name, "major_voting.png"))
save_evaluation_results(os.path.join("evaluation_results", opt.eval_name, "major_voting.json"), major_voting_acc_dict)
else:
print(f"skip {ckpt_id} pass at k and major voting")
\ No newline at end of file
import json
import os, shutil
import sqlite3
from func_timeout import func_set_timeout, FunctionTimedOut
from pathlib import Path
# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path):
try:
if not os.path.exists(sqlite_path):
print("Openning a new connection %s" % sqlite_path)
connection = sqlite3.connect(sqlite_path, check_same_thread = False)
except Exception as e:
print(sqlite_path)
raise e
connection.text_factory = lambda b: b.decode(errors="ignore")
cursor = connection.cursor()
return cursor
# execute predicted sql with a long time limitation (for buiding content index)
@func_set_timeout(3600)
def execute_sql(cursor, sql):
cursor.execute(sql)
return cursor.fetchall()
def remove_contents_of_a_folder(index_path):
# if index_path does not exist, then create it
os.makedirs(index_path, exist_ok = True)
# remove files in index_path
for filename in os.listdir(index_path):
file_path = os.path.join(index_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def is_number(s):
try:
float(s)
return True
except ValueError:
return False
def build_content_index(db_file_path, index_path):
'''
create BM25 index for all string values in a database
'''
cursor = get_cursor_from_path(db_file_path)
results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
table_names = [result[0] for result in results]
all_column_contents = []
for table_name in table_names:
# skip SQLite system table: sqlite_sequence
if table_name == "sqlite_sequence":
continue
results = execute_sql(cursor, f"SELECT name FROM PRAGMA_TABLE_INFO('{table_name}')")
column_names_in_one_table = [result[0] for result in results]
for column_name in column_names_in_one_table:
try:
print(f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
results = execute_sql(cursor, f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
column_contents = [result[0] for result in results if isinstance(result[0], str) and not is_number(result[0])]
for c_id, column_content in enumerate(column_contents):
# remove empty and extremely-long contents
if len(column_content) != 0 and len(column_content) <= 40:
all_column_contents.append(
{
"id": "{}-**-{}-**-{}".format(table_name, column_name, c_id), # .lower()
"contents": column_content
}
)
except Exception as e:
print(str(e))
os.makedirs('./data/temp_db_index', exist_ok = True)
with open("./data/temp_db_index/contents.json", "w") as f:
f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True))
# Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md
cmd = f'python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index "{index_path}" --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw'
d = os.system(cmd)
print(d)
os.remove("./data/temp_db_index/contents.json")
if __name__ == "__main__":
dataset_info = {
# BIRD train
"bird_train": {"db_path": "./data/bird/train/train_databases", "index_path_prefix": "./data/bird/train/db_contents_index"},
# BIRD dev
"bird_dev": {"db_path": "./data/bird/dev_20240627/dev_databases", "index_path_prefix": "./data/bird/dev_20240627/db_contents_index"},
# Spider train-dev-test
"spider": {"db_path": "./data/spider/test_database", "index_path_prefix": "./data/spider/db_contents_index"},
# Spider2.0-SQLite
"spider2_sqlite": {"db_path": "./data/spider2_sqlite/databases", "index_path_prefix": "./data/spider2_sqlite/db_contents_index"},
# SynSQL-2.5M dataset
"SynSQL-2.5M": {"db_path": "./data/SynSQL-2.5M/databases", "index_path_prefix": "./data/SynSQL-2.5M/db_contents_index"},
# spider-dk
"spider_dk": {"db_path": "./data/Spider-DK/database", "index_path_prefix": "./data/Spider-DK/db_contents_index"},
# EHRSQL_dev
"EHRSQL_dev": {"db_path": "./data/EHRSQL/database", "index_path_prefix": "./data/EHRSQL/db_contents_index"},
# sciencebenchmark_dev
"sciencebenchmark_dev": {"db_path": "./data/sciencebenchmark/databases", "index_path_prefix": "./data/sciencebenchmark/db_contents_index"},
}
for dataset_name in dataset_info:
print(dataset_name)
db_path = dataset_info[dataset_name]["db_path"]
index_path_prefix = dataset_info[dataset_name]["index_path_prefix"]
remove_contents_of_a_folder(index_path_prefix)
# build content index
db_ids = os.listdir(db_path)
# db_ids = ["the_table's_domain_appears_to_be_related_to_demographic_and_employment_data"]
for db_id in db_ids:
db_file_path = os.path.join(db_path, db_id, db_id + ".sqlite")
if os.path.exists(db_file_path) and os.path.isfile(db_file_path):
print(f"The file '{db_file_path}' exists.")
build_content_index(
db_file_path,
os.path.join(index_path_prefix, db_id)
)
else:
print(f"The file '{db_file_path}' does not exist.")
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
models = [
"/home/ckpts/OmniSQL-7B",
#"seeklhy/OmniSQL-14B",
#"seeklhy/OmniSQL-32B",
# "qwen/Qwen2.5-Coder-7B-Instruct",
# "qwen/Qwen2.5-Coder-14B-Instruct",
# "qwen/Qwen2.5-Coder-32B-Instruct",
# "qwen/Qwen2.5-7B-Instruct",
# "qwen/Qwen2.5-14B-Instruct",
# "qwen/Qwen2.5-32B-Instruct",
# "qwen/Qwen2.5-72B-Instruct",
# "meta-llama/Meta-Llama-3.1-8B-Instruct",
# "meta-llama/Meta-Llama-3.1-70B-Instruct",
# "infly/OpenCoder-8B-Instruct",
# "deepseek-ai/deepseek-coder-6.7b-instruct",
# "deepseek-ai/deepseek-coder-33b-instruct",
# "deepseek-ai/deepseek-v3",
# "ibm-granite/granite-34b-code-instruct-8k",
# "ibm-granite/granite-20b-code-instruct-8k",
# "ibm-granite/granite-8b-code-instruct-128k",
# "ibm-granite/granite-3.1-8b-instruct",
# "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
# "bigcode/starcoder2-15b-instruct-v0.1",
# "mistralai/Codestral-22B-v0.1",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
]
visible_devices = "0,1" # visible devices for vLLM
tensor_parallel_size = len(visible_devices.split(","))
for model in models:
model_name = model.split("/")[-1].strip()
spider2_test_eval_name = f"{model_name}_test_spider2_sqlite"
spider2_test_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider2.0 --visible_devices {visible_devices} --input_file ./data/test_spider2_sqlite.json --eval_name {spider2_test_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider2_sqlite/test.json --db_path ./data/spider2_sqlite/databases/ --gold_result_dir ./data/spider2_sqlite/gold_exec_result/ --eval_standard ./data/spider2_sqlite/spider2_sqlite_eval.jsonl"
os.system(spider2_test_evaluation_cmd)
dev_bird_eval_name = f"{model_name}_dev_bird"
dev_bird_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_bird.json --eval_name {dev_bird_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/bird/dev_20240627/dev.json --db_path ./data/bird/dev_20240627/dev_databases"
os.system(dev_bird_evaluation_cmd)
dev_spider_eval_name = f"{model_name}_dev_spider"
dev_spider_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider.json --eval_name {dev_spider_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider/dev_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(dev_spider_evaluation_cmd)
test_spider_eval_name = f"{model_name}_test_spider"
test_spider_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/test_spider.json --eval_name {test_spider_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider/test_gold.sql --db_path ./data/spider/test_database"
os.system(test_spider_evaluation_cmd)
spider_dk_eval_name = f"{model_name}_dev_spider_dk"
spider_dk_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_dk.json --eval_name {spider_dk_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/Spider-DK/spider_dk_gold.sql --db_path ./data/Spider-DK/database"
os.system(spider_dk_evaluation_cmd)
spider_realistic_eval_name = f"{model_name}_dev_spider_realistic"
spider_realistic_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_realistic.json --eval_name {spider_realistic_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/spider-realistic/spider_realistic_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(spider_realistic_evaluation_cmd)
spider_syn_eval_name = f"{model_name}_dev_spider_syn"
spider_syn_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source spider --visible_devices {visible_devices} --input_file ./data/dev_spider_syn.json --eval_name {spider_syn_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/Spider-Syn/spider_syn_gold.sql --db_path ./data/spider/database --ts_db_path ./test_suite_sql_eval/test_suite_database"
os.system(spider_syn_evaluation_cmd)
dev_ehrsql_eval_name = f"{model_name}_dev_ehrsql"
dev_ehrsql_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_ehrsql.json --eval_name {dev_ehrsql_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/EHRSQL/dev.json --db_path ./data/EHRSQL/database"
os.system(dev_ehrsql_evaluation_cmd)
dev_sciencebenchmark_eval_name = f"{model_name}_dev_sciencebenchmark"
dev_sciencebenchmark_evaluation_cmd = f"python3 auto_evaluation.py --output_ckpt_dir {model} --source bird --visible_devices {visible_devices} --input_file ./data/dev_sciencebenchmark.json --eval_name {dev_sciencebenchmark_eval_name} --tensor_parallel_size {tensor_parallel_size} --n 8 --gold_file ./data/sciencebenchmark/dev.json --db_path ./data/sciencebenchmark/databases"
os.system(dev_sciencebenchmark_evaluation_cmd)
import sys
import sqlite3
import json
import argparse
import os
from func_timeout import func_timeout, FunctionTimedOut
from tqdm import tqdm
import multiprocessing as mp
import random
random.seed(42)
execution_results = None
evaluation_results = None
def parse_option():
parser = argparse.ArgumentParser()
parser.add_argument('--pred', type = str, default = "predict_dev.json")
parser.add_argument('--gold', type = str, default = "./bird/dev/dev.json")
parser.add_argument('--db_path', type = str, default = "./bird/dev/dev_databases")
parser.add_argument('--mode', type = str, default = "greedy_search")
opt = parser.parse_args()
return opt
def execute_sql(data_idx, db_file, sql):
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
try:
conn.execute("BEGIN TRANSACTION;")
cursor.execute(sql)
execution_res = cursor.fetchall()
execution_res = frozenset(execution_res) # make set hashable
conn.rollback()
conn.close()
return data_idx, db_file, sql, execution_res, 1
# if len(execution_res) > 0:
# return data_idx, db_file, sql, execution_res, 1
# elif len(execution_res) == 0:
# return data_idx, db_file, sql, execution_res, 0
except:
conn.rollback()
conn.close()
return data_idx, db_file, sql, None, 0
def compare_sql(question_id, db_file, question, ground_truth, pred_sql) :
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
correctness = 0
try:
conn.execute("BEGIN TRANSACTION;")
cursor.execute(pred_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
print('Successfully executed')
if set(predicted_res) == set(ground_truth_res):
correctness = 1
conn.rollback()
except:
conn.rollback()
finally:
conn.close()
return question_id, db_file, question, ground_truth, pred_sql, correctness
def compare_sql_wrapper(args, timeout):
'''Wrap execute_sql for timeout'''
try:
result = func_timeout(timeout, compare_sql, args=args)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = (*args, 0)
except Exception as e:
result = (*args, 0)
return result
def execute_sql_wrapper(data_idx, db_file, sql, timeout):
try:
res = func_timeout(timeout, execute_sql, args=(data_idx, db_file, sql))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
print(f"Data index:{data_idx}\nSQL:\n{sql}\nTime Out!")
print("-"*30)
res = (data_idx, db_file, sql, None, 0)
except Exception as e:
res = (data_idx, db_file, sql, None, 0)
return res
def execute_callback_evaluate_sql(result):
'''Store the execution result in the collection'''
question_id, db_file, question, ground_truth, pred_sql, correctness = result
# evaluation_res = dict()
# evaluation_res['question_id'] = question_id
# evaluation_res["db_file"] = db_file
# evaluation_res["question"] = question
# evaluation_res["ground_truth"] = ground_truth
# evaluation_res["pred_sql"] = pred_sql
# evaluation_res["correctness"] = correctness
evaluation_results.append(
{
"question_id": question_id,
"db_file": db_file,
"question": question,
"ground_truth": ground_truth,
"pred_sql": pred_sql,
"correctness": correctness
}
)
print('Done:', question_id, correctness) # Print the progress
sys.stdout.flush()
sys.stderr.flush()
def execute_callback_execute_sqls(result):
data_idx, db_file, sql, query_result, valid = result
print('Done:', data_idx) # Print the progress
execution_results.append(
{
"data_idx": data_idx,
"db_file": db_file,
"sql": sql,
"query_result": query_result,
"valid": valid
}
)
def evaluate_sqls_parallel(db_files, questions, pred_sqls, ground_truth_sqls, num_cpus=1, timeout=1):
'''Execute the sqls in parallel'''
pool = mp.Pool(processes=num_cpus)
for question_id, db_file, question, pred_sql, ground_truth in zip([x for x in range(len(db_files))], db_files, questions, pred_sqls, ground_truth_sqls):
pool.apply_async(compare_sql_wrapper, args=((question_id, db_file, question, ground_truth, pred_sql), timeout), callback=execute_callback_evaluate_sql)
pool.close()
pool.join()
def execute_sqls_parallel(db_files, sqls, num_cpus=1, timeout=1):
pool = mp.Pool(processes=num_cpus)
for data_idx, db_file, sql in zip(list(range(len(sqls))), db_files, sqls):
pool.apply_async(execute_sql_wrapper, args=(data_idx, db_file, sql, timeout), callback=execute_callback_execute_sqls)
pool.close()
pool.join()
def mark_invalid_sqls(db_files, sqls):
global execution_results
execution_results = []
execute_sqls_parallel(db_files, sqls, num_cpus=20, timeout=10)
execution_results = sorted(execution_results, key=lambda x:x['data_idx'])
for idx, res in enumerate(execution_results):
if res["valid"] == 0:
sqls[idx] = "Error SQL"
return sqls
def major_voting(db_files, pred_sqls, sampling_num, return_random_one_when_all_errors=True):
global execution_results
mj_pred_sqls = []
execution_results = []
# execute all sampled SQL queries to obtain their execution results
execute_sqls_parallel(db_files, pred_sqls, num_cpus=20, timeout=10)
execution_results = sorted(execution_results, key=lambda x:x['data_idx'])
print("len(execution_results):", len(execution_results))
# perform major voting
for result_idx in range(0, len(execution_results), sampling_num):
major_voting_counting = dict()
execution_results_of_one_sample = execution_results[result_idx: result_idx + sampling_num]
# if no predicted SQLs are valid
if sum([res["valid"] for res in execution_results_of_one_sample]) == 0:
if return_random_one_when_all_errors:
mj_pred_sql = random.choice(execution_results_of_one_sample)["sql"] # select a random one to return
else:
mj_pred_sql = "Error SQL"
mj_pred_sqls.append(mj_pred_sql)
continue
for res in execution_results_of_one_sample:
if res["valid"] == 1: # skip invalid SQLs
if res["query_result"] in major_voting_counting:
major_voting_counting[res["query_result"]]["votes"] += 1
else:
major_voting_counting[res["query_result"]] = {"votes": 1, "sql": res["sql"]}
# find the SQL with the max votes
major_vote = max(major_voting_counting.values(), key=lambda x: x["votes"])
mj_pred_sql = major_vote["sql"]
mj_pred_sqls.append(mj_pred_sql)
return mj_pred_sqls
def run_eval(gold_file, pred_file, db_path, mode, save_pred_sqls, num_cpus=20, timeout=10):
global evaluation_results
gold = json.load(open(gold_file))
pred_results = json.load(open(pred_file))
db_files = [os.path.join(db_path, data["db_id"], data["db_id"] + ".sqlite") for data in gold]
questions = [data["question"] for data in gold]
pred_sql_key = "pred_sqls"
# pred_sql_key = "responses"
if "bird" in gold_file:
ground_truth_sqls = [data["SQL"] for data in gold]
else:
ground_truth_sqls = [data["query"] for data in gold]
if mode == "greedy_search":
pred_sqls = [res[pred_sql_key][0] for res in pred_results]
# save the (greedy-search) predicted SQL so we can check it out later
if save_pred_sqls:
with open(pred_file[:-5] + "_pred_greedy_search_sqls.json", "w", encoding="utf-8") as f:
f.write(json.dumps(pred_sqls, indent=2 ,ensure_ascii=False))
assert len(pred_results) == len(pred_sqls) == len(db_files) == len(questions) == len(ground_truth_sqls)
evaluation_results = []
evaluate_sqls_parallel(db_files, questions, pred_sqls, ground_truth_sqls, num_cpus=num_cpus, timeout=timeout)
# sort evaluation_results by question_id
evaluation_results = sorted(evaluation_results, key=lambda x:x['question_id'])
evaluation_scores = [res["correctness"] for res in evaluation_results]
for res in evaluation_results:
if res["correctness"] == 0:
print("question:", res["question"])
print("GT:", res["ground_truth"])
print("Pred:", res["pred_sql"])
print("-"*30)
print("EX Accuracy (greedy search):", sum(evaluation_scores)/len(evaluation_scores))
return sum(evaluation_scores)/len(evaluation_scores), pred_sqls
elif mode == "major_voting":
sampling_num = len(pred_results[0][pred_sql_key])
print("sampling_num:", sampling_num)
db_files = []
for gold_data in gold:
db_files.extend([os.path.join(db_path, gold_data["db_id"], gold_data["db_id"] + ".sqlite")] * sampling_num)
pred_sqls = []
for pred_data in pred_results:
pred_sqls.extend(pred_data[pred_sql_key])
assert len(pred_sqls) == len(db_files)
mj_pred_sqls = major_voting(db_files, pred_sqls, sampling_num)
# save the (major-voting) predicted SQL so we can check it out later
if save_pred_sqls:
with open(pred_file[:-5] + "_pred_major_voting_sqls.json", "w", encoding="utf-8") as f:
f.write(json.dumps(mj_pred_sqls, indent=2 ,ensure_ascii=False))
# reset db_files
db_files = []
for gold_data in gold:
db_files.append(os.path.join(db_path, gold_data["db_id"], gold_data["db_id"] + ".sqlite"))
assert len(mj_pred_sqls) == len(db_files) == len(questions) == len(ground_truth_sqls)
evaluation_results = []
evaluate_sqls_parallel(db_files, questions, mj_pred_sqls, ground_truth_sqls, num_cpus=num_cpus, timeout=timeout)
# sort evaluation_results by question_id
evaluation_results = sorted(evaluation_results, key=lambda x:x['question_id'])
evaluation_scores = [res["correctness"] for res in evaluation_results]
print("EX Accuracy (major voting):", sum(evaluation_scores)/len(evaluation_scores))
return sum(evaluation_scores)/len(evaluation_scores), mj_pred_sqls
elif mode == "pass@k":
all_scores = []
sampling_num = len(pred_results[0][pred_sql_key])
db_files = []
for gold_data in gold:
db_files.append(os.path.join(db_path, gold_data["db_id"], gold_data["db_id"] + ".sqlite"))
for sample_idx in range(sampling_num):
pred_sqls_for_specific_sample_idx = [pred_data[pred_sql_key][sample_idx] for pred_data in pred_results]
evaluation_results = []
evaluate_sqls_parallel(db_files, questions, pred_sqls_for_specific_sample_idx, ground_truth_sqls, num_cpus=num_cpus, timeout=timeout)
evaluation_results = sorted(evaluation_results, key=lambda x:x['question_id'])
evaluation_scores = [res["correctness"] for res in evaluation_results]
all_scores.append(evaluation_scores)
pass_at_k_scores = [1 if any(column) else 0 for column in zip(*all_scores)]
print(f"EX Accuracy (pass@{sampling_num}):", sum(pass_at_k_scores)/len(pass_at_k_scores))
return sum(pass_at_k_scores)/len(pass_at_k_scores), None
else:
raise ValueError("mode should be in [greedy_search, major_voting, pass@k]")
'''
python evaluate_bird.py --pred ./results/spider_dev_greedy_search_ckpt-5306.json --gold ../data/spider/dev.json --db_path ../data/spider/database/
python evaluate_bird.py --pred ./results/bird_dev_greedy_search_ckpt-5306.json --gold ../data/bird/dev_20240627/dev.json --db_path ../data/bird/dev_20240627/dev_databases/
'''
if __name__ == "__main__":
opt = parse_option()
run_eval(opt.gold, opt.pred, opt.db_path, opt.mode, False)
\ No newline at end of file
import json
import argparse
import os
import random
import re
from evaluate_bird import major_voting, mark_invalid_sqls
import tempfile
import subprocess
random.seed(42)
def parse_option():
parser = argparse.ArgumentParser()
parser.add_argument('--pred', type = str, default = "predict_dev.json")
parser.add_argument('--gold', type = str, default = "./data/spider/dev_gold.sql")
parser.add_argument('--db_path', type = str, default = "./data/spider/databases")
parser.add_argument('--ts_db_path', type = str, default = "")
parser.add_argument('--mode', type = str, default = "greedy_search")
opt = parser.parse_args()
return opt
def format_sql(sql):
sql = sql.strip()
# remove multi-line comments /* ... */
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL)
# remove single-line comments --
sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)
sql = sql.replace("\n", " ").replace("\t", " ")
sql = sql.strip()
if sql == "":
sql = "Error SQL"
return sql
def run_spider_eval(gold_file, pred_file, db_path, ts_db_path, mode, save_pred_sqls):
assert mode in ["greedy_search", "major_voting"]
gold_sqls = [line.split("\t")[0].strip() for line in open(gold_file).readlines()]
db_ids = [line.split("\t")[1].strip() for line in open(gold_file).readlines()]
pred = json.load(open(pred_file))
pred_sql_key = "pred_sqls"
# pred_sql_key = "responses"
pred_sqls = []
if mode == "greedy_search":
pred_sqls = [pred_data[pred_sql_key][0] for pred_data in pred]
assert len(pred_sqls) == len(db_ids)
db_files = [os.path.join(db_path, db_id, db_id + ".sqlite") for db_id in db_ids]
pred_sqls = mark_invalid_sqls(db_files, pred_sqls)
elif mode == "major_voting":
# perform major voting using the BIRD's evaluation script
sampling_num = len(pred[0][pred_sql_key])
print("sampling_num:", sampling_num)
all_db_files = []
for db_id in db_ids:
all_db_files.extend([os.path.join(db_path, db_id, db_id + ".sqlite")] * sampling_num)
all_pred_sqls = []
for pred_data in pred:
all_pred_sqls.extend(pred_data[pred_sql_key])
assert len(all_db_files) == len(all_pred_sqls)
pred_sqls = major_voting(all_db_files, all_pred_sqls, sampling_num, False)
pred_sqls = [format_sql(pred_sql) for pred_sql in pred_sqls]
assert len(pred_sqls) == len(gold_sqls)
if save_pred_sqls:
with open(pred_file[:-5] + f"_pred_{mode}_sqls.json", "w", encoding="utf-8") as f:
f.write(json.dumps(pred_sqls, indent=2 ,ensure_ascii=False))
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding="utf-8") as temp_file:
for pred_sql in pred_sqls:
temp_file.write(pred_sql + "\n")
temp_file_name = temp_file.name
print(temp_file_name)
print("Execution accuracy:")
cmd = f'python3 -u test_suite_sql_eval/evaluation.py --gold {gold_file} --pred {temp_file_name} --db {db_path} --etype exec'
print(cmd)
result = subprocess.run(cmd, shell=True, text=True, capture_output=True)
stdout = result.stdout
print(result.stderr)
# match the last num in the string
match = re.search(r'(\d+\.\d+)\s*$', stdout.strip())
ex_acc = float(match.group(1))
print(stdout)
print("ex_acc:", ex_acc)
ts_acc = None
if ts_db_path != "":
print("Test suit execution accuracy:")
cmd = f'python3 -u test_suite_sql_eval/evaluation.py --gold {gold_file} --pred {temp_file_name} --db {ts_db_path} --etype exec'
result = subprocess.run(cmd, shell=True, text=True, capture_output=True)
stdout = result.stdout
print(result.stderr)
# match the last num in the string
match = re.search(r'(\d+\.\d+)\s*$', stdout.strip())
ts_acc = float(match.group(1))
print(stdout)
print("ts_acc:", ts_acc)
os.remove(temp_file_name)
return ex_acc, ts_acc
if __name__ == "__main__":
opt = parse_option()
run_spider_eval(opt.gold, opt.pred, opt.db_path, opt.ts_db_path, opt.mode, False)
\ No newline at end of file
# import debugpy; debugpy.connect(('127.0.0.1', 5688))
import json
import re
import pandas as pd
import math
# import duckdb
from typing import List, Union
import os
import os.path as osp
import pandas as pd
import argparse
# from google.cloud import bigquery
import shutil
import multiprocessing as mp
import sqlite3
from tqdm import tqdm
# import snowflake.connector
import logging
from func_timeout import func_timeout, FunctionTimedOut
import sys
from functools import partial
import tempfile
class TeeOutput:
def __init__(self, filename):
self.console = sys.stdout
self.file = open(filename, 'w')
def write(self, message):
self.console.write(message)
self.file.write(message)
def flush(self):
self.console.flush()
self.file.flush()
def close(self):
self.file.close()
sys.stdout = TeeOutput('log.txt')
sys.stderr = sys.stdout
TOTAL_GB_PROCESSED = 0.0
byte_output_dict = {}
def load_jsonl_to_dict(jsonl_file):
data_dict = {}
with open(jsonl_file, 'r') as file:
for line in file:
item = json.loads(line.strip())
instance_id = item['instance_id']
data_dict[instance_id] = item
return data_dict
def load_json_list_to_dict(json_file_path):
with open(json_file_path, 'r', encoding='utf-8') as file:
data_list = json.load(file)
data_dict = {item['instance_id']: item for item in data_list}
return data_dict
def compare_multi_pandas_table(pred, multi_gold, multi_condition_cols, multi_ignore_order):
# print('multi_condition_cols', multi_condition_cols)
# print("len(multi_condition_cols)", len(multi_condition_cols))
if multi_condition_cols == [] or multi_condition_cols == [[]] or multi_condition_cols == [None] or multi_condition_cols == None:
multi_condition_cols = [[] for _ in range(len(multi_gold))]
elif len(multi_gold) > 1 and not all(isinstance(sublist, list) for sublist in multi_condition_cols):
multi_condition_cols = [multi_condition_cols for _ in range(len(multi_gold))]
# multi_ignore_order = [multi_ignore_order for _ in range(len(multi_gold))]
assert len(multi_gold) == len(multi_condition_cols) == len(multi_ignore_order)
for i, gold in enumerate(multi_gold):
if compare_pandas_table(pred, gold, multi_condition_cols[i], multi_ignore_order[i]):
return 1
return 0
def compare_pandas_table(pred, gold, condition_cols=[], ignore_order=False):
"""_summary_
Args:
pred (Dataframe): _description_
gold (Dataframe): _description_
condition_cols (list, optional): _description_. Defaults to [].
ignore_order (bool, optional): _description_. Defaults to False.
"""
# print('condition_cols', condition_cols)
tolerance = 1e-2
def vectors_match(v1, v2, tol=tolerance, ignore_order_=False):
if ignore_order_:
v1, v2 = (sorted(v1, key=lambda x: (x is None, str(x), isinstance(x, (int, float)))),
sorted(v2, key=lambda x: (x is None, str(x), isinstance(x, (int, float)))))
if len(v1) != len(v2):
return False
for a, b in zip(v1, v2):
if pd.isna(a) and pd.isna(b):
continue
elif isinstance(a, (int, float)) and isinstance(b, (int, float)):
if not math.isclose(float(a), float(b), abs_tol=tol):
return False
elif a != b:
return False
return True
if condition_cols != []:
gold_cols = gold.iloc[:, condition_cols]
else:
gold_cols = gold
pred_cols = pred
t_gold_list = gold_cols.transpose().values.tolist()
t_pred_list = pred_cols.transpose().values.tolist()
score = 1
for _, gold in enumerate(t_gold_list):
if not any(vectors_match(gold, pred, ignore_order_=ignore_order) for pred in t_pred_list):
score = 0
else:
for j, pred in enumerate(t_pred_list):
if vectors_match(gold, pred, ignore_order_=ignore_order):
break
return score
def get_sqlite_result(db_file_path, query, save_dir=None, file_name="result.csv", chunksize=500):
conn = sqlite3.connect(db_file_path)
memory_conn = sqlite3.connect(':memory:')
conn.backup(memory_conn)
try:
if save_dir:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i, chunk in enumerate(pd.read_sql_query(query, memory_conn, chunksize=chunksize)):
mode = 'a' if i > 0 else 'w'
header = i == 0
chunk.to_csv(os.path.join(save_dir, file_name), mode=mode, header=header, index=False)
else:
df = pd.read_sql_query(query, memory_conn)
return True, df
except Exception as e:
print(f"An error occurred: {e}")
return False, str(e)
finally:
memory_conn.close()
conn.close()
return True, None
def evaluate_spider2sql(gold_result_dir, eval_standard_dict, gold, pred_sqls, db_path, temp_dir):
instance_id2db_id = dict()
for gt_data in gold:
instance_id2db_id[gt_data["instance_id"]] = gt_data["db_id"]
instance_id2pred_sql_query = dict()
for gt_data, pred_sql in zip(gold, pred_sqls):
instance_id2pred_sql_query[gt_data["instance_id"]] = pred_sql
eval_ids = list(eval_standard_dict.keys())
assert len(gold) == len(pred_sqls) == len(eval_ids)
output_results = []
for instance_id in tqdm(eval_ids):
print(f">>>Evaluating {instance_id}...")
if instance_id not in instance_id2pred_sql_query:
raise ValueError("instance id '{instance_id}' not in instance_id2pred_sql_query")
if instance_id not in instance_id2db_id:
raise ValueError("instance id '{instance_id}' not in instance_id2db_id")
error_info = None
pred_sql_query = instance_id2pred_sql_query[instance_id]
db_file_path = os.path.join(db_path, instance_id2db_id[instance_id], instance_id2db_id[instance_id] + ".sqlite")
exe_flag, dbms_error_info = get_sqlite_result(db_file_path, pred_sql_query, temp_dir, f"{instance_id}_pred.csv")
if exe_flag == False:
score = 0
error_info = dbms_error_info
else:
pred_pd = pd.read_csv(os.path.join(temp_dir, f"{instance_id}_pred.csv"))
pattern = re.compile(rf'^{re.escape(instance_id)}(_[a-z])?\.csv$')
all_files = os.listdir(gold_result_dir)
csv_files = [file for file in all_files if pattern.match(file)]
if len(csv_files) == 1:
gold_pd = pd.read_csv(os.path.join(gold_result_dir, f"{instance_id}.csv"))
try:
score = compare_pandas_table(pred_pd, gold_pd, eval_standard_dict.get(instance_id)['condition_cols'], eval_standard_dict.get(instance_id)['ignore_order'])
except Exception as e:
print(f"An error occurred: {e}")
score = 0
error_info = 'Python Script Error:' + str(e)
if score == 0 and error_info is None:
error_info = 'Result Error'
# print("score:", score)
# print("pred_pd:\n", pred_pd)
# print("gold_pd:\n", gold_pd)
elif len(csv_files) > 1:
gold_pds = [pd.read_csv(os.path.join(gold_result_dir, file)) for file in csv_files]
score = compare_multi_pandas_table(pred_pd, gold_pds, eval_standard_dict.get(instance_id)['condition_cols'], eval_standard_dict.get(instance_id)['ignore_order'])
if score == 0 and error_info is None:
error_info = 'Result Error'
# print("score:", score)
# print("pred_pd:\n", pred_pd)
# print("gold_pds:\n", gold_pds)
output_results.append(
{
"instance_id": instance_id,
"score": score,
"pred_sql": pred_sql_query,
"error_info": error_info
}
)
print({item['instance_id']: item['score'] for item in output_results})
final_acc = sum([item['score'] for item in output_results]) / len(output_results)
print(f"Final score: {final_acc}")
print("Correct Instance ID:")
for item in output_results:
if item["score"] == 1:
print(item["instance_id"])
return output_results, final_acc
def execute_sql(data_idx, db_file, sql):
conn = sqlite3.connect(db_file)
cursor = conn.cursor()
try:
conn.execute("BEGIN TRANSACTION;")
cursor.execute(sql)
execution_res = cursor.fetchall()
execution_res = frozenset(execution_res) # make set hashable
conn.rollback()
conn.close()
return {"data_idx": data_idx, "sql": sql, "execution_res": execution_res, "valid_flag": 1}
except:
conn.rollback()
conn.close()
return {"data_idx": data_idx, "sql": sql, "execution_res": None, "valid_flag": 0}
def execute_sql_wrapper(data_idx, db_file, sql, timeout):
try:
res = func_timeout(timeout, execute_sql, args=(data_idx, db_file, sql))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
res = {"data_idx": data_idx, "sql": sql, "execution_res": None, "valid_flag": 0}
except Exception as e:
res = {"data_idx": data_idx, "sql": sql, "execution_res": None, "valid_flag": 0}
return res
def execute_callback_execute_sqls(result, all_execution_results):
# print("Done:", result["data_idx"])
all_execution_results.append(result)
def execute_sqls_parallel(all_db_files, all_sqls, all_execution_results, num_cpus=10, timeout=30):
pool = mp.Pool(processes=num_cpus)
for data_idx, db_file, sql in zip(list(range(len(all_sqls))), all_db_files, all_sqls):
callback_with_results = partial(execute_callback_execute_sqls, all_execution_results=all_execution_results)
pool.apply_async(execute_sql_wrapper, args=(data_idx, db_file, sql, timeout), callback = callback_with_results)
pool.close()
pool.join()
def evaluate(mode, gold_result_dir, eval_standard, gold_file, pred_file, db_path, save_pred_sqls):
eval_standard_dict = load_jsonl_to_dict(eval_standard)
gold = json.load(open(gold_file))
pred = json.load(open(pred_file))
pred_sql_key = "pred_sqls"
# pred_sql_key = "responses"
sampling_num = len(pred[0][pred_sql_key])
print(f"sampling_num: {sampling_num}")
all_db_files = []
all_pred_sqls = []
for gold_data, pred_data in tqdm(zip(gold, pred)):
db_file = os.path.join(db_path, gold_data["db_id"], gold_data["db_id"] + ".sqlite")
for sample_idx in range(sampling_num):
all_db_files.append(db_file)
all_pred_sqls.append(pred_data[pred_sql_key][sample_idx])
# obtain execution results of all predicted SQL queries
all_execution_results = []
execute_sqls_parallel(all_db_files, all_pred_sqls, all_execution_results, num_cpus=40, timeout=10)
all_execution_results = sorted(all_execution_results, key=lambda x: x["data_idx"])
print([res["data_idx"] for res in all_execution_results])
pred_sqls = []
for idx in range(len(gold)):
execution_results = all_execution_results[idx*sampling_num: (idx+1)*sampling_num]
if mode == "greedy_search":
# For greedy_search calculation, pred_sqls is a list of SQL query strings.
assert len(execution_results) == len(pred[0][pred_sql_key]) == sampling_num == 1
if execution_results[0]["valid_flag"] == 1:
pred_sqls.append(execution_results[0]["sql"])
else:
pred_sqls.append("Error SQL qeury")
elif mode == "major_voting":
# For major_voting calculation, pred_sqls is a list of SQL query strings.
assert len(execution_results) == len(pred[0][pred_sql_key]) == sampling_num
# no one pred sql is valid
if sum(res["valid_flag"] for res in execution_results) == 0:
pred_sqls.append("Error SQL qeury")
continue
major_voting_counting = dict()
for res in execution_results:
if res["valid_flag"] == 0:
continue
if res["execution_res"] in major_voting_counting:
major_voting_counting[res["execution_res"]][0] += 1
else:
major_voting_counting[res["execution_res"]] = [1, res["sql"]]
major_vote = max(major_voting_counting.values(), key=lambda x: x[0])
mj_pred_sql = major_vote[1]
pred_sqls.append(mj_pred_sql)
elif mode == "pass@k":
# For pass@k calculation, pred_sqls is a list where each element is a list of SQL query strings.
assert len(execution_results) == len(pred[0][pred_sql_key]) == sampling_num
pred_sqls.append([res["sql"] if res["valid_flag"] == 1 else "Error SQL query" for res in execution_results])
assert len(pred_sqls) == len(gold) == len(pred)
if mode in ["greedy_search", "major_voting"]:
temp_dir = tempfile.mkdtemp(prefix="temp-") # , dir="./"
print("temp_dir:", temp_dir)
output_results, final_acc = evaluate_spider2sql(
gold_result_dir,
eval_standard_dict,
gold,
pred_sqls,
db_path,
temp_dir
)
if save_pred_sqls:
suffix = "-pred-greedy-search-sqls.json" if mode == "greedy_search" else "-pred-major-voting-sqls.json"
with open(pred_file[:-5] + suffix, "w", encoding="utf-8") as f:
f.write(json.dumps(pred_sqls, indent=2 ,ensure_ascii=False))
print(f"{mode} ACC: {final_acc}")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
return final_acc, pred_sqls
elif mode == "pass@k":
all_scores = []
for sample_idx in range(sampling_num):
temp_dir = tempfile.mkdtemp(prefix="temp-") # , dir="./"
print("temp_dir:", temp_dir)
pred_sqls_for_specific_sample_idx = [sqls[sample_idx] for sqls in pred_sqls]
output_results, _ = evaluate_spider2sql(
gold_result_dir,
eval_standard_dict,
gold,
pred_sqls_for_specific_sample_idx,
db_path,
temp_dir
)
all_scores.append([item['score'] for item in output_results])
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print(all_scores)
pass_at_k_scores = [1 if any(column) else 0 for column in zip(*all_scores)]
final_acc = sum(pass_at_k_scores)/len(pass_at_k_scores)
print(pass_at_k_scores)
print(f"{mode} ACC: {final_acc}")
return final_acc, None
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run evaluations for NLP models.")
parser.add_argument("--mode", type=str, choices=["greedy_search", "major_voting", "pass@k"])
parser.add_argument("--pred", type=str, default="../results/Qwen2.5-Coder-7B-Instruct-spider2.0-test-greedy.json", help="Predicted result directory")
parser.add_argument('--gold', type = str, default = "./data/spider2.0/test.json")
parser.add_argument('--gold_result_dir', type = str, default = "./data/spider2.0/gold_exec_result")
parser.add_argument('--eval_standard', type = str, default = "./data/spider2.0/spider2lite_eval.jsonl")
parser.add_argument('--db_path', type = str, default = "./data/spider2.0/databases")
args = parser.parse_args()
evaluate(args.mode, args.gold_result_dir, args.eval_standard, args.gold, args.pred, args.db_path)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment