convert_llava_pretrain_to_wds.py 905 Bytes
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import json
import os
import webdataset as wds

from tqdm import tqdm

llava_pretrain_dir = '<path_to_LLaVA-Pretrain>'

# Paths to the dataset files
json_file = os.path.join(llava_pretrain_dir, 'blip_laion_cc_sbu_558k.json')
output = os.path.join(llava_pretrain_dir, 'wds')

if not os.path.exists(output):
    os.mkdir(output)

# Load data
with open(json_file, 'r') as f:
    data = json.load(f)

with wds.ShardWriter(os.path.join(output, 'pretrain-%d.tar'), maxcount=10000) as shard_writer:
    for entry in tqdm(data):
        with open(os.path.join(llava_pretrain_dir, entry['image']), "rb") as img_file:
                image_data = img_file.read()
        sample = {
            "__key__": entry['id'],
            "jpg": image_data,
            "json": json.dumps(entry['conversations']).encode("utf-8"),
        }
        shard_writer.write(sample)

print(f"Dataset successfully converted to wds")