from share import *
import config

import cv2
import einops
import numpy as np
import torch
import random
import matplotlib.pyplot as plt

from PIL import Image
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler


apply_canny = CannyDetector()

model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)


def process(input_image="test_imgs/bird.png", 
            prompt="bird", 
            a_prompt="best quality, extremely detailed", 
            n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 
            num_samples=4, 
            image_resolution=512, 
            ddim_steps=20, 
            guess_mode=False, 
            strength=1, 
            scale=9.0, 
            seed=-1, 
            eta=0, 
            low_threshold=100, 
            high_threshold=200):
    with torch.no_grad():
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape

        detected_map = apply_canny(img, low_threshold, high_threshold)
        detected_map = HWC3(detected_map)

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
        shape = (4, H // 8, W // 8)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=True)

        model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                     shape, cond, verbose=False, eta=eta,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=un_cond)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

        results = [x_samples[i] for i in range(num_samples)]
    
    return [255 - detected_map] + results


def arange_imgs(images):
    fig, axes = plt.subplots(nrows=1, ncols=len(images), figsize=(12, 4))
    
    for i, image in enumerate(images):
        axes[i].imshow(image)
        axes[i].axis('off')
        
    plt.tight_layout()
    plt.savefig("test_results/test.png")


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--input_image", type=str, default="test_imgs/bird.png", help="输入图片")
    
    parser.add_argument("--prompt", type=str, default="bird")
    
    parser.add_argument("--positive_prompt", type=str, default="best quality, extremely detailed")
    
    parser.add_argument("--negative_prompt", type=str, default="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality")
    
    parser.add_argument("--num_samples", type=int, default=4)
    
    parser.add_argument("--image_resolution", type=int, default=512)
    
    parser.add_argument("--ddim_steps", type=int, default=20)
    
    parser.add_argument("--guess_mode", default=False, action="store_true")
    
    parser.add_argument("--strength", type=float, default=1)
    
    parser.add_argument("--scale", type=float, default=9)

    parser.add_argument("--seed", type=int, default=-1)
    
    parser.add_argument("--eta", type=float, default=0)

    parser.add_argument("--low_threshold", type=float, default=100)

    parser.add_argument("--high_threshold", type=float, default=200)
    
    args = parser.parse_args()
    
    input_image = np.array(Image.open(args.input_image))
    
    images = process(
            input_image,
            prompt=args.prompt,
            a_prompt=args.positive_prompt,
            n_prompt=args.negative_prompt,
            num_samples=args.num_samples,
            image_resolution=args.image_resolution,
            ddim_steps=args.ddim_steps,
            guess_mode=args.guess_mode,
            strength=args.strength,
            scale=args.scale,
            seed=args.seed,
            eta=args.eta,
            low_threshold=args.low_threshold,
            high_threshold=args.high_threshold
        )
    
    arange_imgs(images)
    