import os
import json
import pandas as pd
import cv2
import numpy as np

# ======================
# 自定义配置项（只需改这里即可切换数据集）
# ======================
DATASET_NAME = "chartqa"  # 图像目录和 JSON 中的前缀
IMAGE_DIR = DATASET_NAME    # 图像保存目录
JSON_FILE = f"llamafactory_{DATASET_NAME}.json"

# 创建保存图像的目录
os.makedirs(IMAGE_DIR, exist_ok=True)

def process_parquet_file(parquet_file, dataset):
    """处理单个 Parquet 文件并更新全局数据集"""
    df = pd.read_parquet(parquet_file)
    print(f"Processing file: {parquet_file}")

    global valid_index  # 使用全局计数器

    for index, row in df.iterrows():
        try:
            # 提取图像 bytes 数据
            image_bytes = row['image']['bytes']

            # 使用 OpenCV 解码图像
            image_np = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
            if image_np is None:
                raise ValueError(f"Failed to decode image at index {index}")

            # 构造图像路径
            image_path = f"{IMAGE_DIR}/{valid_index}.jpg"

            # 保存图像
            cv2.imwrite(image_path, image_np)

            # 构造 messages 结构
            messages = [
                {"role": "user", "content": f"<image>{row['query']}"},
                {"role": "assistant", "content": str(row['label'][0])}
            ]

            # 添加到数据集
            dataset.append({
                "messages": messages,
                "images": [image_path]
            })

            # 更新全局计数器
            valid_index += 1

        except Exception as e:
            print(f"Error processing row {index} in {parquet_file}: {e}")

def process_all_train_parquet_files(data_dir):
    """批量处理 data 目录下所有以 train 开头的 Parquet 文件"""
    parquet_files = [
        os.path.join(data_dir, f) for f in os.listdir(data_dir)
        if f.startswith("train") and f.endswith(".parquet")
    ]

    if not parquet_files:
        print("❌ 未找到以 'train' 开头的 Parquet 文件")
        return

    # 存储所有样本的全局数据集
    dataset = []

    # 全局计数器，确保图像文件名唯一
    global valid_index
    valid_index = 0

    for parquet_file in parquet_files:
        process_parquet_file(parquet_file, dataset)

    # 生成 JSON 文件
    with open(JSON_FILE, 'w', encoding='utf-8') as f:
        json.dump(dataset, f, ensure_ascii=False, indent=2)

    print(f"\n✅ JSON 数据集已保存至：{JSON_FILE}")
    print(f"📁 总共提取图像数量：{len(dataset)} 张")

# 示例调用
data_dir = "/workspace/datasets/ChartQA/data"
process_all_train_parquet_files(data_dir)

