# https://github.com/Randl/MobileNetV3-pytorch/blob/master/dropblock.py import torch import torch.nn.functional as F from torch import nn __all__ = ['DropBlock2D', 'reset_dropblock'] class DropBlock2D(nn.Module): r"""Randomly zeroes 2D spatial blocks of the input tensor. As described in the paper `DropBlock: A regularization method for convolutional networks`_ , dropping whole blocks of feature map allows to remove semantic information as compared to regular dropout. Args: drop_prob (float): probability of an element to be dropped. block_size (int): size of the block to drop Shape: - Input: `(N, C, H, W)` - Output: `(N, C, H, W)` .. _DropBlock: A regularization method for convolutional networks: https://arxiv.org/abs/1810.12890 """ def __init__(self, drop_prob, block_size, share_channel=False): super(DropBlock2D, self).__init__() self.register_buffer('i', torch.zeros(1, dtype=torch.int64)) self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32)) self.inited = False self.step_size = 0.0 self.start_step = 0 self.nr_steps = 0 self.block_size = block_size self.share_channel = share_channel def reset(self): """stop DropBlock""" self.inited = True self.i[0] = 0 self.drop_prob = 0.0 def reset_steps(self, start_step, nr_steps, start_value=0, stop_value=None): self.inited = True stop_value = self.drop_prob.item() if stop_value is None else stop_value self.i[0] = 0 self.drop_prob[0] = start_value self.step_size = (stop_value - start_value) / nr_steps self.nr_steps = nr_steps self.start_step = start_step def forward(self, x): if not self.training or self.drop_prob.item() == 0.: return x else: self.step() # get gamma value gamma = self._compute_gamma(x) # sample mask and place on input device if self.share_channel: mask = (torch.rand(*x.shape[2:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0).unsqueeze(0) else: mask = (torch.rand(*x.shape[1:], device=x.device, dtype=x.dtype) < gamma).unsqueeze(0) # compute block mask block_mask, keeped = self._compute_block_mask(mask) # apply block mask out = x * block_mask # scale output out = out * (block_mask.numel() / keeped).to(out) return out def _compute_block_mask(self, mask): block_mask = F.max_pool2d(mask, kernel_size=(self.block_size, self.block_size), stride=(1, 1), padding=self.block_size // 2) keeped = block_mask.numel() - block_mask.sum().to(torch.float32) block_mask = 1 - block_mask return block_mask, keeped def _compute_gamma(self, x): _, c, h, w = x.size() gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \ ((w - self.block_size + 1) * (h - self.block_size + 1)) return gamma def step(self): assert self.inited idx = self.i.item() if idx > self.start_step and idx < self.start_step + self.nr_steps: self.drop_prob += self.step_size self.i += 1 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): idx_key = prefix + 'i' drop_prob_key = prefix + 'drop_prob' if idx_key not in state_dict: state_dict[idx_key] = torch.zeros(1, dtype=torch.int64) if idx_key not in drop_prob_key: state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32) super(DropBlock2D, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def _save_to_state_dict(self, destination, prefix, keep_vars): """overwrite save method""" pass def extra_repr(self): return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size) def reset_dropblock(start_step, nr_steps, start_value, stop_value, m): """ Example: from functools import partial apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1) net.apply(apply_drop_prob) """ if isinstance(m, DropBlock2D): print('reseting dropblock') m.reset_steps(start_step, nr_steps, start_value, stop_value)