"""Some helper functions for PyTorch, including: - get_mean_and_std: calculate the mean and std value of dataset. - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. """ import numpy as np import torch import torch.nn as nn import torch.nn.init as init import torchvision.transforms as transforms class EarlyStopping: """ EarlyStopping class to keep NN from overfitting """ # pylint: disable=E0202 def __init__(self, mode="min", min_delta=0, patience=10, percentage=False): self.mode = mode self.min_delta = min_delta self.patience = patience self.best = None self.num_bad_epochs = 0 self.is_better = None self._init_is_better(mode, min_delta, percentage) if patience == 0: self.is_better = lambda a, b: True self.step = lambda a: False def step(self, metrics): """ EarlyStopping step on each epoch Arguments: metrics {float} -- metric value """ if self.best is None: self.best = metrics return False if np.isnan(metrics): return True if self.is_better(metrics, self.best): self.num_bad_epochs = 0 self.best = metrics else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: return True return False def _init_is_better(self, mode, min_delta, percentage): if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if not percentage: if mode == "min": self.is_better = lambda a, best: a < best - min_delta if mode == "max": self.is_better = lambda a, best: a > best + min_delta else: if mode == "min": self.is_better = lambda a, best: a < best - (best * min_delta / 100) if mode == "max": self.is_better = lambda a, best: a > best + (best * min_delta / 100) class Cutout: """Randomly mask out one or more patches from an image. Args: n_holes (int): Number of patches to cut out of each image. length (int): The length (in pixels) of each square patch. """ def __init__(self, length): self.length = length def __call__(self, img): """ Args: img (Tensor): Tensor image of size (C, H, W). Returns: Tensor: Image with n_holes of dimension length x length cut out of it. """ h_img, w_img = img.size(1), img.size(2) mask = np.ones((h_img, w_img), np.float32) y_img = np.random.randint(h_img) x_img = np.random.randint(w_img) y1_img = np.clip(y_img - self.length // 2, 0, h_img) y2_img = np.clip(y_img + self.length // 2, 0, h_img) x1_img = np.clip(x_img - self.length // 2, 0, w_img) x2_img = np.clip(x_img + self.length // 2, 0, w_img) mask[y1_img:y2_img, x1_img:x2_img] = 0.0 mask = torch.from_numpy(mask) mask = mask.expand_as(img) img *= mask return img def data_transforms_cifar10(args): """ data_transforms for cifar10 dataset """ cifar_mean = [0.49139968, 0.48215827, 0.44653124] cifar_std = [0.24703233, 0.24348505, 0.26158768] train_transform = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std), ] ) if args.cutout: train_transform.transforms.append(Cutout(args.cutout_length)) valid_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(cifar_mean, cifar_std)] ) return train_transform, valid_transform def data_transforms_mnist(args, mnist_mean=None, mnist_std=None): """ data_transforms for mnist dataset """ if mnist_mean is None: mnist_mean = [0.5] if mnist_std is None: mnist_std = [0.5] train_transform = transforms.Compose( [ transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mnist_mean, mnist_std), ] ) if args.cutout: train_transform.transforms.append(Cutout(args.cutout_length)) valid_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mnist_mean, mnist_std)] ) return train_transform, valid_transform def get_mean_and_std(dataset): """Compute the mean and std value of dataset.""" dataloader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=True, num_workers=2 ) mean = torch.zeros(3) std = torch.zeros(3) print("==> Computing mean and std..") for inputs, _ in dataloader: for i in range(3): mean[i] += inputs[:, i, :, :].mean() std[i] += inputs[:, i, :, :].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std def init_params(net): """Init layer parameters.""" for module in net.modules(): if isinstance(module, nn.Conv2d): init.kaiming_normal(module.weight, mode="fan_out") if module.bias: init.constant(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): init.constant(module.weight, 1) init.constant(module.bias, 0) elif isinstance(module, nn.Linear): init.normal(module.weight, std=1e-3) if module.bias: init.constant(module.bias, 0)