import torch import torch.nn as nn import torch.nn.functional as F class NoiseSythesis(nn.Module): def __init__(self, p, scale=0.01, shift_scale=(8,5), scaling_size=(0.1,0.1), canvas_size=(200, 100), bbox_type='sce', poly_coord_dim=2, bbox_coord_dim=2, quantify=True): super(NoiseSythesis, self).__init__() self.p = p self.scale = scale self.bbox_type = bbox_type self.quantify = quantify self.poly_coord_dim = poly_coord_dim self.bbox_coord_dim = bbox_coord_dim self.transforms = [self.random_shifting, self.random_scaling] # self.transforms = [self.random_scaling] self.register_buffer('canvas_size', torch.tensor(canvas_size)) self.register_buffer('shift_scale', torch.tensor(shift_scale).float()) self.register_buffer('scaling_size', torch.tensor(scaling_size)) def random_scaling(self, bbox): ''' bbox: B, paramter_num, 2 ''' device = bbox.device dtype = bbox.dtype B = bbox.shape[0] noise = (torch.rand(B, device=device)*2-1)[:,None,None] # [-1,1] scale = self.scaling_size.to(device) scale = (noise * scale) + 1 scaled_bbox = bbox * scale # recenterization coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2) scaled_bbox = scaled_bbox - coffset[:,None] return scaled_bbox.round().type(dtype) def random_shifting(self, bbox): ''' bbox: B, paramter_num, 2 ''' device = bbox.device batch_size = bbox.shape[0] shift_scale = self.shift_scale scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1 scale = torch.where(scale < shift_scale, scale, shift_scale) noise = (torch.rand(batch_size, 2, device=device)*2-1) # [-1,1] offset = (noise * scale).round().type(bbox.dtype) shifted_bbox = bbox + offset[:, None] return shifted_bbox def gaussian_noise_bbox(self, bbox): dtype = bbox.dtype batch_size = bbox.shape[0] scale = (self.canvas_size * self.scale)[:self.bbox_coord_dim] noisy_bbox = torch.normal(bbox.type(torch.float), scale) if self.quantify: noisy_bbox = noisy_bbox.round().type(dtype) # prevent out of bound case for i in range(self.bbox_coord_dim): noisy_bbox[...,i] =\ torch.clamp(noisy_bbox[...,0],1,self.canvas_size[i]) else: noisy_bbox = noisy_bbox.type(torch.float) return noisy_bbox def gaussian_noise_poly(self, polyline, polyline_mask): device = polyline.device batchsize = polyline.shape[0] scale = self.canvas_size * self.scale polyline = F.pad(polyline,(0,self.poly_coord_dim-1)) polyline = polyline.view(batchsize,-1, self.poly_coord_dim) mask = F.pad(polyline_mask[:,1:],(0,self.poly_coord_dim)) noisy_polyline = torch.normal(polyline.type(torch.float), scale) if self.quantify: noisy_polyline = noisy_polyline.round().type(polyline.dtype) # prevent out of bound case for i in range(self.poly_coord_dim): noisy_polyline[...,i] =\ torch.clamp(noisy_polyline[...,i],0,self.canvas_size[i]) else: noisy_polyline = noisy_polyline.type(torch.float) noisy_polyline = noisy_polyline.view(batchsize,-1) * mask noisy_polyline = noisy_polyline[:,:-(self.poly_coord_dim-1)] return noisy_polyline def random_apply(self, bbox): for t in self.transforms: if self.p < torch.rand(1): continue bbox = t(bbox) # prevent out of bound case bbox[...,0] =\ torch.clamp(bbox[...,0],0,self.canvas_size[0]) bbox[...,1] =\ torch.clamp(bbox[...,1],0,self.canvas_size[1]) return bbox def simple_aug(self, batch): # augment bbox if self.bbox_type in ['sce', 'xyxy']: fbbox = batch['bbox_flat'] seq_len = fbbox.shape[0] bbox = fbbox.view(seq_len, -1, 2) bbox = self.gaussian_noise_bbox(bbox) fbbox_aug = bbox.view(seq_len, -1) aug_mask = torch.rand(fbbox.shape,device=fbbox.device) fbbox = torch.where(aug_mask