import os
import json

from pathlib import Path


def get_mapping(path_list):
    return {p.stem: str(p.resolve()) for p in path_list}


def generate_image_text(image_root: str,
                        prompt_root: str,
                        save_root: str):
    image_root = Path(image_root)
    prompt_root = Path(prompt_root)
    
    image_path_list = [*image_root.glob("*.jpg"), *image_root.glob("*.png"), *image_root.glob("*.JPEG")]
    prompt_path_list = [*prompt_root.glob("*.json")]
    
    image_path_mapping = get_mapping(image_path_list)
    prompt_path_mapping = get_mapping(prompt_path_list)
    
    keys = set(image_path_mapping.keys()) & set(prompt_path_mapping.keys())
    
    for key in keys:
        with open(prompt_path_mapping[key], "r") as f:
            text = json.loads(f.read().strip())['prompt']
        tmp = {"image": image_path_mapping[key], "text": text}
        with open(os.path.join(save_root, "image_text.jsonl"), "a") as f:
            f.write(json.dumps(tmp, ensure_ascii=False) + '\n')
            

if __name__ == "__main__":
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--image_root", type=str)

    parser.add_argument("--prompt_root", type=str)

    parser.add_argument("--save_root", type=str)
    
    args = parser.parse_args()
    
    generate_image_text(args.image_root, args.prompt_root, args.save_root)
