# 日期: 2024/11
# 作者: 马顺
# 机构：sugon

import random
import torch
import PIL

from PIL import Image, ImageEnhance
from typing import Optional, Union, List, Dict
from torchvision import transforms as T
import torchvision.transforms.functional as TF


class DataAugment:
    
    def __init__(self,
                 brightness: float = 0.5,
                 contrast: float = 0.3,
                 saturation: float = 0.5,
                 hue: float = 0.5,
                 brightness_p: float = 0.5,
                 contrast_p: float = 0.5,
                 saturation_p: float = 0.5,
                 hue_p: float = 0.5,
                 shift_p: float = 0.5,
                 horizontal_flip_p: float = 0.5):
        self.color_jitter = T.ColorJitter(brightness=brightness,
                                          contrast=contrast,
                                          saturation=saturation,
                                          hue=hue)
        
        self.brightness_p = brightness_p
        self.contrast_p = contrast_p
        self.saturation_p = saturation_p
        self.hue_p = hue_p
        self.shift_p = shift_p
        self.horizontal_flip_p = horizontal_flip_p
    
    def __call__(self, images: Dict, extra_condition_key):
        
        fn_idx, b, c, s, h = T.ColorJitter.get_params(self.color_jitter.brightness, self.color_jitter.contrast, self.color_jitter.saturation,self.color_jitter.hue)
        random_hflip = random.random()
        random_brightness = random.random()
        random_contrast = random.random()
        random_saturation = random.random()
        random_hue = random.random()
        random_shift = random.random()
        
        shift_valx = random.uniform(-0.2, 0.2)
        shift_valy = random.uniform(-0.2, 0.2)
        
        for key, image in images.items():
            if key in ['person', 'cloth']:
                # for person and cloth
                if random_contrast < self.contrast_p:
                    images[key] = TF.adjust_contrast(image, c)
                if random_brightness < self.brightness_p:
                    images[key] = TF.adjust_brightness(image, b)
                if random_hue < self.hue_p:
                    images[key] = TF.adjust_hue(image, h)
                if random_saturation < self.saturation_p:
                    images[key] = TF.adjust_saturation(image, s)
            # for all
            if random_hflip < self.horizontal_flip_p:
                images[key] = TF.hflip(image)
            
            if random_shift < self.shift_p:
                # for person, mask, extra_condition
                if key in ['person', 'mask', extra_condition_key]:
                    images[key] = TF.affine(images[key], angle=0, translate=[shift_valx*images[key].size[-1], shift_valy*images[key].size[-2]], scale=1, shear=0)
        
        return images
