# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This code is based on https://github.com/saic-vul/ritm_interactive_segmentation Ths copyright of saic-vul/ritm_interactive_segmentation is as follows: MIT License [see LICENSE for details] """ import cv2 import numpy as np from copy import deepcopy class Clicker(object): def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): self.click_indx_offset = click_indx_offset if gt_mask is not None: self.gt_mask = gt_mask == 1 self.not_ignore_mask = gt_mask != ignore_label else: self.gt_mask = None self.reset_clicks() if init_clicks is not None: for click in init_clicks: self.add_click(click) def make_next_click(self, pred_mask): assert self.gt_mask is not None click = self._get_next_click(pred_mask) self.add_click(click) def get_clicks(self, clicks_limit=None): return self.clicks_list[:clicks_limit] def _get_next_click(self, pred_mask, padding=True): fn_mask = np.logical_and( np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask, ) fp_mask = np.logical_and( np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask, ) if padding: fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant") fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant") fn_mask_dt = cv2.distanceTransform( fn_mask.astype(np.uint8), cv2.DIST_L2, 0) fp_mask_dt = cv2.distanceTransform( fp_mask.astype(np.uint8), cv2.DIST_L2, 0) if padding: fn_mask_dt = fn_mask_dt[1:-1, 1:-1] fp_mask_dt = fp_mask_dt[1:-1, 1:-1] fn_mask_dt = fn_mask_dt * self.not_clicked_map fp_mask_dt = fp_mask_dt * self.not_clicked_map fn_max_dist = np.max(fn_mask_dt) fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist if is_positive: coords_y, coords_x = np.where( fn_mask_dt == fn_max_dist) # coords is [y, x] else: coords_y, coords_x = np.where( fp_mask_dt == fp_max_dist) # coords is [y, x] return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) def add_click(self, click): coords = click.coords click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks if click.is_positive: self.num_pos_clicks += 1 else: self.num_neg_clicks += 1 self.clicks_list.append(click) if self.gt_mask is not None: self.not_clicked_map[coords[0], coords[1]] = False def _remove_last_click(self): click = self.clicks_list.pop() coords = click.coords if click.is_positive: self.num_pos_clicks -= 1 else: self.num_neg_clicks -= 1 if self.gt_mask is not None: self.not_clicked_map[coords[0], coords[1]] = True def reset_clicks(self): if self.gt_mask is not None: self.not_clicked_map = np.ones_like(self.gt_mask, dtype=np.bool) self.num_pos_clicks = 0 self.num_neg_clicks = 0 self.clicks_list = [] def get_state(self): return deepcopy(self.clicks_list) def set_state(self, state): self.reset_clicks() for click in state: self.add_click(click) def __len__(self): return len(self.clicks_list) class Click: def __init__(self, is_positive, coords, indx=None): self.is_positive = is_positive self.coords = coords self.indx = indx @property def coords_and_indx(self): return (*self.coords, self.indx) def copy(self, **kwargs): self_copy = deepcopy(self) for k, v in kwargs.items(): setattr(self_copy, k, v) return self_copy