Commit ce0e5303 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2003 failed with stages
in 0 seconds
# Refer to utils/config.py for definition and options.
version = "mam_vitb_8gpu"
dist = true
wandb = false
[model]
trimap_channel = 3
mask_channel = 1
batch_size = 10
freeze_seg = true
self_refine_width1 = 30
self_refine_width2 = 15
[model.arch]
seg = "sam_vit_b"
m2m = "sam_decoder_deep"
[train]
total_step = 20000
warmup_step = 4000
val_step = 0
clip_grad = true
G_lr = 1e-3
rec_weight = 1.0
lap_weight = 1.0
# Uncomment to resume training
#resume_checkpoint = "path/to/checkpoint"
#reset_lr = false
[data]
d646_fg = "path/to/Distinctions-646/Train_ori/FG"
d646_pha = "path/to/Distinctions-646/Train_ori/GT"
aim_fg = "path/to/AIM/Combined_Dataset/fg"
aim_pha = "path/to/AIM/Combined_Dataset/alpha"
human2k_fg = "path/to/Human2K/Train/FG"
human2k_pha = "path/to/Human2K/Train/Alpha"
am2k_fg = "path/to/AM2k/train/fg"
am2k_pha = "path/to/AM2k/train/mask"
rim_img = "path/to/RefMatte/train/img_full"
rim_pha = "path/to/RefMatte/train/mask"
coco_bg = "path/to/COCO/train2017"
bg20k_bg = "path/to/Matting/BG20k/full"
workers = 4
crop_size = 1024
cutmask_prob = 0.25
pha_ratio = 0.5
augmentation = true
random_interp = true
real_world_aug = false
[log]
tensorboard_path = "./logs/tensorboard"
tensorboard_step = 100
tensorboard_image_step = 2000
logging_path = "./logs/stdout"
logging_step = 10
logging_level = "INFO"
checkpoint_path = "./checkpoints/"
checkpoint_step = 2000
# Refer to utils/config.py for definition and options.
version = "mam_vith_8gpu"
dist = true
wandb = false
[model]
trimap_channel = 3
mask_channel = 1
batch_size = 10
freeze_seg = true
self_refine_width1 = 30
self_refine_width2 = 15
[model.arch]
seg = "sam_vit_h"
m2m = "sam_decoder_deep"
[train]
total_step = 20000
warmup_step = 4000
val_step = 0
clip_grad = true
G_lr = 1e-3
rec_weight = 1.0
lap_weight = 1.0
# Uncomment to resume training
#resume_checkpoint = "path/to/checkpoint"
#reset_lr = false
[data]
d646_fg = "path/to/Distinctions-646/Train_ori/FG"
d646_pha = "path/to/Distinctions-646/Train_ori/GT"
aim_fg = "path/to/AIM/Combined_Dataset/fg"
aim_pha = "path/to/AIM/Combined_Dataset/alpha"
human2k_fg = "path/to/Human2K/Train/FG"
human2k_pha = "path/to/Human2K/Train/Alpha"
am2k_fg = "path/to/AM2k/train/fg"
am2k_pha = "path/to/AM2k/train/mask"
rim_img = "path/to/RefMatte/train/img_full"
rim_pha = "path/to/RefMatte/train/mask"
coco_bg = "path/to/COCO/train2017"
bg20k_bg = "path/to/Matting/BG20k/full"
workers = 4
crop_size = 1024
cutmask_prob = 0.25
pha_ratio = 0.5
augmentation = true
random_interp = true
real_world_aug = false
[log]
tensorboard_path = "./logs/tensorboard"
tensorboard_step = 100
tensorboard_image_step = 2000
logging_path = "./logs/stdout"
logging_step = 10
logging_level = "INFO"
checkpoint_path = "./checkpoints/"
checkpoint_step = 2000
# Refer to utils/config.py for definition and options.
version = "mam_vitl_8gpu"
dist = true
wandb = false
[model]
trimap_channel = 3
mask_channel = 1
batch_size = 10
freeze_seg = true
self_refine_width1 = 30
self_refine_width2 = 15
[model.arch]
seg = "sam_vit_l"
m2m = "sam_decoder_deep"
[train]
total_step = 20000
warmup_step = 4000
val_step = 0
clip_grad = true
G_lr = 1e-3
rec_weight = 1.0
lap_weight = 1.0
# Uncomment to resume training
#resume_checkpoint = "path/to/checkpoint"
#reset_lr = false
[data]
d646_fg = "path/to/Distinctions-646/Train_ori/FG"
d646_pha = "path/to/Distinctions-646/Train_ori/GT"
aim_fg = "path/to/AIM/Combined_Dataset/fg"
aim_pha = "path/to/AIM/Combined_Dataset/alpha"
human2k_fg = "path/to/Human2K/Train/FG"
human2k_pha = "path/to/Human2K/Train/Alpha"
am2k_fg = "path/to/AM2k/train/fg"
am2k_pha = "path/to/AM2k/train/mask"
rim_img = "path/to/RefMatte/train/img_full"
rim_pha = "path/to/RefMatte/train/mask"
coco_bg = "path/to/COCO/train2017"
bg20k_bg = "path/to/Matting/BG20k/full"
workers = 4
crop_size = 1024
cutmask_prob = 0.25
pha_ratio = 0.5
augmentation = true
random_interp = true
real_world_aug = false
[log]
tensorboard_path = "./logs/tensorboard"
tensorboard_step = 100
tensorboard_image_step = 2000
logging_path = "./logs/stdout"
logging_step = 10
logging_level = "INFO"
checkpoint_path = "./checkpoints/"
checkpoint_step = 2000
import os
import math
import numbers
import random
import logging
import numpy as np
import imgaug.augmenters as iaa
import torch
from torch.utils.data import Dataset
from torch.nn import functional as F
from torchvision import transforms
from utils import CONFIG
from random import randint
import warnings
warnings.filterwarnings("ignore")
import cv2
interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
def maybe_random_interp(cv2_interp):
if CONFIG.data.random_interp:
return np.random.choice(interp_list)
else:
return cv2_interp
class ToTensor(object):
"""
Convert ndarrays in sample to Tensors with normalization.
"""
def __init__(self, phase="test", real_world_aug = False):
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
self.phase = phase
if real_world_aug:
self.RWA = iaa.SomeOf((1, None), [
iaa.LinearContrast((0.6, 1.4)),
iaa.JpegCompression(compression=(0, 60)),
iaa.GaussianBlur(sigma=(0.0, 3.0)),
iaa.AdditiveGaussianNoise(scale=(0, 0.1*255))
], random_order=True)
else:
self.RWA = None
def get_box_from_alpha(self, alpha_final):
bi_mask = np.zeros_like(alpha_final)
bi_mask[alpha_final>0.5] = 1
#bi_mask[alpha_final<=0.5] = 0
fg_set = np.where(bi_mask != 0)
if len(fg_set[1]) == 0 or len(fg_set[0]) == 0:
x_min = random.randint(1, 511)
x_max = random.randint(1, 511) + x_min
y_min = random.randint(1, 511)
y_max = random.randint(1, 511) + y_min
else:
x_min = np.min(fg_set[1])
x_max = np.max(fg_set[1])
y_min = np.min(fg_set[0])
y_max = np.max(fg_set[0])
bbox = np.array([x_min, y_min, x_max, y_max])
#cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0,255,0), 2)
#cv2.imwrite('../outputs/test.jpg', image)
#cv2.imwrite('../outputs/test_gt.jpg', alpha_single)
return bbox
def __call__(self, sample):
# convert GBR images to RGB
image, alpha, trimap = sample['image'][:,:,::-1], sample['alpha'], sample['trimap']
alpha[alpha < 0 ] = 0
alpha[alpha > 1] = 1
bbox = self.get_box_from_alpha(alpha)
if self.phase == 'train' and self.RWA is not None and np.random.rand() < 0.5:
image[image > 255] = 255
image[image < 0] = 0
image = np.round(image).astype(np.uint8)
image = np.expand_dims(image, axis=0)
image = self.RWA(images=image)
image = image[0, ...]
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1)).astype(np.float32)
alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
trimap[trimap < 85] = 0
trimap[trimap >= 170] = 2
trimap[trimap >= 85] = 1
#image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 3)
#cv2.imwrite(os.path.join('outputs', 'img_bbox.png'), image.astype('uint8'))
# normalize image
image /= 255.
if self.phase == "train":
# convert GBR images to RGB
fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std)
bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std)
del sample['image_name']
sample['boxes'] = torch.from_numpy(bbox).to(torch.float)[None,...]
sample['image'], sample['alpha'], sample['trimap'] = \
torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
sample['image'] = sample['image'].sub_(self.mean).div_(self.std)
if CONFIG.model.trimap_channel == 3:
sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float()
elif CONFIG.model.trimap_channel == 1:
sample['trimap'] = sample['trimap'][None,...].float()
else:
raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1")
return sample
class RandomAffine(object):
"""
Random affine translation
"""
def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
"degrees should be a list or tuple and it must be of length 2."
self.degrees = degrees
if translate is not None:
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
"translate should be a list or tuple and it must be of length 2."
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
"scale should be a list or tuple and it must be of length 2."
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
if isinstance(shear, numbers.Number):
if shear < 0:
raise ValueError("If shear is a single number, it must be positive.")
self.shear = (-shear, shear)
else:
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
"shear should be a list or tuple and it must be of length 2."
self.shear = shear
else:
self.shear = shear
self.resample = resample
self.fillcolor = fillcolor
self.flip = flip
@staticmethod
def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
"""Get parameters for affine transformation
Returns:
sequence: params to be passed to the affine transformation
"""
angle = random.uniform(degrees[0], degrees[1])
if translate is not None:
max_dx = translate[0] * img_size[0]
max_dy = translate[1] * img_size[1]
translations = (np.round(random.uniform(-max_dx, max_dx)),
np.round(random.uniform(-max_dy, max_dy)))
else:
translations = (0, 0)
if scale_ranges is not None:
scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
random.uniform(scale_ranges[0], scale_ranges[1]))
else:
scale = (1.0, 1.0)
if shears is not None:
shear = random.uniform(shears[0], shears[1])
else:
shear = 0.0
if flip is not None:
flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1
return angle, translations, scale, shear, flip
def __call__(self, sample):
fg, alpha = sample['fg'], sample['alpha']
rows, cols, ch = fg.shape
if np.maximum(rows, cols) < 1024:
params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
else:
params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
M = self._get_inverse_affine_matrix(center, *params)
M = np.array(M).reshape((2, 3))
fg = cv2.warpAffine(fg, M, (cols, rows),
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
alpha = cv2.warpAffine(alpha, M, (cols, rows),
flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
sample['fg'], sample['alpha'] = fg, alpha
return sample
@ staticmethod
def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
# Helper method to compute inverse matrix for affine transformation
# As it is explained in PIL.Image.rotate
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# It is different from the original function in torchvision
# The order are changed to flip -> scale -> rotation -> shear
# x and y have different scale factors
# RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0]
# [ sin(a)*scale_x*f cos(a)*scale_y 0]
# [ 0 0 1]
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
angle = math.radians(angle)
shear = math.radians(shear)
scale_x = 1.0 / scale[0] * flip[0]
scale_y = 1.0 / scale[1] * flip[1]
# Inverted rotation matrix with scale and shear
d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
matrix = [
math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
-math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
]
matrix = [m / d for m in matrix]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += center[0]
matrix[5] += center[1]
return matrix
class RandomJitter(object):
"""
Random change the hue of the image
"""
def __call__(self, sample):
fg, alpha = sample['fg'], sample['alpha']
# if alpha is all 0 skip
if np.all(alpha==0):
return sample
# convert to HSV space, convert to float32 image to keep precision during space conversion.
fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
# Hue noise
hue_jitter = np.random.randint(-40, 40)
fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
# Saturation noise
sat_bar = fg[:, :, 1][alpha > 0].mean()
sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
sat = fg[:, :, 1]
sat = np.abs(sat + sat_jitter)
sat[sat>1] = 2 - sat[sat>1]
fg[:, :, 1] = sat
# Value noise
val_bar = fg[:, :, 2][alpha > 0].mean()
val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
val = fg[:, :, 2]
val = np.abs(val + val_jitter)
val[val>1] = 2 - val[val>1]
fg[:, :, 2] = val
# convert back to BGR space
fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
sample['fg'] = fg*255
return sample
class RandomHorizontalFlip(object):
"""
Random flip image and label horizontally
"""
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, sample):
fg, alpha = sample['fg'], sample['alpha']
if np.random.uniform(0, 1) < self.prob:
fg = cv2.flip(fg, 1)
alpha = cv2.flip(alpha, 1)
sample['fg'], sample['alpha'] = fg, alpha
return sample
class RandomCrop(object):
"""
Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
:param output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
self.margin = output_size[0] // 2
self.logger = logging.getLogger("Logger")
def __call__(self, sample):
fg, alpha, trimap, name = sample['fg'], sample['alpha'], sample['trimap'], sample['image_name']
bg = sample['bg']
h, w = trimap.shape
bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
if w < self.output_size[0]+1 or h < self.output_size[1]+1:
ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
# self.logger.warning("Size of {} is {}.".format(name, (h, w)))
while h < self.output_size[0]+1 or w < self.output_size[1]+1:
fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
interpolation=maybe_random_interp(cv2.INTER_NEAREST))
trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
h, w = trimap.shape
small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
self.margin//4:(w-self.margin)//4] == 128)))
unknown_num = len(unknown_list)
if len(unknown_list) < 10:
left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
else:
idx = np.random.randint(unknown_num)
left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
if len(np.where(trimap==128)[0]) == 0:
self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
"left_top: {}".format(name, left_top))
fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'bg': bg_crop})
return sample
class GenTrimap(object):
def __init__(self):
self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)]
def __call__(self, sample):
alpha = sample['alpha']
h, w = alpha.shape
max_kernel_size = max(30, int((min(h,w) / 2048) * 30))
### generate trimap
fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8)
bg_mask = (1 - alpha + 1e-5).astype(np.int).astype(np.uint8)
fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
trimap = np.ones_like(alpha) * 128
trimap[fg_mask == 1] = 255
trimap[bg_mask == 1] = 0
trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
sample['trimap'] = trimap
return sample
class Composite(object):
def __call__(self, sample):
fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
alpha[alpha < 0 ] = 0
alpha[alpha > 1] = 1
fg[fg < 0 ] = 0
fg[fg > 255] = 255
bg[bg < 0 ] = 0
bg[bg > 255] = 255
image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
sample['image'] = image
return sample
class Composite_Seg(object):
def __call__(self, sample):
fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
fg[fg < 0 ] = 0
fg[fg > 255] = 255
image = fg
sample['image'] = image
return sample
class DataGenerator(Dataset):
def __init__(self, phase="train"):
self.phase = phase
self.crop_size = CONFIG.data.crop_size
self.pha_ratio = CONFIG.data.pha_ratio
self.coco_bg = [os.path.join(CONFIG.data.coco_bg, name) for name in sorted(os.listdir(CONFIG.data.coco_bg))]
self.coco_num = len(self.coco_bg)
self.bg20k_bg = [os.path.join(CONFIG.data.bg20k_bg, name) for name in sorted(os.listdir(CONFIG.data.bg20k_bg))]
self.bg20k_num = len(self.bg20k_bg)
self.d646_fg = [os.path.join(CONFIG.data.d646_fg, name) for name in sorted(os.listdir(CONFIG.data.d646_fg))]
self.d646_pha = [os.path.join(CONFIG.data.d646_pha, name) for name in sorted(os.listdir(CONFIG.data.d646_pha))]
self.d646_num = len(self.d646_fg)
self.aim_fg = [os.path.join(CONFIG.data.aim_fg, name) for name in sorted(os.listdir(CONFIG.data.aim_fg))]
self.aim_pha = [os.path.join(CONFIG.data.aim_pha, name) for name in sorted(os.listdir(CONFIG.data.aim_pha))]
self.aim_num = len(self.aim_fg)
self.am2k_fg = [os.path.join(CONFIG.data.am2k_fg, name) for name in sorted(os.listdir(CONFIG.data.am2k_fg))]
self.am2k_pha = [os.path.join(CONFIG.data.am2k_pha, name) for name in sorted(os.listdir(CONFIG.data.am2k_pha))]
self.am2k_num = len(self.am2k_fg)
self.human2k_fg = [os.path.join(CONFIG.data.human2k_fg, name) for name in sorted(os.listdir(CONFIG.data.human2k_fg))]
self.human2k_pha = [os.path.join(CONFIG.data.human2k_pha, name) for name in sorted(os.listdir(CONFIG.data.human2k_pha))]
self.human2k_num = len(self.human2k_fg)
self.rim_img = [os.path.join(CONFIG.data.rim_img, name) for name in sorted(os.listdir(CONFIG.data.rim_img))]
self.rim_pha = [os.path.join(CONFIG.data.rim_pha, name) for name in sorted(os.listdir(CONFIG.data.rim_pha))]
self.rim_num = len(self.rim_img)
self.transform_imagematte = transforms.Compose(
[RandomAffine(degrees=30, scale=[0.8, 1.5], shear=10, flip=0.5),
GenTrimap(),
RandomCrop((self.crop_size, self.crop_size)),
RandomJitter(),
Composite(),
ToTensor(phase="train", real_world_aug=CONFIG.data.real_world_aug)])
self.transform_spd = transforms.Compose(
[RandomAffine(degrees=30, scale=[0.8, 1.5], shear=10, flip=0.5),
GenTrimap(),
RandomCrop((self.crop_size, self.crop_size)),
#RandomJitter(),
Composite_Seg(),
ToTensor(phase="train", real_world_aug=CONFIG.data.real_world_aug)])
def __getitem__(self, idx):
if random.random() < 0.5:
bg = cv2.imread(self.coco_bg[idx])
else:
bg = cv2.imread(self.bg20k_bg[idx % self.bg20k_num])
if random.random() < 0.5:
if random.random() < 0.25:
fg = cv2.imread(self.human2k_fg[idx % self.human2k_num])
alpha = cv2.imread(self.human2k_pha[idx % self.human2k_num], 0).astype(np.float32)/255
fg, alpha = self._composite_fg_human2k(fg, alpha, idx)
image_name = os.path.split(self.human2k_fg[idx % self.human2k_num])[-1]
elif random.random() < 0.5:
fg = cv2.imread(self.am2k_fg[idx % self.am2k_num])
alpha = cv2.imread(self.am2k_pha[idx % self.am2k_num], 0).astype(np.float32)/255
fg, alpha = self._composite_fg_am2k(fg, alpha, idx)
image_name = os.path.split(self.am2k_fg[idx % self.am2k_num])[-1]
elif random.random() < 0.75:
fg = cv2.imread(self.d646_fg[idx % self.d646_num])
alpha = cv2.imread(self.d646_pha[idx % self.d646_num], 0).astype(np.float32)/255
fg, alpha = self._composite_fg_646(fg, alpha, idx)
image_name = os.path.split(self.d646_fg[idx % self.d646_num])[-1]
else:
fg = cv2.imread(self.aim_fg[idx % self.aim_num])
alpha = cv2.imread(self.aim_pha[idx % self.aim_num], 0).astype(np.float32)/255
fg, alpha = self._composite_fg_aim(fg, alpha, idx)
image_name = os.path.split(self.aim_fg[idx % self.aim_num])[-1]
sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name}
sample = self.transform_imagematte(sample)
else:
fg = cv2.imread(self.rim_img[idx % self.rim_num])
alpha = cv2.imread(self.rim_pha[idx % self.rim_num], 0).astype(np.float32)/255
if np.random.rand() < 0.25:
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
image_name = os.path.split(self.rim_img[idx % self.rim_num])[-1]
sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name}
sample = self.transform_spd(sample)
return sample
def _composite_fg_human2k(self, fg, alpha, idx):
if np.random.rand() < 0.5:
idx2 = np.random.randint(self.human2k_num) + idx
fg2 = cv2.imread(self.human2k_fg[idx2 % self.human2k_num])
alpha2 = cv2.imread(self.human2k_pha[idx2 % self.human2k_num], 0).astype(np.float32)/255.
h, w = alpha.shape
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
if np.any(alpha_tmp < 1):
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
# The overlap of two 50% transparency should be 25%
alpha = alpha_tmp
fg = fg.astype(np.uint8)
if np.random.rand() < 0.25:
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
return fg, alpha
def _composite_fg_am2k(self, fg, alpha, idx):
if np.random.rand() < 0.5:
idx2 = np.random.randint(self.am2k_num) + idx
fg2 = cv2.imread(self.am2k_fg[idx2 % self.am2k_num])
alpha2 = cv2.imread(self.am2k_pha[idx2 % self.am2k_num], 0).astype(np.float32)/255.
h, w = alpha.shape
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
if np.any(alpha_tmp < 1):
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
# The overlap of two 50% transparency should be 25%
alpha = alpha_tmp
fg = fg.astype(np.uint8)
if np.random.rand() < 0.25:
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
return fg, alpha
def _composite_fg_646(self, fg, alpha, idx):
if np.random.rand() < 0.5:
idx2 = np.random.randint(self.d646_num) + idx
fg2 = cv2.imread(self.d646_fg[idx2 % self.d646_num])
alpha2 = cv2.imread(self.d646_pha[idx2 % self.d646_num], 0).astype(np.float32)/255.
h, w = alpha.shape
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
if np.any(alpha_tmp < 1):
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
# The overlap of two 50% transparency should be 25%
alpha = alpha_tmp
fg = fg.astype(np.uint8)
if np.random.rand() < 0.25:
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
return fg, alpha
def _composite_fg_aim(self, fg, alpha, idx):
if np.random.rand() < 0.5:
idx2 = np.random.randint(self.aim_num) + idx
fg2 = cv2.imread(self.aim_fg[idx2 % self.aim_num])
alpha2 = cv2.imread(self.aim_pha[idx2 % self.aim_num], 0).astype(np.float32)/255.
h, w = alpha.shape
fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
if np.any(alpha_tmp < 1):
fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
# The overlap of two 50% transparency should be 25%
alpha = alpha_tmp
fg = fg.astype(np.uint8)
if np.random.rand() < 0.25:
fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
return fg, alpha
def __len__(self):
return len(self.coco_bg)
\ No newline at end of file
import os
import glob
import logging
import functools
import numpy as np
import pdb
class ImageFile(object):
def __init__(self, phase='train'):
self.logger = logging.getLogger("Logger")
self.phase = phase
self.rng = np.random.RandomState(0)
def _get_valid_names(self, *dirs, shuffle=True):
# Extract valid names
name_sets = [self._get_name_set(d) for d in dirs]
# Reduce
def _join_and(a, b):
return a & b
valid_names = list(functools.reduce(_join_and, name_sets))
if shuffle:
self.rng.shuffle(valid_names)
if len(valid_names) == 0:
self.logger.error('No image valid')
else:
self.logger.info('{}: {} foreground/images are valid'.format(self.phase.upper(), len(valid_names)))
return valid_names
@staticmethod
def _get_name_set(dir_name):
path_list = glob.glob(os.path.join(dir_name, '*'))
name_set = set()
for path in path_list:
name = os.path.basename(path)
name = os.path.splitext(name)[0]
name_set.add(name)
return name_set
@staticmethod
def _list_abspath(data_dir, ext, data_list):
return [os.path.join(data_dir, name + ext)
for name in data_list]
class ImageFileTrain(ImageFile):
def __init__(self,
alpha_dir="train_alpha",
fg_dir="train_fg",
bg_dir="train_bg",
alpha_ext=".jpg",
fg_ext=".jpg",
bg_ext=".jpg"):
super(ImageFileTrain, self).__init__(phase="train")
self.alpha_dir = alpha_dir
self.fg_dir = fg_dir
self.bg_dir = bg_dir
self.alpha_ext = alpha_ext
self.fg_ext = fg_ext
self.bg_ext = bg_ext
self.logger.debug('Load Training Images From Folders')
#self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
#self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)]
#self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list)
#self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list)
#self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list)
self.alpha = [os.path.join(self.alpha_dir, name) for name in sorted(os.listdir(self.alpha_dir))]
self.fg = [os.path.join(self.fg_dir, name) for name in sorted(os.listdir(self.fg_dir))]
self.bg = [os.path.join(self.bg_dir, name) for name in os.listdir(self.bg_dir)]
def __len__(self):
return len(self.alpha)
class ImageFileTest(ImageFile):
def __init__(self,
alpha_dir="test_alpha",
merged_dir="test_merged",
trimap_dir="test_trimap",
alpha_ext=".png",
merged_ext=".png",
trimap_ext=".png"):
super(ImageFileTest, self).__init__(phase="test")
self.alpha_dir = alpha_dir
self.merged_dir = merged_dir
self.trimap_dir = trimap_dir
self.alpha_ext = alpha_ext
self.merged_ext = merged_ext
self.trimap_ext = trimap_ext
self.logger.debug('Load Testing Images From Folders')
self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list)
self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list)
self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list)
def __len__(self):
return len(self.alpha)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s', datefmt='%m-%d %H:%M:%S')
train_data = ImageFileTrain(alpha_dir="/home/liyaoyi/dataset/Adobe/train/alpha",
fg_dir="/home/liyaoyi/dataset/Adobe/train/fg",
bg_dir="/home/Data/coco/images2017",
alpha_ext=".png",
fg_ext=".jpg",
bg_ext=".jpg")
test_data = ImageFileTest(alpha_dir="/home/liyaoyi/dataset/Adobe/test/alpha",
merged_dir="/home/liyaoyi/dataset/Adobe/test/merged",
trimap_dir="/home/liyaoyi/dataset/Adobe/test/trimap",
alpha_ext=".png",
merged_ext=".jpg",
trimap_ext=".png")
print(train_data.alpha[0], train_data.fg[0], train_data.bg[0])
print(len(train_data.alpha), len(train_data.fg), len(train_data.bg))
print(test_data.alpha[0], test_data.merged[0], test_data.trimap[0])
print(len(test_data.alpha), len(test_data.merged), len(test_data.trimap))
import torch
class Prefetcher():
"""
Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
"""
def __init__(self, loader):
self.orig_loader = loader
self.stream = torch.cuda.Stream()
self.next_sample = None
def preload(self):
try:
self.next_sample = next(self.loader)
except StopIteration:
self.next_sample = None
return
with torch.cuda.stream(self.stream):
for key, value in self.next_sample.items():
if isinstance(value, torch.Tensor):
self.next_sample[key] = value.cuda(non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
sample = self.next_sample
if sample is not None:
for key, value in sample.items():
if isinstance(value, torch.Tensor):
sample[key].record_stream(torch.cuda.current_stream())
self.preload()
else:
# throw stop exception if there is no more data to perform as a default dataloader
raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); "
"data = next(iterator)`")
return sample
def __iter__(self):
self.loader = iter(self.orig_loader)
self.preload()
return self
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10
ENV DEBIAN_FRONTEND=noninteractive
# COPY requirements.txt requirements.txt
# RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
import numpy as np
import os, sys, tqdm, cv2
from scipy.optimize import linear_sum_assignment
from metrics import BatchMetric
import pdb
def match(pred, gt):
# pred: (n,h,w)
# gt: (m,h,w)
n, h, w = pred.shape
m, h, w = gt.shape
pred_mask = (pred>0)
gt_mask = (gt>0)
# (n,w)
union = np.logical_or(pred_mask[:,None,:,:], gt_mask[None,:,:,:]).sum(axis=(2,3))
inter = np.logical_and(pred_mask[:,None,:,:], gt_mask[None,:,:,:]).sum(axis=(2,3))
iou = inter / (union + 1e-8)
# matched_idx = np.argmax(iou, axis=0) # m
# matched_iou = np.max(iou, axis=0) # m
# return matched_idx, matched_iou
return iou
def mad(pred, gt):
pred_mask = (pred>0)
gt_mask = (gt>0)
union_mask = np.logical_or(pred_mask, gt_mask)
error = np.abs(pred-gt) * union_mask.astype(np.float32)
error = error.sum(axis=(1,2)) / (union_mask.sum(axis=(1,2)) + 1.)
score = 1 - np.minimum(error,1)
return score
def similarity(pred, gt):
metric = BatchMetric('cuda')
# mask = np.logical_or(pred>0, gt>0) * 128
mask = np.logical_or(pred>0, gt>0) * 128
sad, mad, mse, grad, conn = metric.run(pred*255, gt*255, mask=mask, calc_mad=True)
mad = 1 - np.minimum(mad * 10, 1.0)
mse = 1 - np.minimum(mse * 10, 1.0)
grad = 1 - np.minimum(grad * 10, 1.0)
conn = 1 - np.minimum(conn * 10, 1.0)
score = [mad.sum(), mse.sum(), grad.sum(), conn.sum()]
return score
def compute_stats_per_image(pred, gt, thresh_list, func=mad):
# matched_idx, matched_iou = match(pred, gt)
tp_list, fp_list, fn_list = [], [], []
MQ_list = []
if len(pred)>0 and len(gt)>0:
iou_matrix = match(pred, gt)
matched_i, matched_j = linear_sum_assignment(1-iou_matrix)
matched_iou = iou_matrix[matched_i, matched_j]
# score = func(pred[matched_i], gt[matched_j]) * matched_iou
for thresh in thresh_list:
tp = (matched_iou>=thresh).sum()
fp = pred.shape[0] - tp
fn = gt.shape[0] - tp
tp_list.append(tp)
fp_list.append(fp)
fn_list.append(fn)
if tp>0:
tp_idx = np.where(matched_iou>=thresh)
tp_i, tp_j = matched_i[tp_idx], matched_j[tp_idx]
score_list = similarity(pred[tp_i], gt[tp_j])
score_list.append(matched_iou[tp_idx].sum())
MQ_list.append(score_list)
else:
MQ_list.append([0,0,0,0,0])
elif len(pred) == 0:
for thresh in thresh_list:
tp_list.append(0)
fp_list.append(0)
fn_list.append(len(gt))
MQ_list.append([0,0,0,0,0])
else:
for thresh in thresh_list:
tp_list.append(0)
fp_list.append(len(pred))
fn_list.append(0)
MQ_list.append([0,0,0,0,0])
return tp_list, fp_list, fn_list, MQ_list
def compute_stats(pred_folder, gt_folder, thresh_list):
n_thresh = len(thresh_list)
IMQ_list = [] # n_thresh, n_IMQ, n_instances
_MQ_list = []
_RQ_list = []
for i in range(n_thresh):
IMQ_list.append([0]*5)
_MQ_list.append([0]*5)
_RQ_list.append([0]*5)
TP, FP, FN = [0]*n_thresh, [0]*n_thresh, [0]*n_thresh
for item in tqdm.tqdm(sorted(os.listdir(gt_folder))):
if not os.path.exists(os.path.join(pred_folder, item)):
continue
pred_images = [cv2.imread(os.path.join(pred_folder, item, im), 0)/255. for im in os.listdir(os.path.join(pred_folder, item))]
gt_images = [cv2.imread(os.path.join(gt_folder, item, im), 0)/255. for im in os.listdir(os.path.join(gt_folder, item))]
if len(pred_images)>0:
pred_items = np.stack(pred_images, axis=0)
else:
pred_items = []
if len(gt_images)>0:
gt_items = np.stack(gt_images, axis=0)
else:
gt_items = []
tp_list, fp_list, fn_list, MQ_list = compute_stats_per_image(pred_items, gt_items, thresh_list)
for i in range(0, n_thresh):
TP[i] += tp_list[i]
FP[i] += fp_list[i]
FN[i] += fn_list[i]
for j in range(len(MQ_list[i])):
IMQ_list[i][j] += MQ_list[i][j]
_MQ_list[i][j] += MQ_list[i][j]
for i in range(0, n_thresh):
coeff = 1.0 / (TP[i] + 0.5*FP[i] + 0.5*FN[i] + 1e-6)
for j in range(len(IMQ_list[0])):
IMQ_list[i][j] = IMQ_list[i][j] * coeff
_MQ_list[i][j] = _MQ_list[i][j] / float(TP[i])
_RQ_list[i][j] = coeff * TP[i]
return IMQ_list, _MQ_list, _RQ_list
if __name__ == "__main__":
pred_folder = sys.argv[1]
gt_folder = sys.argv[2]
#thresh_list = [0.5, 0.75]
thresh_list = [0.5]
IMQ_list, MQ_list, RQ_list = compute_stats(pred_folder, gt_folder, thresh_list)
for IMQ, MQ, RQ, thresh in zip(IMQ_list, MQ_list, RQ_list, thresh_list):
print("IMQ/MQ/RQ on Threshold = {}".format(thresh))
for i, name in enumerate(['MAD', 'MSE', 'Grad', 'Conn', 'IoU']):
print('{} = {}/{}/{}'.format(name, IMQ[i], MQ[i], RQ[i]))
import numpy as np
import os, sys, tqdm, cv2
from scipy.optimize import linear_sum_assignment
from metrics import BatchMetric
def match(pred, gt):
# pred: (n,h,w)
# gt: (m,h,w)
n, h, w = pred.shape
m, h, w = gt.shape
pred_mask = (pred>0)
gt_mask = (gt>0)
# (n,w)
union = np.logical_or(pred_mask[:,None,:,:], gt_mask[None,:,:,:]).sum(axis=(2,3))
inter = np.logical_and(pred_mask[:,None,:,:], gt_mask[None,:,:,:]).sum(axis=(2,3))
iou = inter / (union + 1e-8)
# matched_idx = np.argmax(iou, axis=0) # m
# matched_iou = np.max(iou, axis=0) # m
# return matched_idx, matched_iou
return iou
def mad(pred, gt):
pred_mask = (pred>0)
gt_mask = (gt>0)
union_mask = np.logical_or(pred_mask, gt_mask)
error = np.abs(pred-gt) * union_mask.astype(np.float32)
error = error.sum(axis=(1,2)) / (union_mask.sum(axis=(1,2)) + 1.)
score = 1 - np.minimum(error,1)
return score
def similarity(pred, gt):
metric = BatchMetric('cuda')
# mask = np.logical_or(pred>0, gt>0) * 128
mask = np.logical_or(pred>0, gt>0) * 128
mad, mse = metric.run_quick(pred*255, gt*255, mask=mask)
mad = 1 - np.minimum(mad * 10, 1.0)
mse = 1 - np.minimum(mse * 10, 1.0)
#grad = 1 - np.minimum(grad * 10, 1.0)
#conn = 1 - np.minimum(conn * 10, 1.0)
score = [mad.sum(), mse.sum()]
return score
def compute_stats_per_image(pred, gt, thresh_list, func=mad):
# matched_idx, matched_iou = match(pred, gt)
tp_list, fp_list, fn_list = [], [], []
MQ_list = []
if len(pred)>0 and len(gt)>0:
iou_matrix = match(pred, gt)
matched_i, matched_j = linear_sum_assignment(1-iou_matrix)
matched_iou = iou_matrix[matched_i, matched_j]
# score = func(pred[matched_i], gt[matched_j]) * matched_iou
for thresh in thresh_list:
tp = (matched_iou>=thresh).sum()
fp = pred.shape[0] - tp
fn = gt.shape[0] - tp
tp_list.append(tp)
fp_list.append(fp)
fn_list.append(fn)
if tp>0:
tp_idx = np.where(matched_iou>=thresh)
tp_i, tp_j = matched_i[tp_idx], matched_j[tp_idx]
score_list = similarity(pred[tp_i], gt[tp_j])
#score_list.append(matched_iou[tp_idx].sum())
MQ_list.append(score_list)
else:
MQ_list.append([0,0])
elif len(pred) == 0:
for thresh in thresh_list:
tp_list.append(0)
fp_list.append(0)
fn_list.append(len(gt))
MQ_list.append([0,0])
else:
for thresh in thresh_list:
tp_list.append(0)
fp_list.append(len(pred))
fn_list.append(0)
MQ_list.append([0,0])
return tp_list, fp_list, fn_list, MQ_list
def compute_stats(pred_folder, gt_folder, thresh_list):
n_thresh = len(thresh_list)
IMQ_list = [] # n_thresh, n_IMQ, n_instances
_MQ_list = []
_RQ_list = []
for i in range(n_thresh):
IMQ_list.append([0]*2)
_MQ_list.append([0]*2)
_RQ_list.append([0]*2)
TP, FP, FN = [0]*n_thresh, [0]*n_thresh, [0]*n_thresh
for item in tqdm.tqdm(sorted(os.listdir(gt_folder))):
#if not os.path.exists(os.path.join(pred_folder, item)):
# continue
#pred_images = [cv2.imread(os.path.join(pred_folder, item, im), 0)/255. for im in os.listdir(os.path.join(pred_folder, item))]
#gt_images = [cv2.imread(os.path.join(gt_folder, item, im), 0)/255. for im in os.listdir(os.path.join(gt_folder, item))]
pred_images = [cv2.imread(os.path.join(pred_folder, item), 0)/255.]
gt_images = [cv2.imread(os.path.join(gt_folder, item), 0)/255.]
if len(pred_images)>0:
pred_items = np.stack(pred_images, axis=0)
else:
pred_items = []
if len(gt_images)>0:
gt_items = np.stack(gt_images, axis=0)
else:
gt_items = []
tp_list, fp_list, fn_list, MQ_list = compute_stats_per_image(pred_items, gt_items, thresh_list)
for i in range(0, n_thresh):
TP[i] += tp_list[i]
FP[i] += fp_list[i]
FN[i] += fn_list[i]
for j in range(len(MQ_list[i])):
IMQ_list[i][j] += MQ_list[i][j]
_MQ_list[i][j] += MQ_list[i][j]
for i in range(0, n_thresh):
coeff = 1.0 / (TP[i] + 0.5*FP[i] + 0.5*FN[i] + 1e-6)
for j in range(len(IMQ_list[0])):
IMQ_list[i][j] = IMQ_list[i][j] * coeff
_MQ_list[i][j] = _MQ_list[i][j] / float(TP[i])
_RQ_list[i][j] = coeff * TP[i]
return IMQ_list, _MQ_list, _RQ_list
if __name__ == "__main__":
pred_folder = sys.argv[1]
gt_folder = sys.argv[2]
#thresh_list = [0.5, 0.75]
thresh_list = [0.5]
IMQ_list, MQ_list, RQ_list = compute_stats(pred_folder, gt_folder, thresh_list)
for IMQ, MQ, RQ, thresh in zip(IMQ_list, MQ_list, RQ_list, thresh_list):
print("IMQ/MQ/RQ on Threshold = {}".format(thresh))
for i, name in enumerate(['MAD', 'MSE']):
print('{} = {}/{}/{}'.format(name, IMQ[i], MQ[i], RQ[i]))
import os
import cv2
import sys
import numpy as np
sys.path.insert(0, './utils')
from evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error
import argparse
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred-dir', type=str, default='path/to/outputs/am2k', help="pred alpha dir")
parser.add_argument('--label-dir', type=str, default='path/to/AM2k/validation/mask/', help="GT alpha dir")
parser.add_argument('--detailmap-dir', type=str, default='path/to/AM2k/validation/trimap/', help="trimap dir")
args = parser.parse_args()
mse_loss = []
sad_loss = []
mad_loss = []
grad_loss = []
conn_loss = []
### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
mse_loss_unknown = []
sad_loss_unknown = []
for img in tqdm(os.listdir(args.label_dir)):
print(img)
#pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.png', '.jpg')), 0).astype(np.float32)
pred = cv2.imread(os.path.join(args.pred_dir, img), 0).astype(np.float32)
label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
detailmap = cv2.imread(os.path.join(args.detailmap_dir, img), 0).astype(np.float32)
detailmap[detailmap > 0] = 128
mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
sad_loss_unknown_ = compute_sad_loss(pred, label, detailmap)[0]
detailmap[...] = 128
mse_loss_ = compute_mse_loss(pred, label, detailmap)
sad_loss_ = compute_sad_loss(pred, label, detailmap)[0]
mad_loss_ = compute_mad_loss(pred, label, detailmap)
grad_loss_ = compute_gradient_loss(pred, label, detailmap)
conn_loss_ = compute_connectivity_error(pred, label, detailmap)
print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_, ' MAD:', mad_loss_, 'Grad:', grad_loss_, ' Conn:', conn_loss_)
print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)
mse_loss_unknown.append(mse_loss_unknown_)
sad_loss_unknown.append(sad_loss_unknown_)
mse_loss.append(mse_loss_)
sad_loss.append(sad_loss_)
mad_loss.append(mad_loss_)
grad_loss.append(grad_loss_)
conn_loss.append(conn_loss_)
print('Average:')
print('Whole Image: MSE:', np.array(mse_loss).mean(), ' SAD:', np.array(sad_loss).mean(), ' MAD:', np.array(mad_loss).mean(), ' Grad:', np.array(grad_loss).mean(), ' Conn:', np.array(conn_loss).mean())
print('Detail Region: MSE:', np.array(mse_loss_unknown).mean(), ' SAD:', np.array(sad_loss_unknown).mean())
import os
import sys
import cv2
import numpy as np
sys.path.insert(0, './utils')
from evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss
import argparse
from tqdm import tqdm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pred-dir', type=str, default='path/to/outputs/pm10k/', help="pred alpha dir")
parser.add_argument('--label-dir', type=str, default='path/to/P3M-10k/validation/P3M-500-NP/mask', help="GT alpha dir")
parser.add_argument('--detailmap-dir', type=str, default='path/to/P3M-10k/validation/P3M-500-NP/trimap', help="trimap dir")
args = parser.parse_args()
mse_loss = []
sad_loss = []
mad_loss = []
grad_loss = []
#conn_loss = []
### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
mse_loss_unknown = []
sad_loss_unknown = []
for img in tqdm(os.listdir(args.label_dir)):
print(img)
#pred = cv2.imread(os.path.join(args.pred_dir, img.replace('.png', '.jpg')), 0).astype(np.float32)
pred = cv2.imread(os.path.join(args.pred_dir, img), 0).astype(np.float32)
label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
detailmap = cv2.imread(os.path.join(args.detailmap_dir, img), 0).astype(np.float32)
detailmap[detailmap > 0] = 128
mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
sad_loss_unknown_ = compute_sad_loss(pred, label, detailmap)[0]
detailmap[...] = 128
mse_loss_ = compute_mse_loss(pred, label, detailmap)
sad_loss_ = compute_sad_loss(pred, label, detailmap)[0]
mad_loss_ = compute_mad_loss(pred, label, detailmap)
grad_loss_ = compute_gradient_loss(pred, label, detailmap)
#conn_loss_ = compute_connectivity_error(pred, label, detailmap)
print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_, ' MAD:', mad_loss_, 'Grad:', grad_loss_)
print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)
mse_loss_unknown.append(mse_loss_unknown_)
sad_loss_unknown.append(sad_loss_unknown_)
mse_loss.append(mse_loss_)
sad_loss.append(sad_loss_)
mad_loss.append(mad_loss_)
grad_loss.append(grad_loss_)
#conn_loss.append(conn_loss_)
print('Average:')
print('Whole Image: MSE:', np.array(mse_loss).mean(), ' SAD:', np.array(sad_loss).mean(), ' MAD:', np.array(mad_loss).mean(), ' Grad:', np.array(grad_loss).mean())
print('Detail Region: MSE:', np.array(mse_loss_unknown).mean(), ' SAD:', np.array(sad_loss_unknown).mean())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment