Commit ce0e5303 authored by bailuo's avatar bailuo
Browse files

init

parents
Pipeline #2003 failed with stages
in 0 seconds
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as nn_utils
import torch.backends.cudnn as cudnn
from torch.nn import SyncBatchNorm
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel
import utils
from utils import CONFIG
import networks
import wandb
import cv2
class Trainer(object):
def __init__(self,
train_dataloader,
test_dataloader,
logger,
tb_logger):
cudnn.benchmark = True
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.logger = logger
self.tb_logger = tb_logger
self.model_config = CONFIG.model
self.train_config = CONFIG.train
self.log_config = CONFIG.log
self.loss_dict = {'rec_os8': None,
'comp_os8': None,
'rec_os1': None,
'comp_os1': None,
'smooth_l1':None,
'grad':None,
'gabor':None,
'lap_os8': None,
'lap_os1': None,
'rec_os4': None,
'comp_os4': None,
'lap_os4': None,}
self.test_loss_dict = {'rec': None,
'smooth_l1':None,
'mse':None,
'sad':None,
'grad':None,
'gabor':None}
self.grad_filter = torch.tensor(utils.get_gradfilter()).cuda()
self.gabor_filter = torch.tensor(utils.get_gaborfilter(16)).cuda()
self.gauss_filter = torch.tensor([[1., 4., 6., 4., 1.],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]]).cuda()
self.gauss_filter /= 256.
self.gauss_filter = self.gauss_filter.repeat(1, 1, 1, 1)
self.build_model()
self.resume_step = None
self.best_loss = 1e+8
utils.print_network(self.G, CONFIG.version)
if self.train_config.resume_checkpoint:
self.logger.info('Resume checkpoint: {}'.format(self.train_config.resume_checkpoint))
self.restore_model(self.train_config.resume_checkpoint)
if self.model_config.imagenet_pretrain and self.train_config.resume_checkpoint is None:
self.logger.info('Load Imagenet Pretrained: {}'.format(self.model_config.imagenet_pretrain_path))
if self.model_config.arch.encoder == "vgg_encoder":
utils.load_VGG_pretrain(self.G, self.model_config.imagenet_pretrain_path)
else:
utils.load_imagenet_pretrain(self.G, self.model_config.imagenet_pretrain_path)
def build_model(self):
self.G = networks.get_generator_m2m(seg=self.model_config.arch.seg, m2m=self.model_config.arch.m2m)
self.G.cuda()
if CONFIG.dist:
self.logger.info("Using pytorch synced BN")
self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)
self.G_optimizer = torch.optim.Adam(self.G.parameters(),
lr = self.train_config.G_lr,
betas = [self.train_config.beta1, self.train_config.beta2])
if CONFIG.dist:
# SyncBatchNorm only supports DistributedDataParallel with single GPU per process
self.G = DistributedDataParallel(self.G, device_ids=[CONFIG.local_rank], output_device=CONFIG.local_rank, find_unused_parameters=True)
else:
self.G = nn.DataParallel(self.G)
self.build_lr_scheduler()
def build_lr_scheduler(self):
"""Build cosine learning rate scheduler."""
self.G_scheduler = lr_scheduler.CosineAnnealingLR(self.G_optimizer,
T_max=self.train_config.total_step
- self.train_config.warmup_step)
def reset_grad(self):
"""Reset the gradient buffers."""
self.G_optimizer.zero_grad()
def restore_model(self, resume_checkpoint):
"""
Restore the trained generator and discriminator.
:param resume_checkpoint: File name of checkpoint
:return:
"""
pth_path = os.path.join(self.log_config.checkpoint_path, '{}.pth'.format(resume_checkpoint))
checkpoint = torch.load(pth_path, map_location = lambda storage, loc: storage.cuda(CONFIG.gpu))
self.resume_step = checkpoint['iter']
self.logger.info('Loading the trained models from step {}...'.format(self.resume_step))
self.G.load_state_dict(checkpoint['state_dict'], strict=True)
if not self.train_config.reset_lr:
if 'opt_state_dict' in checkpoint.keys():
try:
self.G_optimizer.load_state_dict(checkpoint['opt_state_dict'])
except ValueError as ve:
self.logger.error("{}".format(ve))
else:
self.logger.info('No Optimizer State Loaded!!')
if 'lr_state_dict' in checkpoint.keys():
try:
self.G_scheduler.load_state_dict(checkpoint['lr_state_dict'])
except ValueError as ve:
self.logger.error("{}".format(ve))
else:
self.G_scheduler = lr_scheduler.CosineAnnealingLR(self.G_optimizer,
T_max=self.train_config.total_step - self.resume_step - 1)
if 'loss' in checkpoint.keys():
self.best_loss = checkpoint['loss']
def train(self):
data_iter = iter(self.train_dataloader)
if self.train_config.resume_checkpoint:
start = self.resume_step + 1
else:
start = 0
moving_max_grad = 0
moving_grad_moment = 0.999
max_grad = 0
for step in range(start, self.train_config.total_step + 1):
try:
image_dict = next(data_iter)
except:
data_iter = iter(self.train_dataloader)
image_dict = next(data_iter)
image, alpha, trimap, bbox = image_dict['image'], image_dict['alpha'], image_dict['trimap'], image_dict['boxes']
image = image.cuda()
alpha = alpha.cuda()
trimap = trimap.cuda()
bbox = bbox.cuda()
# train() of DistributedDataParallel has no return
self.G.train()
log_info = ""
loss = 0
"""===== Update Learning Rate ====="""
if step < self.train_config.warmup_step and self.train_config.resume_checkpoint is None:
cur_G_lr = utils.warmup_lr(self.train_config.G_lr, step + 1, self.train_config.warmup_step)
utils.update_lr(cur_G_lr, self.G_optimizer)
else:
self.G_scheduler.step()
cur_G_lr = self.G_scheduler.get_lr()[0]
"""===== Forward G ====="""
pred = self.G(image, bbox)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
mask = pred['mask']
weight_os8 = utils.get_unknown_tensor(mask)
weight_os8[...] = 1
if step < self.train_config.warmup_step:
weight_os4 = utils.get_unknown_tensor(mask)
weight_os1 = utils.get_unknown_tensor(mask)
weight_os4[...] = 1
weight_os1[...] = 1
elif step < self.train_config.warmup_step * 3:
if random.randint(0,1) == 0:
weight_os4 = utils.get_unknown_tensor(mask)
weight_os1 = utils.get_unknown_tensor(mask)
else:
weight_os4 = utils.get_unknown_tensor(trimap)
weight_os1 = utils.get_unknown_tensor(trimap)
else:
if random.randint(0,1) == 0:
weight_os4 = utils.get_unknown_tensor(trimap)
weight_os1 = utils.get_unknown_tensor(trimap)
else:
weight_os4 = utils.get_unknown_tensor_from_pred(alpha_pred_os8, rand_width=CONFIG.model.self_refine_width1, train_mode=True)
weight_os1 = utils.get_unknown_tensor_from_pred(alpha_pred_os4, rand_width=CONFIG.model.self_refine_width2, train_mode=True)
if self.train_config.rec_weight > 0:
self.loss_dict['rec_os1'] = self.regression_loss(alpha_pred_os1, alpha, loss_type='l1', weight=weight_os1) * 2 / 5.0 * self.train_config.rec_weight
self.loss_dict['rec_os4'] = self.regression_loss(alpha_pred_os4, alpha, loss_type='l1', weight=weight_os4) * 1 / 5.0 * self.train_config.rec_weight
self.loss_dict['rec_os8'] = self.regression_loss(alpha_pred_os8, alpha, loss_type='l1', weight=weight_os8) * 1 / 5.0 * self.train_config.rec_weight
if self.train_config.comp_weight > 0:
self.loss_dict['comp_os1'] = self.composition_loss(alpha_pred_os1, fg_norm, bg_norm, image, weight=weight_os1) * 2 / 5.0 * self.train_config.comp_weight
self.loss_dict['comp_os4'] = self.composition_loss(alpha_pred_os4, fg_norm, bg_norm, image, weight=weight_os4) * 1 / 5.0 * self.train_config.comp_weight
self.loss_dict['comp_os8'] = self.composition_loss(alpha_pred_os8, fg_norm, bg_norm, image, weight=weight_os8) * 1 / 5.0 * self.train_config.comp_weight
if self.train_config.lap_weight > 0:
self.loss_dict['lap_os1'] = self.lap_loss(logit=alpha_pred_os1, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os1) * 2 / 5.0 * self.train_config.lap_weight
self.loss_dict['lap_os4'] = self.lap_loss(logit=alpha_pred_os4, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os4) * 1 / 5.0 * self.train_config.lap_weight
self.loss_dict['lap_os8'] = self.lap_loss(logit=alpha_pred_os8, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os8) * 1 / 5.0 * self.train_config.lap_weight
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
loss += self.loss_dict[loss_key]
"""===== Back Propagate ====="""
self.reset_grad()
loss.backward()
"""===== Clip Large Gradient ====="""
if self.train_config.clip_grad:
if moving_max_grad == 0:
moving_max_grad = nn_utils.clip_grad_norm_(self.G.parameters(), 1e+6)
max_grad = moving_max_grad
else:
max_grad = nn_utils.clip_grad_norm_(self.G.parameters(), 2 * moving_max_grad)
moving_max_grad = moving_max_grad * moving_grad_moment + max_grad * (
1 - moving_grad_moment)
"""===== Update Parameters ====="""
self.G_optimizer.step()
"""===== Write Log and Tensorboard ====="""
# stdout log
if step % self.log_config.logging_step == 0:
# reduce losses from GPUs
if CONFIG.dist:
self.loss_dict = utils.reduce_tensor_dict(self.loss_dict, mode='mean')
loss = utils.reduce_tensor(loss)
# create logging information
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
log_info += loss_key.upper() + ": {:.4f}, ".format(self.loss_dict[loss_key])
if CONFIG.wandb and CONFIG.local_rank == 0:
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
wandb.log({'lr': cur_G_lr, 'total_loss': loss, loss_key.upper(): self.loss_dict[loss_key]}, step=step)
self.logger.debug("Image tensor shape: {}. Trimap tensor shape: {}".format(image.shape, trimap.shape))
log_info = "[{}/{}], ".format(step, self.train_config.total_step) + log_info
log_info += "lr: {:6f}".format(cur_G_lr)
self.logger.info(log_info)
# tensorboard
if step % self.log_config.tensorboard_step == 0 or step == start: # and step > start:
self.tb_logger.scalar_summary('Loss', loss, step)
# detailed losses
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
self.tb_logger.scalar_summary('Loss_' + loss_key.upper(),
self.loss_dict[loss_key], step)
self.tb_logger.scalar_summary('LearnRate', cur_G_lr, step)
if self.train_config.clip_grad:
self.tb_logger.scalar_summary('Moving_Max_Grad', moving_max_grad, step)
self.tb_logger.scalar_summary('Max_Grad', max_grad, step)
if (step % self.log_config.checkpoint_step == 0 or step == self.train_config.total_step) \
and CONFIG.local_rank == 0 and (step > start):
self.logger.info('Saving the trained models from step {}...'.format(iter))
self.save_model("model_step_{}".format(step), step, loss)
torch.cuda.empty_cache()
def save_model(self, checkpoint_name, iter, loss):
torch.save({
'iter': iter,
'loss': loss,
'state_dict': self.G.module.m2m.state_dict(),
'opt_state_dict': self.G_optimizer.state_dict(),
'lr_state_dict': self.G_scheduler.state_dict()
}, os.path.join(self.log_config.checkpoint_path, '{}.pth'.format(checkpoint_name)))
@staticmethod
def regression_loss(logit, target, loss_type='l1', weight=None):
"""
Alpha reconstruction loss
:param logit:
:param target:
:param loss_type: "l1" or "l2"
:param weight: tensor with shape [N,1,H,W] weights for each pixel
:return:
"""
if weight is None:
if loss_type == 'l1':
return F.l1_loss(logit, target)
elif loss_type == 'l2':
return F.mse_loss(logit, target)
else:
raise NotImplementedError("NotImplemented loss type {}".format(loss_type))
else:
if loss_type == 'l1':
return F.l1_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
elif loss_type == 'l2':
return F.mse_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
else:
raise NotImplementedError("NotImplemented loss type {}".format(loss_type))
@staticmethod
def smooth_l1(logit, target, weight):
loss = torch.sqrt((logit * weight - target * weight)**2 + 1e-6)
loss = torch.sum(loss) / (torch.sum(weight) + 1e-8)
return loss
@staticmethod
def mse(logit, target, weight):
# return F.mse_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
return Trainer.regression_loss(logit, target, loss_type='l2', weight=weight)
@staticmethod
def sad(logit, target, weight):
return F.l1_loss(logit * weight, target * weight, reduction='sum') / 1000
@staticmethod
def composition_loss(alpha, fg, bg, image, weight, loss_type='l1'):
"""
Alpha composition loss
"""
merged = fg * alpha + bg * (1 - alpha)
return Trainer.regression_loss(merged, image, loss_type=loss_type, weight=weight)
@staticmethod
def gabor_loss(logit, target, gabor_filter, loss_type='l2', weight=None):
""" pass """
gabor_logit = F.conv2d(logit, weight=gabor_filter, padding=2)
gabor_target = F.conv2d(target, weight=gabor_filter, padding=2)
return Trainer.regression_loss(gabor_logit, gabor_target, loss_type=loss_type, weight=weight)
@staticmethod
def grad_loss(logit, target, grad_filter, loss_type='l1', weight=None):
""" pass """
grad_logit = F.conv2d(logit, weight=grad_filter, padding=1)
grad_target = F.conv2d(target, weight=grad_filter, padding=1)
grad_logit = torch.sqrt((grad_logit * grad_logit).sum(dim=1, keepdim=True) + 1e-8)
grad_target = torch.sqrt((grad_target * grad_target).sum(dim=1, keepdim=True) + 1e-8)
return Trainer.regression_loss(grad_logit, grad_target, loss_type=loss_type, weight=weight)
@staticmethod
def lap_loss(logit, target, gauss_filter, loss_type='l1', weight=None):
'''
Based on FBA Matting implementation:
https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
'''
def conv_gauss(x, kernel):
x = F.pad(x, (2,2,2,2), mode='reflect')
x = F.conv2d(x, kernel, groups=x.shape[1])
return x
def downsample(x):
return x[:, :, ::2, ::2]
def upsample(x, kernel):
N, C, H, W = x.shape
cc = torch.cat([x, torch.zeros(N,C,H,W).cuda()], dim = 3)
cc = cc.view(N, C, H*2, W)
cc = cc.permute(0,1,3,2)
cc = torch.cat([cc, torch.zeros(N, C, W, H*2).cuda()], dim = 3)
cc = cc.view(N, C, W*2, H*2)
x_up = cc.permute(0,1,3,2)
return conv_gauss(x_up, kernel=4*gauss_filter)
def lap_pyramid(x, kernel, max_levels=3):
current = x
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
down = downsample(filtered)
up = upsample(down, kernel)
diff = current - up
pyr.append(diff)
current = down
return pyr
def weight_pyramid(x, max_levels=3):
current = x
pyr = []
for level in range(max_levels):
down = downsample(current)
pyr.append(current)
current = down
return pyr
pyr_logit = lap_pyramid(x = logit, kernel = gauss_filter, max_levels = 5)
pyr_target = lap_pyramid(x = target, kernel = gauss_filter, max_levels = 5)
if weight is not None:
pyr_weight = weight_pyramid(x = weight, max_levels = 5)
return sum(Trainer.regression_loss(A[0], A[1], loss_type=loss_type, weight=A[2]) * (2**i) for i, A in enumerate(zip(pyr_logit, pyr_target, pyr_weight)))
else:
return sum(Trainer.regression_loss(A[0], A[1], loss_type=loss_type, weight=None) * (2**i) for i, A in enumerate(zip(pyr_logit, pyr_target)))
\ No newline at end of file
from .logger import *
from .config import *
from .util import *
from .evaluate import *
\ No newline at end of file
from easydict import EasyDict
# Base default config
CONFIG = EasyDict({})
# to indicate this is a default setting, should not be changed by user
CONFIG.is_default = True
CONFIG.version = "baseline"
CONFIG.phase = "train"
# distributed training
CONFIG.dist = False
CONFIG.wandb = False
# global variables which will be assigned in the runtime
CONFIG.local_rank = 0
CONFIG.gpu = 0
CONFIG.world_size = 1
# Model config
CONFIG.model = EasyDict({})
# use pretrained checkpoint as encoder
CONFIG.model.freeze_seg = True
CONFIG.model.multi_scale = False
CONFIG.model.imagenet_pretrain = False
CONFIG.model.imagenet_pretrain_path = "/path/to/data/model_best_resnet34_En_nomixup.pth"
CONFIG.model.batch_size = 16
# one-hot or class, choice: [3, 1]
CONFIG.model.mask_channel = 1
CONFIG.model.trimap_channel = 3
# hyper-parameter for refinement
CONFIG.model.self_refine_width1 = 30
CONFIG.model.self_refine_width2 = 15
CONFIG.model.self_mask_width = 10
# Model -> Architecture config
CONFIG.model.arch = EasyDict({})
# definition in networks/encoders/__init__.py and networks/encoders/__init__.py
CONFIG.model.arch.encoder = "res_shortcut_encoder_29"
CONFIG.model.arch.decoder = "res_shortcut_decoder_22"
CONFIG.model.arch.m2m = "sam_decoder_deep"
CONFIG.model.arch.seg = "sam"
# predefined for GAN structure
CONFIG.model.arch.discriminator = None
# Dataloader config
CONFIG.data = EasyDict({})
CONFIG.data.cutmask_prob = 0
CONFIG.data.workers = 0
CONFIG.data.pha_ratio = 0.5
# data path for training and validation in training phase
CONFIG.data.train_fg = None
CONFIG.data.train_alpha = None
CONFIG.data.train_bg = None
CONFIG.data.test_merged = None
CONFIG.data.test_alpha = None
CONFIG.data.test_trimap = None
CONFIG.data.d646_fg = None
CONFIG.data.d646_pha = None
CONFIG.data.aim_fg = None
CONFIG.data.aim_pha = None
CONFIG.data.human2k_fg = None
CONFIG.data.human2k_pha = None
CONFIG.data.am2k_fg = None
CONFIG.data.am2k_pha = None
CONFIG.data.rim_pha = None
CONFIG.data.rim_img = None
CONFIG.data.coco_bg = None
CONFIG.data.bg20k_bg = None
# feed forward image size (untested)
CONFIG.data.crop_size = 1024
# composition of two foregrounds, affine transform, crop and HSV jitter
CONFIG.data.real_world_aug = False
CONFIG.data.augmentation = True
CONFIG.data.random_interp = True
### Benchmark config
CONFIG.benchmark = EasyDict({})
CONFIG.benchmark.him2k_img = '/path/to/data/HIM2K/images/natural'
CONFIG.benchmark.him2k_alpha = '/path/to/data/HIM2K/alphas/natural'
CONFIG.benchmark.him2k_comp_img = '/path/to/data/HIM2K/images/comp'
CONFIG.benchmark.him2k_comp_alpha = '/path/to/data/HIM2K/alphas/comp'
CONFIG.benchmark.rwp636_img = '/path/to/data/RealWorldPortrait-636/image'
CONFIG.benchmark.rwp636_alpha = '/path/to/data/RealWorldPortrait-636/alpha'
CONFIG.benchmark.ppm100_img = '/path/to/data/PPM-100/image'
CONFIG.benchmark.ppm100_alpha = '/path/to/data/PPM-100/matte'
CONFIG.benchmark.pm10k_img = '/path/to/data/P3M-10k/validation/P3M-500-NP/original_image'
CONFIG.benchmark.pm10k_alpha = '/path/to/data/P3M-10k/validation/P3M-500-NP/mask'
CONFIG.benchmark.am2k_img = '/path/to/data/AM2k/validation/original'
CONFIG.benchmark.am2k_alpha = '/path/to/data/AM2k/validation/mask'
CONFIG.benchmark.rw100_img = '/path/to/data/RefMatte_RW_100/image_all'
CONFIG.benchmark.rw100_alpha = '/path/to/data/RefMatte_RW_100/mask'
CONFIG.benchmark.rw100_text = '/path/to/data/RefMatte_RW_100/refmatte_rw100_label.json'
CONFIG.benchmark.rw100_index = '/path/to/data/RefMatte_RW_100/eval_index_expression.json'
# Training config
CONFIG.train = EasyDict({})
CONFIG.train.total_step = 100000
CONFIG.train.warmup_step = 5000
CONFIG.train.val_step = 1000
# basic learning rate of optimizer
CONFIG.train.G_lr = 1e-3
# beta1 and beta2 for Adam
CONFIG.train.beta1 = 0.5
CONFIG.train.beta2 = 0.999
# weight of different losses
CONFIG.train.rec_weight = 1
CONFIG.train.comp_weight = 0
CONFIG.train.lap_weight = 1
# clip large gradient
CONFIG.train.clip_grad = True
# resume the training (checkpoint file name)
CONFIG.train.resume_checkpoint = None
# reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup)
CONFIG.train.reset_lr = False
# Logging config
CONFIG.log = EasyDict({})
CONFIG.log.tensorboard_path = "./logs/tensorboard"
CONFIG.log.tensorboard_step = 100
# save less images to save disk space
CONFIG.log.tensorboard_image_step = 500
CONFIG.log.logging_path = "./logs/stdout"
CONFIG.log.logging_step = 10
CONFIG.log.logging_level = "DEBUG"
CONFIG.log.checkpoint_path = "./checkpoints"
CONFIG.log.checkpoint_step = 10000
def load_config(custom_config, default_config=CONFIG, prefix="CONFIG"):
"""
This function will recursively overwrite the default config by a custom config
:param default_config:
:param custom_config: parsed from config/config.toml
:param prefix: prefix for config key
:return: None
"""
if "is_default" in default_config:
default_config.is_default = False
for key in custom_config.keys():
full_key = ".".join([prefix, key])
if key not in default_config:
raise NotImplementedError("Unknown config key: {}".format(full_key))
elif isinstance(custom_config[key], dict):
if isinstance(default_config[key], dict):
load_config(default_config=default_config[key],
custom_config=custom_config[key],
prefix=full_key)
else:
raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key])))
else:
if isinstance(default_config[key], dict):
raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key])))
else:
default_config[key] = custom_config[key]
import os
from shutil import copyfile
### list based on https://github.com/senguptaumd/Background-Matting/blob/master/Data_adobe/train_data_list.txt
train_list = open('train_data_list.txt').read().splitlines()
### training
src_fg_path = '/export/ccvl12b/qihang/MGMatting/data/Combined_Dataset/Training_set/fg/'
src_alpha_path = '/export/ccvl12b/qihang/MGMatting/data/Combined_Dataset/Training_set/alpha/'
dst_fg_path = '/export/ccvl12b/qihang/MGMatting/data/Combined_Dataset_Solid/Training_set/fg/'
dst_alpha_path = '/export/ccvl12b/qihang/MGMatting/data/Combined_Dataset_Solid/Training_set/alpha/'
if not os.path.exists(dst_fg_path):
os.makedirs(dst_fg_path)
if not os.path.exists(dst_alpha_path):
os.makedirs(dst_alpha_path)
for f in train_list:
copyfile(src_fg_path + f, dst_fg_path + f)
copyfile(src_alpha_path + f, dst_alpha_path + f)
"""
Reimplement evaluation.mat provided by Adobe in python
Output of `compute_gradient_loss` is sightly different from the MATLAB version provided by Adobe (less than 0.1%)
Output of `compute_connectivity_error` is smaller than the MATLAB version (~5%, maybe MATLAB has a different algorithm)
So do not report results calculated by these functions in your paper.
Evaluate your inference with the MATLAB file `DIM_evaluation_code/evaluate.m`.
by Yaoyi Li
"""
import scipy.ndimage
import numpy as np
from skimage.measure import label
import scipy.ndimage.morphology
def gauss(x, sigma):
y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
return y
def dgauss(x, sigma):
y = -x * gauss(x, sigma) / (sigma ** 2)
return y
def gaussgradient(im, sigma):
epsilon = 1e-2
halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(np.int32)
size = 2 * halfsize + 1
hx = np.zeros((size, size))
for i in range(0, size):
for j in range(0, size):
u = [i - halfsize, j - halfsize]
hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma)
hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx)))
hy = hx.transpose()
gx = scipy.ndimage.convolve(im, hx, mode='nearest')
gy = scipy.ndimage.convolve(im, hy, mode='nearest')
return gx, gy
def compute_gradient_loss(pred, target, trimap):
pred = pred / 255.0
target = target / 255.0
pred_x, pred_y = gaussgradient(pred, 1.4)
target_x, target_y = gaussgradient(target, 1.4)
pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2)
target_amp = np.sqrt(target_x ** 2 + target_y ** 2)
error_map = (pred_amp - target_amp) ** 2
loss = np.sum(error_map[trimap == 128])
return loss / 1000.
def getLargestCC(segmentation):
labels = label(segmentation, connectivity=1)
largestCC = labels == np.argmax(np.bincount(labels.flat))
return largestCC
def compute_connectivity_error(pred, target, trimap, step=0.1):
pred = pred / 255.0
target = target / 255.0
h, w = pred.shape
thresh_steps = list(np.arange(0, 1 + step, step))
l_map = np.ones_like(pred, dtype=np.float) * -1
for i in range(1, len(thresh_steps)):
pred_alpha_thresh = (pred >= thresh_steps[i]).astype(np.int)
target_alpha_thresh = (target >= thresh_steps[i]).astype(np.int)
omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(np.int)
flag = ((l_map == -1) & (omega == 0)).astype(np.int)
l_map[flag == 1] = thresh_steps[i - 1]
l_map[l_map == -1] = 1
pred_d = pred - l_map
target_d = target - l_map
pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(np.int)
target_phi = 1 - target_d * (target_d >= 0.15).astype(np.int)
loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128])
return loss / 1000.
def compute_mse_loss(pred, target, trimap):
error_map = (pred - target) / 255.0
loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
return loss
def compute_sad_loss(pred, target, trimap):
error_map = np.abs((pred - target) / 255.0)
loss = np.sum(error_map * (trimap == 128))
return loss / 1000, np.sum(trimap == 128) / 1000
def compute_mad_loss(pred, target, trimap):
error_map = np.abs((pred - target) / 255.0)
loss = np.sum(error_map * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
return loss
import os
import cv2
import torch
import logging
import datetime
import numpy as np
from pprint import pprint
from utils import util
from utils.config import CONFIG
from tensorboardX import SummaryWriter
LEVELS = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
def make_color_wheel():
# from https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
COLORWHEEL = make_color_wheel()
def compute_color(u,v):
# from https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py
h, w = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = COLORWHEEL
# colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
def flow_to_image(flow):
# part from https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_ops.py
maxrad = -1
u = flow[0, :, :]
v = flow[1, :, :]
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(maxrad, np.max(rad))
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
return img
def put_text(image, text, position=(10, 20)):
image = cv2.resize(image.transpose([1, 2, 0]), (512, 512), interpolation=cv2.INTER_NEAREST)
return cv2.putText(image, text, position, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 0, thickness=2).transpose([2, 0, 1])
class TensorBoardLogger(object):
def __init__(self, tb_log_dir, exp_string):
"""
Initialize summary writer
"""
self.exp_string = exp_string
self.tb_log_dir = tb_log_dir
self.val_img_dir = os.path.join(self.tb_log_dir, 'val_image')
if CONFIG.local_rank == 0:
util.make_dir(self.tb_log_dir)
util.make_dir(self.val_img_dir)
self.writer = SummaryWriter(self.tb_log_dir+'/' + self.exp_string)
else:
self.writer = None
def scalar_summary(self, tag, value, step, phase='train'):
if CONFIG.local_rank == 0:
sum_name = '{}/{}'.format(phase.capitalize(), tag)
self.writer.add_scalar(sum_name, value, step)
def image_summary(self, image_set, step, phase='train', save_val=True):
"""
Record image in tensorboard
The input image should be a numpy array with shape (C, H, W) like a torch tensor
:param image_set: dict of images
:param step:
:param phase:
:param save_val: save images in folder in validation or testing
:return:
"""
if CONFIG.local_rank == 0:
for tag, image_numpy in image_set.items():
sum_name = '{}/{}'.format(phase.capitalize(), tag)
image_numpy = image_numpy.transpose([1, 2, 0])
image_numpy = cv2.resize(image_numpy, (360, 360), interpolation=cv2.INTER_NEAREST)
if len(image_numpy.shape) == 2:
image_numpy = image_numpy[None, :,:]
else:
image_numpy = image_numpy.transpose([2, 0, 1])
self.writer.add_image(sum_name, image_numpy, step)
if (phase=='test') and save_val:
tags = list(image_set.keys())
image_pack = self._reshape_rgb(image_set[tags[0]])
image_pack = cv2.resize(image_pack, (512, 512), interpolation=cv2.INTER_NEAREST)
for tag in tags[1:]:
image = self._reshape_rgb(image_set[tag])
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_NEAREST)
image_pack = np.concatenate((image_pack, image), axis=1)
cv2.imwrite(os.path.join(self.val_img_dir, 'val_{:d}'.format(step)+'.png'), image_pack)
@staticmethod
def _reshape_rgb(image):
"""
Transform RGB/L -> BGR for OpenCV
"""
if len(image.shape) == 3 and image.shape[0] == 3:
image = image.transpose([1, 2, 0])
image = image[...,::-1]
elif len(image.shape) == 3 and image.shape[0] == 1:
image = image.transpose([1, 2, 0])
image = np.repeat(image, 3, axis=2)
elif len(image.shape) == 2:
# image = image.transpose([1,0])
image = np.stack((image, image, image), axis=2)
else:
raise ValueError('Image shape {} not supported to save'.format(image.shape))
return image
def __del__(self):
if self.writer is not None:
self.writer.close()
class MyLogger(logging.Logger):
"""
Only write log in the first subprocess
"""
def __init__(self, *args, **kwargs):
super(MyLogger, self).__init__(*args, **kwargs)
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False):
if CONFIG.local_rank == 0:
super()._log(level, msg, args, exc_info, extra, stack_info)
def get_logger(log_dir=None, tb_log_dir=None, logging_level="DEBUG"):
"""
Return a default build-in logger if log_file=None and tb_log_dir=None
Return a build-in logger which dump stdout to log_file if log_file is assigned
Return a build-in logger and tensorboard summary writer if tb_log_dir is assigned
:param log_file: logging file dumped from stdout
:param tb_log_dir: tensorboard dir
:param logging_level:
:return: Logger or [Logger, TensorBoardLogger]
"""
level = LEVELS[logging_level.upper()]
exp_string = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
logging.setLoggerClass(MyLogger)
logger = logging.getLogger('Logger')
logger.setLevel(level)
# create formatter
formatter = logging.Formatter('[%(asctime)s] %(levelname)s: %(message)s', datefmt='%m-%d %H:%M:%S')
# create console handler
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
# create file handler
if log_dir is not None and CONFIG.local_rank == 0:
log_file = os.path.join(log_dir, exp_string)
fh = logging.FileHandler(log_file+'.log', mode='w')
fh.setLevel(level)
fh.setFormatter(formatter)
logger.addHandler(fh)
pprint(CONFIG, stream=fh.stream)
# create tensorboard summary writer
if tb_log_dir is not None:
tb_logger = TensorBoardLogger(tb_log_dir=tb_log_dir, exp_string=exp_string)
return logger, tb_logger
else:
return logger
def normalize_image(image):
"""
normalize image array to 0~1
"""
image_flat = torch.flatten(image, start_dim=1)
return (image - image_flat.min(dim=1, keepdim=False)[0].view(3,1,1)) / (
image_flat.max(dim=1, keepdim=False)[0].view(3,1,1) - image_flat.min(dim=1, keepdim=False)[0].view(3,1,1) + 1e-8)
import os
import cv2
import torch
import logging
import numpy as np
from utils.config import CONFIG
import torch.distributed as dist
import torch.nn.functional as F
from skimage.measure import label
def make_dir(target_dir):
"""
Create dir if not exists
"""
if not os.path.exists(target_dir):
os.makedirs(target_dir)
def print_network(model, name):
"""
Print out the network information
"""
logger = logging.getLogger("Logger")
num_params = 0
for p in model.parameters():
num_params += p.numel()
logger.info(model)
logger.info(name)
logger.info("Number of parameters: {}".format(num_params))
def update_lr(lr, optimizer):
"""
update learning rates
"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def warmup_lr(init_lr, step, iter_num):
"""
Warm up learning rate
"""
return step/iter_num*init_lr
def add_prefix_state_dict(state_dict, prefix="module"):
"""
add prefix from the key of pretrained state dict for Data-Parallel
"""
new_state_dict = {}
first_state_name = list(state_dict.keys())[0]
if not first_state_name.startswith(prefix):
for key, value in state_dict.items():
new_state_dict[prefix+"."+key] = state_dict[key].float()
else:
for key, value in state_dict.items():
new_state_dict[key] = state_dict[key].float()
return new_state_dict
def remove_prefix_state_dict(state_dict, prefix="module"):
"""
remove prefix from the key of pretrained state dict for Data-Parallel
"""
new_state_dict = {}
first_state_name = list(state_dict.keys())[0]
if not first_state_name.startswith(prefix):
for key, value in state_dict.items():
new_state_dict[key] = state_dict[key].float()
else:
for key, value in state_dict.items():
new_state_dict[key[len(prefix)+1:]] = state_dict[key].float()
return new_state_dict
def load_imagenet_pretrain(model, checkpoint_file):
"""
Load imagenet pretrained resnet
Add zeros channel to the first convolution layer
Since we have the spectral normalization, we need to do a little more
"""
checkpoint = torch.load(checkpoint_file, map_location = lambda storage, loc: storage.cuda(CONFIG.gpu))
state_dict = remove_prefix_state_dict(checkpoint['state_dict'])
for key, value in state_dict.items():
state_dict[key] = state_dict[key].float()
logger = logging.getLogger("Logger")
logger.debug("Imagenet pretrained keys:")
logger.debug(state_dict.keys())
logger.debug("Generator keys:")
logger.debug(model.module.encoder.state_dict().keys())
logger.debug("Intersection keys:")
logger.debug(set(model.module.encoder.state_dict().keys())&set(state_dict.keys()))
weight_u = state_dict["conv1.module.weight_u"]
weight_v = state_dict["conv1.module.weight_v"]
weight_bar = state_dict["conv1.module.weight_bar"]
logger.debug("weight_v: {}".format(weight_v))
logger.debug("weight_bar: {}".format(weight_bar.view(32, -1)))
logger.debug("sigma: {}".format(weight_u.dot(weight_bar.view(32, -1).mv(weight_v))))
new_weight_v = torch.zeros((3+CONFIG.model.mask_channel), 3, 3).cuda()
new_weight_bar = torch.zeros(32, (3+CONFIG.model.mask_channel), 3, 3).cuda()
new_weight_v[:3, :, :].copy_(weight_v.view(3, 3, 3))
new_weight_bar[:, :3, :, :].copy_(weight_bar)
logger.debug("new weight_v: {}".format(new_weight_v.view(-1)))
logger.debug("new weight_bar: {}".format(new_weight_bar.view(32, -1)))
logger.debug("new sigma: {}".format(weight_u.dot(new_weight_bar.view(32, -1).mv(new_weight_v.view(-1)))))
state_dict["conv1.module.weight_v"] = new_weight_v.view(-1)
state_dict["conv1.module.weight_bar"] = new_weight_bar
model.module.encoder.load_state_dict(state_dict, strict=False)
def load_VGG_pretrain(model, checkpoint_file):
"""
Load imagenet pretrained resnet
Add zeros channel to the first convolution layer
Since we have the spectral normalization, we need to do a little more
"""
checkpoint = torch.load(checkpoint_file, map_location = lambda storage, loc: storage.cuda())
backbone_state_dict = remove_prefix_state_dict(checkpoint['state_dict'])
model.module.encoder.load_state_dict(backbone_state_dict, strict=False)
def get_unknown_tensor(trimap):
"""
get 1-channel unknown area tensor from the 3-channel/1-channel trimap tensor
"""
if trimap.shape[1] == 3:
weight = trimap[:, 1:2, :, :].float()
else:
weight = trimap.eq(1).float()
return weight
def get_gaborfilter(angles):
"""
generate gabor filter as the conv kernel
:param angles: number of different angles
"""
gabor_filter = []
for angle in range(angles):
gabor_filter.append(cv2.getGaborKernel(ksize=(5,5), sigma=0.5, theta=angle*np.pi/8, lambd=5, gamma=0.5))
gabor_filter = np.array(gabor_filter)
gabor_filter = np.expand_dims(gabor_filter, axis=1)
return gabor_filter.astype(np.float32)
def get_gradfilter():
"""
generate gradient filter as the conv kernel
"""
grad_filter = []
grad_filter.append([[-1, -2, -1], [0, 0, 0], [1, 2, 1]])
grad_filter.append([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
grad_filter = np.array(grad_filter)
grad_filter = np.expand_dims(grad_filter, axis=1)
return grad_filter.astype(np.float32)
def reduce_tensor_dict(tensor_dict, mode='mean'):
"""
average tensor dict over different GPUs
"""
for key, tensor in tensor_dict.items():
if tensor is not None:
tensor_dict[key] = reduce_tensor(tensor, mode)
return tensor_dict
def reduce_tensor(tensor, mode='mean'):
"""
average tensor over different GPUs
"""
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
if mode == 'mean':
rt /= CONFIG.world_size
elif mode == 'sum':
pass
else:
raise NotImplementedError("reduce mode can only be 'mean' or 'sum'")
return rt
### preprocess the image and mask for inference (np array), crop based on ROI
def preprocess(image, mask, thres):
mask_ = (mask >= thres).astype(np.float32)
arr = np.nonzero(mask_)
h, w = mask.shape
bbox = [max(0, int(min(arr[0]) - 0.1*h)),
min(h, int(max(arr[0]) + 0.1*h)),
max(0, int(min(arr[1]) - 0.1*w)),
min(w, int(max(arr[1]) + 0.1*w))]
image = image[bbox[0]:bbox[1], bbox[2]:bbox[3], :]
mask = mask[bbox[0]:bbox[1], bbox[2]:bbox[3]]
return image, mask, bbox
### postprocess the alpha prediction to keep the largest connected component (np array) and uncrop, alpha in [0, 1]
### based on https://github.com/senguptaumd/Background-Matting/blob/master/test_background-matting_image.py
def postprocess(alpha, orih=None, oriw=None, bbox=None):
labels=label((alpha>0.05).astype(int))
try:
assert( labels.max() != 0 )
except:
return None
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
alpha = alpha * largestCC
if bbox is None:
return alpha
else:
ori_alpha = np.zeros(shape=[orih, oriw], dtype=np.float32)
ori_alpha[bbox[0]:bbox[1], bbox[2]:bbox[3]] = alpha
return ori_alpha
Kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
def get_unknown_tensor_from_pred(pred, rand_width=30, train_mode=True):
### pred: N, 1 ,H, W
N, C, H, W = pred.shape
pred = pred.data.cpu().numpy()
uncertain_area = np.ones_like(pred, dtype=np.uint8)
uncertain_area[pred<1.0/255.0] = 0
uncertain_area[pred>1-1.0/255.0] = 0
for n in range(N):
uncertain_area_ = uncertain_area[n,0,:,:] # H, W
if train_mode:
width = np.random.randint(1, rand_width)
else:
width = rand_width // 2
uncertain_area_ = cv2.dilate(uncertain_area_, Kernels[width])
uncertain_area[n,0,:,:] = uncertain_area_
weight = np.zeros_like(uncertain_area)
weight[uncertain_area == 1] = 1
weight = torch.from_numpy(weight).cuda()
return weight
def get_unknown_tensor_from_pred_oneside(pred, rand_width=30, train_mode=True):
### pred: N, 1 ,H, W
N, C, H, W = pred.shape
pred = pred.data.cpu().numpy()
uncertain_area = np.ones_like(pred, dtype=np.uint8)
uncertain_area[pred<1.0/255.0] = 0
#uncertain_area[pred>1-1.0/255.0] = 0
for n in range(N):
uncertain_area_ = uncertain_area[n,0,:,:] # H, W
if train_mode:
width = np.random.randint(1, rand_width)
else:
width = rand_width // 2
uncertain_area_ = cv2.dilate(uncertain_area_, Kernels[width])
uncertain_area[n,0,:,:] = uncertain_area_
uncertain_area[pred>1-1.0/255.0] = 0
#weight = np.zeros_like(uncertain_area)
#weight[uncertain_area == 1] = 1
weight = torch.from_numpy(uncertain_area).cuda()
return weight
Kernels_mask = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
def get_unknown_tensor_from_mask(mask, rand_width=30, train_mode=True):
"""
get 1-channel unknown area tensor from the 3-channel/1-channel trimap tensor
"""
N, C, H, W = mask.shape
mask_c = mask.data.cpu().numpy().astype(np.uint8)
weight = np.ones_like(mask_c, dtype=np.uint8)
for n in range(N):
if train_mode:
width = np.random.randint(rand_width // 2, rand_width)
else:
width = rand_width // 2
fg_mask = cv2.erode(mask_c[n,0], Kernels_mask[width])
bg_mask = cv2.erode(1 - mask_c[n,0], Kernels_mask[width])
weight[n,0][fg_mask==1] = 0
weight[n,0][bg_mask==1] = 0
weight = torch.from_numpy(weight).cuda()
return weight
def get_unknown_tensor_from_mask_oneside(mask, rand_width=30, train_mode=True):
"""
get 1-channel unknown area tensor from the 3-channel/1-channel trimap tensor
"""
N, C, H, W = mask.shape
mask_c = mask.data.cpu().numpy().astype(np.uint8)
weight = np.ones_like(mask_c, dtype=np.uint8)
for n in range(N):
if train_mode:
width = np.random.randint(rand_width // 2, rand_width)
else:
width = rand_width // 2
#fg_mask = cv2.erode(mask_c[n,0], Kernels_mask[width])
fg_mask = mask_c[n,0]
bg_mask = cv2.erode(1 - mask_c[n,0], Kernels_mask[width])
weight[n,0][fg_mask==1] = 0
weight[n,0][bg_mask==1] = 0
weight = torch.from_numpy(weight).cuda()
return weight
def get_unknown_box_from_mask(mask):
"""
get 1-channel unknown area tensor from the 3-channel/1-channel trimap tensor
"""
N, C, H, W = mask.shape
mask_c = mask.data.cpu().numpy().astype(np.uint8)
weight = np.ones_like(mask_c, dtype=np.uint8)
fg_set = np.where(mask_c[0][0] != 0)
x_min = np.min(fg_set[1])
x_max = np.max(fg_set[1])
y_min = np.min(fg_set[0])
y_max = np.max(fg_set[0])
weight[0, 0, y_min:y_max, x_min:x_max] = 0
weight = torch.from_numpy(weight).cuda()
return weight
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment