# 日期: 2024/11 # 作者: 马顺 # 机构:sugon import random import torch import PIL from PIL import Image, ImageEnhance from typing import Optional, Union, List, Dict from torchvision import transforms as T import torchvision.transforms.functional as TF class DataAugment: def __init__(self, brightness: float = 0.5, contrast: float = 0.3, saturation: float = 0.5, hue: float = 0.5, brightness_p: float = 0.5, contrast_p: float = 0.5, saturation_p: float = 0.5, hue_p: float = 0.5, shift_p: float = 0.5, horizontal_flip_p: float = 0.5): self.color_jitter = T.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) self.brightness_p = brightness_p self.contrast_p = contrast_p self.saturation_p = saturation_p self.hue_p = hue_p self.shift_p = shift_p self.horizontal_flip_p = horizontal_flip_p def __call__(self, images: Dict, extra_condition_key): fn_idx, b, c, s, h = T.ColorJitter.get_params(self.color_jitter.brightness, self.color_jitter.contrast, self.color_jitter.saturation,self.color_jitter.hue) random_hflip = random.random() random_brightness = random.random() random_contrast = random.random() random_saturation = random.random() random_hue = random.random() random_shift = random.random() shift_valx = random.uniform(-0.2, 0.2) shift_valy = random.uniform(-0.2, 0.2) for key, image in images.items(): if key in ['person', 'cloth']: # for person and cloth if random_contrast < self.contrast_p: images[key] = TF.adjust_contrast(image, c) if random_brightness < self.brightness_p: images[key] = TF.adjust_brightness(image, b) if random_hue < self.hue_p: images[key] = TF.adjust_hue(image, h) if random_saturation < self.saturation_p: images[key] = TF.adjust_saturation(image, s) # for all if random_hflip < self.horizontal_flip_p: images[key] = TF.hflip(image) if random_shift < self.shift_p: # for person, mask, extra_condition if key in ['person', 'mask', extra_condition_key]: images[key] = TF.affine(images[key], angle=0, translate=[shift_valx*images[key].size[-1], shift_valy*images[key].size[-2]], scale=1, shear=0) return images