# -------------------------------------------------------- # Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499) # Github source: https://github.com/baaivision/Painter # Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI) # Licensed under The MIT License [see LICENSE for details] # By Xinlong Wang, Wen Wang # Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases # --------------------------------------------------------' import random import math import numpy as np class MaskingGenerator: def __init__( self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, min_aspect=0.3, max_aspect=None): if not isinstance(input_size, tuple): input_size = (input_size,) * 2 self.height, self.width = input_size self.num_patches = self.height * self.width self.num_masking_patches = num_masking_patches self.min_num_patches = min_num_patches self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) def __repr__(self): repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( self.height, self.width, self.min_num_patches, self.max_num_patches, self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) return repr_str def get_shape(self): return self.height, self.width def _mask(self, mask, max_mask_patches): delta = 0 for attempt in range(10): target_area = random.uniform(self.min_num_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < self.width and h < self.height: top = random.randint(0, self.height - h) left = random.randint(0, self.width - w) num_masked = mask[top: top + h, left: left + w].sum() # Overlap if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break return delta def __call__(self): mask = np.zeros(shape=self.get_shape(), dtype=np.int32) mask_count = 0 while mask_count < self.num_masking_patches: max_mask_patches = self.num_masking_patches - mask_count max_mask_patches = min(max_mask_patches, self.max_num_patches) delta = self._mask(mask, max_mask_patches) if delta == 0: break else: mask_count += delta # maintain a fix number {self.num_masking_patches} if mask_count > self.num_masking_patches: delta = mask_count - self.num_masking_patches mask_x, mask_y = mask.nonzero() to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_vis], mask_y[to_vis]] = 0 elif mask_count < self.num_masking_patches: delta = self.num_masking_patches - mask_count mask_x, mask_y = (mask == 0).nonzero() to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) mask[mask_x[to_mask], mask_y[to_mask]] = 1 assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" return mask if __name__ == '__main__': import pdb generator = MaskingGenerator(input_size=14, num_masking_patches=118, min_num_patches=16, ) for i in range(10000000): mask = generator() if mask.sum() != 118: pdb.set_trace() print(mask) print(mask.sum())