update_dataset_suffix.py 4.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python3
import argparse
import glob
import hashlib
import json
import os
import re
from multiprocessing import Pool
from typing import List, Union

from mmengine.config import Config, ConfigDict


# from opencompass.utils import get_prompt_hash
# copied from opencompass.utils.get_prompt_hash, for easy use in ci
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
    """Get the hash of the prompt configuration.

    Args:
        dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
            configuration.

    Returns:
        str: The hash of the prompt configuration.
    """
    if isinstance(dataset_cfg, list):
        if len(dataset_cfg) == 1:
            dataset_cfg = dataset_cfg[0]
        else:
            hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
            hash_object = hashlib.sha256(hashes.encode())
            return hash_object.hexdigest()
    if 'reader_cfg' in dataset_cfg.infer_cfg:
        # new config
        reader_cfg = dict(type='DatasetReader',
                          input_columns=dataset_cfg.reader_cfg.input_columns,
                          output_column=dataset_cfg.reader_cfg.output_column)
        dataset_cfg.infer_cfg.reader = reader_cfg
        if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
            dataset_cfg.infer_cfg.retriever[
                'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
                    'train_split']
        if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
            dataset_cfg.infer_cfg.retriever[
                'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
        for k, v in dataset_cfg.infer_cfg.items():
            dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
48
49
50
51
    # A compromise for the hash consistency
    if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
        fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
        dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
    hash_object = hashlib.sha256(d_json.encode())
    return hash_object.hexdigest()


# Assuming get_hash is a function that computes the hash of a file
# from get_hash import get_hash
def get_hash(path):
    cfg = Config.fromfile(path)
    for k in cfg.keys():
        if k.endswith('_datasets'):
            return get_prompt_hash(cfg[k])[:6]
    print(f'Could not find *_datasets in {path}')
    return None


def check_and_rename(filepath):
    base_name = os.path.basename(filepath)
    match = re.match(r'(.*)_(gen|ppl)_(.*).py', base_name)
    if match:
        dataset, mode, old_hash = match.groups()
        new_hash = get_hash(filepath)
        if not new_hash:
            return None, None
        if old_hash != new_hash:
            new_name = f'{dataset}_{mode}_{new_hash}.py'
            new_file = os.path.join(os.path.dirname(filepath), new_name)
            print(f'Rename {filepath} to {new_file}')
            return filepath, new_file
    return None, None


def update_imports(data):
    python_file, name_pairs = data
    for filepath, new_file in name_pairs:
        old_name = os.path.basename(filepath)[:-3]
        new_name = os.path.basename(new_file)[:-3]
        if not os.path.exists(python_file):
            return
        with open(python_file, 'r') as file:
            filedata = file.read()
        # Replace the old name with new name
        new_data = filedata.replace(old_name, new_name)
        if filedata != new_data:
            with open(python_file, 'w') as file:
                file.write(new_data)
            # print(f"Updated imports in {python_file}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('python_files', nargs='*')
    args = parser.parse_args()

    root_folder = 'configs/datasets'
    if args.python_files:
        python_files = [
            i for i in args.python_files if i.startswith(root_folder)
        ]
    else:
        python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)

    # Use multiprocessing to speed up the check and rename process
    with Pool(16) as p:
        name_pairs = p.map(check_and_rename, python_files)
    name_pairs = [pair for pair in name_pairs if pair[0] is not None]
118
119
    if not name_pairs:
        return
120
121
    with Pool(16) as p:
        p.starmap(os.rename, name_pairs)
122
    python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
123
124
125
126
127
128
129
    update_data = [(python_file, name_pairs) for python_file in python_files]
    with Pool(16) as p:
        p.map(update_imports, update_data)


if __name__ == '__main__':
    main()