# 1  读取base64图像
# 2  保存该图像
# 3  按比例进行放缩
# 4  保存该图像
# 5  修改数据文件格式
# 6  保存该文件
import base64
from PIL import Image
from io import BytesIO
import json
import os
import numpy as np
import cv2
import re


def decode_image(data_uri):
    # 分离data:image/png;base64,部分
    header, encoded = data_uri.split(",", 1)
    img_bytes = base64.b64decode(encoded)
    img = Image.open(BytesIO(img_bytes))

    return img



def scale_action_coordinates(content_str: str, scale: float) -> str:
    pattern = r"start_box='\((\d+),\s*(\d+)\)'"
    match = re.search(pattern, content_str)
    if match:
        x, y = map(int, match.groups())
        new_x = int(x * scale)
        new_y = int(y * scale)
        new_box = f"start_box='({new_x},{new_y})'"
        # 替换原始字符串中的坐标
        content_str = re.sub(pattern, new_box, content_str)
    return content_str


def resize_image(img, scale: float = 0.7):
    img_np = np.array(img)
    img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

    new_w = int(img_cv2.shape[1] * scale)
    new_h = int(img_cv2.shape[0] * scale)
    resized_img = cv2.resize(img_cv2, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)

    result_img = Image.fromarray(cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB))
    return result_img


def generate(data_path: str,
             save_root: str,
             messages_file_name: str,
             scale: float = 0.7):
    
    with open(data_path, "r", encoding="utf8") as f:
        messages = json.load(f)
    
    idx = 0
    
    ori_img_save_dir = os.path.join(save_root, "images")
    resized_img_save_dir = os.path.join(save_root, "resized_images")
    
    os.makedirs(ori_img_save_dir, exist_ok=True)
    os.makedirs(resized_img_save_dir, exist_ok=True)
    
    for block in messages:
        role = block.get("role", "")
        content = block.get("content", None)
        if role.lower() == "user":
            if isinstance(content, list):
                for sub_block in content:
                    image_url = sub_block.get("image_url", None)
                    if image_url:
                        url = image_url.get("url", "")
                        if url:
                            original_img = decode_image(url)
                            resized_img = resize_image(original_img, scale=scale)
                            original_img.save(os.path.join(ori_img_save_dir, f"image_{idx}.png"))
                            resized_img.save(os.path.join(resized_img_save_dir, f"resized_image_{idx}.png"))
                            
                            del sub_block['image_url']
                            sub_block['type'] = "image"
                            sub_block['url'] = os.path.join(resized_img_save_dir, f"resized_image_{idx}.png")
                            
                            idx += 1
        elif role.lower() == "assistant":
            block['content'] = scale_action_coordinates(content, scale)
    
    messages_save_path = os.path.join(save_root, messages_file_name)
    with open(messages_save_path, "w", encoding='utf8') as f:
        json.dump(messages, f, indent=4, ensure_ascii=False)
        
    
if __name__ == "__main__":
    generate("data/test_messages.json", "data", "test_messages_07.json")

    
