build_contents_index.py 5.89 KB
Newer Older
mashun1's avatar
omnisql  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import os, shutil
import sqlite3
from func_timeout import func_set_timeout, FunctionTimedOut
from pathlib import Path

# get the database cursor for a sqlite database path
def get_cursor_from_path(sqlite_path):
    try:
        if not os.path.exists(sqlite_path):
            print("Openning a new connection %s" % sqlite_path)
        connection = sqlite3.connect(sqlite_path, check_same_thread = False)
    except Exception as e:
        print(sqlite_path)
        raise e
    connection.text_factory = lambda b: b.decode(errors="ignore")
    cursor = connection.cursor()
    return cursor

# execute predicted sql with a long time limitation (for buiding content index)
@func_set_timeout(3600)
def execute_sql(cursor, sql):
    cursor.execute(sql)

    return cursor.fetchall()

def remove_contents_of_a_folder(index_path):
    # if index_path does not exist, then create it
    os.makedirs(index_path, exist_ok = True)
    # remove files in index_path
    for filename in os.listdir(index_path):
        file_path = os.path.join(index_path, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def build_content_index(db_file_path, index_path):
    '''
    create BM25 index for all string values in a database
    '''
    cursor = get_cursor_from_path(db_file_path)
    results = execute_sql(cursor, "SELECT name FROM sqlite_master WHERE type='table';")
    table_names = [result[0] for result in results]

    all_column_contents = []
    for table_name in table_names:
        # skip SQLite system table: sqlite_sequence
        if table_name == "sqlite_sequence":
            continue
        results = execute_sql(cursor, f"SELECT name FROM PRAGMA_TABLE_INFO('{table_name}')")
        column_names_in_one_table = [result[0] for result in results]
        for column_name in column_names_in_one_table:
            try:
                print(f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
                results = execute_sql(cursor, f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL;")
                column_contents = [result[0] for result in results if isinstance(result[0], str) and not is_number(result[0])]

                for c_id, column_content in enumerate(column_contents):
                    # remove empty and extremely-long contents
                    if len(column_content) != 0 and len(column_content) <= 40:
                        all_column_contents.append(
                            {
                                "id": "{}-**-{}-**-{}".format(table_name, column_name, c_id), # .lower()
                                "contents": column_content
                            }
                        )
            except Exception as e:
                print(str(e))

    os.makedirs('./data/temp_db_index', exist_ok = True)
    
    with open("./data/temp_db_index/contents.json", "w") as f:
        f.write(json.dumps(all_column_contents, indent = 2, ensure_ascii = True))

    # Building a BM25 Index (Direct Java Implementation), see https://github.com/castorini/pyserini/blob/master/docs/usage-index.md
    cmd = f'python -m pyserini.index.lucene --collection JsonCollection --input ./data/temp_db_index --index "{index_path}" --generator DefaultLuceneDocumentGenerator --threads 16 --storePositions --storeDocvectors --storeRaw'
    
    d = os.system(cmd)
    print(d)
    os.remove("./data/temp_db_index/contents.json")

if __name__ == "__main__":
    dataset_info = {
        # BIRD train
        "bird_train": {"db_path": "./data/bird/train/train_databases", "index_path_prefix": "./data/bird/train/db_contents_index"},
        # BIRD dev
        "bird_dev": {"db_path": "./data/bird/dev_20240627/dev_databases", "index_path_prefix": "./data/bird/dev_20240627/db_contents_index"},
        # Spider train-dev-test
        "spider": {"db_path": "./data/spider/test_database", "index_path_prefix": "./data/spider/db_contents_index"},
        # Spider2.0-SQLite
        "spider2_sqlite": {"db_path": "./data/spider2_sqlite/databases", "index_path_prefix": "./data/spider2_sqlite/db_contents_index"},
        # SynSQL-2.5M dataset
        "SynSQL-2.5M": {"db_path": "./data/SynSQL-2.5M/databases", "index_path_prefix": "./data/SynSQL-2.5M/db_contents_index"},
        # spider-dk
        "spider_dk": {"db_path": "./data/Spider-DK/database", "index_path_prefix": "./data/Spider-DK/db_contents_index"},
        # EHRSQL_dev
        "EHRSQL_dev": {"db_path": "./data/EHRSQL/database", "index_path_prefix": "./data/EHRSQL/db_contents_index"},
        # sciencebenchmark_dev
        "sciencebenchmark_dev": {"db_path": "./data/sciencebenchmark/databases", "index_path_prefix": "./data/sciencebenchmark/db_contents_index"},
    }

    for dataset_name in dataset_info:
        print(dataset_name)
        db_path = dataset_info[dataset_name]["db_path"]
        index_path_prefix = dataset_info[dataset_name]["index_path_prefix"]
        remove_contents_of_a_folder(index_path_prefix)
        # build content index
        db_ids = os.listdir(db_path)
        # db_ids = ["the_table's_domain_appears_to_be_related_to_demographic_and_employment_data"]
        for db_id in db_ids:
            db_file_path = os.path.join(db_path, db_id, db_id + ".sqlite")
            if os.path.exists(db_file_path) and os.path.isfile(db_file_path):
                print(f"The file '{db_file_path}' exists.")
                build_content_index(
                    db_file_path,
                    os.path.join(index_path_prefix, db_id)
                )
            else:
                print(f"The file '{db_file_path}' does not exist.")