# 提取给定VidGen_1M视频对应的文本，设置相应特征保存位置
import os
import json

from pathlib import Path
from argparse import ArgumentParser


def get_video_text(video_root,
                   caption_json_path,
                   save_root,
                   video_latent_root,
                   text_fea_root):
    video_root = Path(video_root)
    
    # 最多支持3级目录
    video_path_list = [*video_root.glob("*.mp4"), *video_root.glob("*/*.mp4"), *video_root.glob("*/*/*.mp4")]
    vid_path = {p.stem: str(p.resolve()) for p in video_path_list}
    
    with open(caption_json_path, "r") as f:
        captions = json.load(f)
    
    vid_caption = {}
    
    for d in captions:
        vid_caption[d['vid']] = d['caption']
    
    os.makedirs(video_latent_root, exist_ok=True)
    os.makedirs(text_fea_root, exist_ok=True)
    
    with open(os.path.join(save_root, "video_text.jsonl"), "w") as f:
        for vid, vpath in vid_path.items():
            text = vid_caption[vid]
            latent_path = str(Path(os.path.join(video_latent_root, f"{vid}.pt")).resolve())
            text_fea_path = str(Path(os.path.join(text_fea_root, f"{vid}-text.pt")).resolve())
            f.write(json.dumps({"video": vpath, "text": text, "latent": latent_path, "text_fea": text_fea_path}, ensure_ascii=False) + '\n')
    

if __name__ == "__main__":
    
    parser = ArgumentParser()
    
    parser.add_argument("--video_root", type=str)
    
    parser.add_argument("-cjp", "--caption_json_path", type=str)
    
    parser.add_argument("-sr", "--save_root", type=str)
    
    parser.add_argument("-vlr", "--video_latent_root", type=str)
    
    parser.add_argument("-tfr", "--text_fea_root", type=str)
    
    args = parser.parse_args()
    
    get_video_text(args.video_root, 
                   args.caption_json_path, 
                   args.save_root, 
                   args.video_latent_root,
                   args.text_fea_root)
    
    