# 将数据路径放入json文件
from pathlib import Path
from glob import glob

import os
import json
import random


def make_mapping(path_list):
    return {path.stem: str(path.resolve()) for path in path_list}


def prepare_json(person_image_root,
                 cloth_image_root,
                 mask_root,
                 extra_condition_image_root,
                 extra_condition_key,
                 eval_nums: int,
                 save_root: str):
    person_image_root = Path(person_image_root)
    cloth_image_root = Path(cloth_image_root)
    mask_root = Path(mask_root)
    extra_condition_image_root = Path(extra_condition_image_root)

    person_image_path_list = [*person_image_root.glob("*.png"), *person_image_root.glob("*.jpg"), *person_image_root.glob("*.jpeg")]
    cloth_image_path_list = [*cloth_image_root.glob("*.png"), *cloth_image_root.glob("*.jpg"), *cloth_image_root.glob("*.jpeg")]
    mask_path_list = [*mask_root.glob("*.png"), *mask_root.glob("*.jpg"), *mask_root.glob("*.jpeg")]
    extra_condition_image_path_list = [*extra_condition_image_root.glob("*.png"), *extra_condition_image_root.glob("*.jpg"), *extra_condition_image_root.glob("*.jpeg")]

    person_image_path_mapping = make_mapping(person_image_path_list)
    cloth_image_path_mapping = make_mapping(cloth_image_path_list)
    mask_path_mapping = make_mapping(mask_path_list)
    extra_condition_image_path_mapping = make_mapping(extra_condition_image_path_list)

    keys = set(person_image_path_mapping.keys()) & set(cloth_image_path_mapping.keys()) & \
           set(mask_path_mapping.keys())  & set(extra_condition_image_path_mapping.keys())
    
    keys = list(keys)
    
    all_index = range(len(keys))
    eval_index = set(random.choices(all_index, k=eval_nums))
    train_index = set(all_index) - eval_index
    
    eval_index, train_index = list(eval_index), list(train_index)
        
    with open(os.path.join(save_root, "train_data.jsonl"), "w") as f:
        for idx in train_index:
            key = keys[idx]
            temp = {}
            temp['person_img_path'] = person_image_path_mapping[key]
            temp['cloth_img_path'] = cloth_image_path_mapping[key]
            temp['mask_img_path'] = mask_path_mapping[key]
            temp[extra_condition_key] = extra_condition_image_path_mapping[key]

            f.write(json.dumps(temp, ensure_ascii=False) + '\n')
    
    with open(os.path.join(save_root, "eval_data.jsonl"), "w") as f:
        for idx in eval_index:
            key = keys[idx]
            temp = {}
            temp['person_img_path'] = person_image_path_mapping[key]
            temp['cloth_img_path'] = cloth_image_path_mapping[key]
            temp['mask_img_path'] = mask_path_mapping[key]
            temp[extra_condition_key] = extra_condition_image_path_mapping[key]

            f.write(json.dumps(temp, ensure_ascii=False) + '\n')


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()

    parser.add_argument("--person_image_root", type=str)

    parser.add_argument("--cloth_image_root", type=str)

    parser.add_argument("--mask_root", type=str)

    parser.add_argument("--extra_condition_image_root", type=str)
    
    parser.add_argument("--extra_condition_key", type=str)
    
    parser.add_argument("--eval_nums", type=int, default=8)

    parser.add_argument("--save_root", type=str)

    args = parser.parse_args()

    prepare_json(args.person_image_root,
                 args.cloth_image_root,
                 args.mask_root,
                 args.extra_condition_image_root,
                 args.extra_condition_key,
                 args.eval_nums,
                 args.save_root)

