prepare_for_training.py 1.4 KB
Newer Older
ai_public's avatar
ai_public committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 将数据转换为json格式,此脚本适用于cc3m
# [
#    {"text": "a dog", "image_file": "dog.jpg"}
# ] 
import json

from pathlib import Path


def convert_to_json(data_root: str,
                    save_path: str):
    
    data_root = Path(data_root)
    
    txt_path_list = [*data_root.glob("*.txt")]
    
    image_path_list = [*data_root.glob("*.png"),
                       *data_root.glob("*.jpg"),
                       *data_root.glob("*.jpeg")]
    
    text_path_mapping = {
        txt_path.stem: txt_path for txt_path in txt_path_list
    }

    image_path_mapping = {
        image_path.stem: image_path for image_path in image_path_list
    }
    
    keys = list(set(text_path_mapping.keys()) & set(image_path_mapping.keys()))
    
    results = []
    
    for key in keys:
        with open(text_path_mapping[key]) as f:
            text = f.read().strip()
        
        results.append({"text": text, "image_file": str(image_path_mapping[key])})
    
    with open(save_path, "w") as f:
        json.dump(results, f, ensure_ascii=False)


if __name__ == "__main__":
    import argparse 
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--data_root", type=str, help="图像-文本存储位置")
    
    parser.add_argument("--save_path", type=str, help="json文件存储位置")
    
    args = parser.parse_args()
    
    convert_to_json(args.data_root, args.save_path)