import os import json import pandas as pd import cv2 import numpy as np # ====================== # 自定义配置项(只需改这里即可切换数据集) # ====================== DATASET_NAME = "chartqa" # 图像目录和 JSON 中的前缀 IMAGE_DIR = DATASET_NAME # 图像保存目录 JSON_FILE = f"llamafactory_{DATASET_NAME}.json" # 创建保存图像的目录 os.makedirs(IMAGE_DIR, exist_ok=True) def process_parquet_file(parquet_file, dataset): """处理单个 Parquet 文件并更新全局数据集""" df = pd.read_parquet(parquet_file) print(f"Processing file: {parquet_file}") global valid_index # 使用全局计数器 for index, row in df.iterrows(): try: # 提取图像 bytes 数据 image_bytes = row['image']['bytes'] # 使用 OpenCV 解码图像 image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR) if image_np is None: raise ValueError(f"Failed to decode image at index {index}") # 构造图像路径 image_path = f"{IMAGE_DIR}/{valid_index}.jpg" # 保存图像 cv2.imwrite(image_path, image_np) # 构造 messages 结构 messages = [ {"role": "user", "content": f"{row['query']}"}, {"role": "assistant", "content": str(row['label'][0])} ] # 添加到数据集 dataset.append({ "messages": messages, "images": [image_path] }) # 更新全局计数器 valid_index += 1 except Exception as e: print(f"Error processing row {index} in {parquet_file}: {e}") def process_all_train_parquet_files(data_dir): """批量处理 data 目录下所有以 train 开头的 Parquet 文件""" parquet_files = [ os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.startswith("train") and f.endswith(".parquet") ] if not parquet_files: print("❌ 未找到以 'train' 开头的 Parquet 文件") return # 存储所有样本的全局数据集 dataset = [] # 全局计数器,确保图像文件名唯一 global valid_index valid_index = 0 for parquet_file in parquet_files: process_parquet_file(parquet_file, dataset) # 生成 JSON 文件 with open(JSON_FILE, 'w', encoding='utf-8') as f: json.dump(dataset, f, ensure_ascii=False, indent=2) print(f"\n✅ JSON 数据集已保存至:{JSON_FILE}") print(f"📁 总共提取图像数量:{len(dataset)} 张") # 示例调用 data_dir = "/workspace/datasets/ChartQA/data" process_all_train_parquet_files(data_dir)