aug.py 2.81 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# 日期: 2024/11
# 作者: 马顺
# 机构:sugon

import random
import torch
import PIL

from PIL import Image, ImageEnhance
from typing import Optional, Union, List, Dict
from torchvision import transforms as T
import torchvision.transforms.functional as TF


class DataAugment:
    
    def __init__(self,
                 brightness: float = 0.5,
                 contrast: float = 0.3,
                 saturation: float = 0.5,
                 hue: float = 0.5,
                 brightness_p: float = 0.5,
                 contrast_p: float = 0.5,
                 saturation_p: float = 0.5,
                 hue_p: float = 0.5,
                 shift_p: float = 0.5,
                 horizontal_flip_p: float = 0.5):
        self.color_jitter = T.ColorJitter(brightness=brightness,
                                          contrast=contrast,
                                          saturation=saturation,
                                          hue=hue)
        
        self.brightness_p = brightness_p
        self.contrast_p = contrast_p
        self.saturation_p = saturation_p
        self.hue_p = hue_p
        self.shift_p = shift_p
        self.horizontal_flip_p = horizontal_flip_p
    
    def __call__(self, images: Dict, extra_condition_key):
        
        fn_idx, b, c, s, h = T.ColorJitter.get_params(self.color_jitter.brightness, self.color_jitter.contrast, self.color_jitter.saturation,self.color_jitter.hue)
        random_hflip = random.random()
        random_brightness = random.random()
        random_contrast = random.random()
        random_saturation = random.random()
        random_hue = random.random()
        random_shift = random.random()
        
        shift_valx = random.uniform(-0.2, 0.2)
        shift_valy = random.uniform(-0.2, 0.2)
        
        for key, image in images.items():
            if key in ['person', 'cloth']:
                # for person and cloth
                if random_contrast < self.contrast_p:
                    images[key] = TF.adjust_contrast(image, c)
                if random_brightness < self.brightness_p:
                    images[key] = TF.adjust_brightness(image, b)
                if random_hue < self.hue_p:
                    images[key] = TF.adjust_hue(image, h)
                if random_saturation < self.saturation_p:
                    images[key] = TF.adjust_saturation(image, s)
            # for all
            if random_hflip < self.horizontal_flip_p:
                images[key] = TF.hflip(image)
            
            if random_shift < self.shift_p:
                # for person, mask, extra_condition
                if key in ['person', 'mask', extra_condition_key]:
                    images[key] = TF.affine(images[key], angle=0, translate=[shift_valx*images[key].size[-1], shift_valy*images[key].size[-2]], scale=1, shear=0)
        
        return images