convert_chartqa.py 2.7 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
import pandas as pd
import cv2
import numpy as np

# ======================
# 自定义配置项(只需改这里即可切换数据集)
# ======================
DATASET_NAME = "chartqa"  # 图像目录和 JSON 中的前缀
IMAGE_DIR = DATASET_NAME    # 图像保存目录
JSON_FILE = f"llamafactory_{DATASET_NAME}.json"

# 创建保存图像的目录
os.makedirs(IMAGE_DIR, exist_ok=True)

def process_parquet_file(parquet_file, dataset):
    """处理单个 Parquet 文件并更新全局数据集"""
    df = pd.read_parquet(parquet_file)
    print(f"Processing file: {parquet_file}")

    global valid_index  # 使用全局计数器

    for index, row in df.iterrows():
        try:
            # 提取图像 bytes 数据
            image_bytes = row['image']['bytes']

            # 使用 OpenCV 解码图像
            image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
            if image_np is None:
                raise ValueError(f"Failed to decode image at index {index}")

            # 构造图像路径
            image_path = f"{IMAGE_DIR}/{valid_index}.jpg"

            # 保存图像
            cv2.imwrite(image_path, image_np)

            # 构造 messages 结构
            messages = [
                {"role": "user", "content": f"<image>{row['query']}"},
                {"role": "assistant", "content": str(row['label'][0])}
            ]

            # 添加到数据集
            dataset.append({
                "messages": messages,
                "images": [image_path]
            })

            # 更新全局计数器
            valid_index += 1

        except Exception as e:
            print(f"Error processing row {index} in {parquet_file}: {e}")

def process_all_train_parquet_files(data_dir):
    """批量处理 data 目录下所有以 train 开头的 Parquet 文件"""
    parquet_files = [
        os.path.join(data_dir, f) for f in os.listdir(data_dir)
        if f.startswith("train") and f.endswith(".parquet")
    ]

    if not parquet_files:
        print("❌ 未找到以 'train' 开头的 Parquet 文件")
        return

    # 存储所有样本的全局数据集
    dataset = []

    # 全局计数器,确保图像文件名唯一
    global valid_index
    valid_index = 0

    for parquet_file in parquet_files:
        process_parquet_file(parquet_file, dataset)

    # 生成 JSON 文件
    with open(JSON_FILE, 'w', encoding='utf-8') as f:
        json.dump(dataset, f, ensure_ascii=False, indent=2)

    print(f"\n✅ JSON 数据集已保存至:{JSON_FILE}")
    print(f"📁 总共提取图像数量:{len(dataset)} 张")

# 示例调用
data_dir = "/workspace/datasets/ChartQA/data"
process_all_train_parquet_files(data_dir)