import torch
import time
import os

from typing import Optional
from pathlib import Path

from diffusers.utils import load_image
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting, \
                      StableDiffusionXLInpaintPipeline, AutoPipelineForImage2Image, \
                      StableDiffusionXLImg2ImgPipeline, DiffusionPipeline


def get_pipeline(base_path: Optional[str] = None,
                 refiner_path: Optional[str] = None,
                 mode: str = "t2i"):
    # mode = Union[t2i, i2i, inpainting, t2i_wr, inpainting_wr]
    
    if "t2i" in mode:
        if "wr" in mode:
            base_pipeline = DiffusionPipeline.from_pretrained(
                base_path,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            ).to("cuda")
            
            refiner_pipeline = DiffusionPipeline.from_pretrained(
                refiner_path,
                text_encoder_2=base_pipeline.text_encoder_2,
                vae=base_pipeline.vae,
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16",
            ).to("cuda")
            
            pipelines = [base_pipeline, refiner_pipeline]
            
        else:
            base_pipeline = AutoPipelineForText2Image.from_pretrained(
                base_path,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            ).to("cuda")
            
            pipelines = [base_pipeline]
    elif "i2i" in mode:
        base_pipeline = AutoPipelineForImage2Image.from_pretrained(
            base_path,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True
        ).to("cuda")
        
        pipelines = [base_pipeline]
        
    elif "inpainting" in mode:
        if "wr" in mode:
            base_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
                base_path,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            ).to("cuda")
            
            refiner_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
                refiner_path,
                text_encoder_2=base_pipeline.text_encoder_2,
                vae=base_pipeline.vae,
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16"
            ).to("cuda")
            
            pipelines = [base_pipeline, refiner_pipeline]
        else:
            base_pipeline = AutoPipelineForInpainting.from_pretrained(
                base_path,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            ).to("cuda")
            
            pipelines = [base_pipeline]
    else:
        raise NotImplemented
            
    return pipelines


def save_img(img,
             save_root,
             mode):
    save_root = os.path.join(save_root, mode)
    
    os.makedirs(save_root, exist_ok=True)
    
    save_path = os.path.join(save_root, str(time.time())+".png")
    
    img.save(save_path)
    

def t2i(pipelines: list,
        prompt: str):
    if len(pipelines) == 1:  # only base
        image = pipelines[0](prompt=prompt).images[0]
    else:
        image = pipelines[0](prompt=prompt,
                             num_inference_steps=40,
                             denoising_end=0.8,
                             output_type="latent").images
        image = pipelines[1](prompt=prompt,
                             num_inference_steps=40,
                             denoising_start=0.8,
                             image=image).images[0]
    
    return image
        

def i2i(pipelines: list,
        prompt: str,
        img_path: str):
    if len(pipelines) == 1:
        init_image = load_image(img_path)
        image = pipelines[0](prompt=prompt,
                             image=init_image,
                             strength=0.8,
                             guidance_scale=10.5).images[0]
    else:
        raise NotImplementedError

    return image


def inpainting(pipelines: list,
               prompt: str,
               img_path: str,
               mask_path: str):
    init_image = load_image(img_path)
    mask_image = load_image(mask_path)
    
    if len(pipelines) == 1:
        image = pipelines[0](prompt=prompt,
                             image=init_image,
                             mask_image=mask_image,
                             strength=0.85,
                             guidance_scale=12.5).images[0]
    else:
        image = pipelines[0](prompt=prompt,
                             image=init_image,
                             mask_image=mask_image,
                             num_inference_steps=75,
                             denoising_end=0.7,
                             output_type="latent").images
        
        image = pipelines[1](prompt=prompt,
                             image=image,
                             mask_image=mask_image,
                             num_inference_steps=75,
                             denoising_start=0.7).images[0]

    return image


def inference(args):
    
    pipelines = get_pipeline(args.base_path,
                             args.refiner_path,
                             args.mode)

    if "t2i" in args.mode:
        img = t2i(pipelines, args.prompt)
    elif "i2i" in args.mode:
        img = i2i(pipelines, args.prompt, args.img_path)
    elif "inpainting" in args.mode:
        img = inpainting(pipelines, args.prompt, args.img_path, args.mask_path)
    
    save_img(img, args.save_root, args.mode)
    


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    default_base_path = str(Path(__file__).resolve().parent / "pretrained_models" / "stable-diffusion-xl-base-1.0")

    default_refiner_path = str(Path(__file__).resolve().parent / "pretrained_models" / "stable-diffusion-xl-refiner-1.0")
    
    parser = ArgumentParser()
    
    parser.add_argument("--base_path", default=default_base_path, type=str)
    
    parser.add_argument("--refiner_path", default=default_refiner_path, type=str)
    
    parser.add_argument("--prompt", default="a panda is playing a ball", type=str)
    
    parser.add_argument("--img_path", type=str, default="")
    
    parser.add_argument("--mask_path", type=str, default="")
    
    parser.add_argument("--save_root", default="./results", type=str)
    
    parser.add_argument("--mode", type=str, required=True)
    
    args = parser.parse_args()
    
    inference(args)