Commit 8f9687f5 authored by mashun1's avatar mashun1
Browse files

ridcp

parents
Pipeline #617 canceled with stages
import argparse
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from basicsr.data import build_dataset
from basicsr.metrics.fid import extract_inception_features, load_patched_inception_v3
def calculate_stats_from_dataset():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--size', type=int, default=512)
parser.add_argument('--dataroot', type=str, default='datasets/ffhq')
args = parser.parse_args()
# inception model
inception = load_patched_inception_v3(device)
# create dataset
opt = {}
opt['name'] = 'FFHQ'
opt['type'] = 'FFHQDataset'
opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb'
opt['io_backend'] = dict(type='lmdb')
opt['use_hflip'] = False
opt['mean'] = [0.5, 0.5, 0.5]
opt['std'] = [0.5, 0.5, 0.5]
dataset = build_dataset(opt)
# create dataloader
data_loader = DataLoader(
dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False)
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
for idx, data in enumerate(data_loader):
if idx >= total_batch:
break
else:
yield data['gt']
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
save_path = f'inception_{opt["name"]}_{args.size}.pth'
torch.save(
dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False)
if __name__ == '__main__':
calculate_stats_from_dataset()
import cv2
import glob
import numpy as np
import os.path as osp
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor
try:
import lpips
except ImportError:
print('Please install lpips: pip install lpips')
def main():
# Configurations
# -------------------------------------------------------------------------
folder_gt = 'datasets/celeba/celeba_512_validation'
folder_restored = 'datasets/celeba/celeba_512_validation_lq'
# crop_border = 4
suffix = ''
# -------------------------------------------------------------------------
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
lpips_all = []
img_list = sorted(glob.glob(osp.join(folder_gt, '*')))
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
for i, img_path in enumerate(img_list):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(
np.float32) / 255.
img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
# norm to [-1, 1]
normalize(img_gt, mean, std, inplace=True)
normalize(img_restored, mean, std, inplace=True)
# calculate lpips
lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
lpips_all.append(lpips_val)
print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
if __name__ == '__main__':
main()
import argparse
import cv2
import os
import warnings
from basicsr.metrics import calculate_niqe
from basicsr.utils import scandir
def main(args):
niqe_all = []
img_list = sorted(scandir(args.input, recursive=True, full_path=True))
for i, img_path in enumerate(img_list):
basename, _ = os.path.splitext(os.path.basename(img_path))
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y')
print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
niqe_all.append(niqe_score)
print(args.input)
print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, default='datasets/val_set14/Set14', help='Input path')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
args = parser.parse_args()
main(args)
import argparse
import cv2
import numpy as np
from os import path as osp
from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.utils import scandir
from basicsr.utils.matlab_functions import bgr2ycbcr
def main(args):
"""Calculate PSNR and SSIM for images.
"""
psnr_all = []
ssim_all = []
img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True)))
if args.test_y_channel:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
for i, img_path in enumerate(img_list_gt):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.suffix == '':
img_path_restored = img_list_restored[i]
else:
img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.correct_mean_var:
mean_l = []
std_l = []
for j in range(3):
mean_l.append(np.mean(img_gt[:, :, j]))
std_l.append(np.std(img_gt[:, :, j]))
for j in range(3):
# correct twice
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
img_gt = bgr2ycbcr(img_gt, y_only=True)
img_restored = bgr2ycbcr(img_restored, y_only=True)
# calculate PSNR and SSIM
psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
psnr_all.append(psnr)
ssim_all.append(ssim)
print(args.gt)
print(args.restored)
print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gt', type=str, default='datasets/val_set14/Set14', help='Path to gt (Ground-Truth)')
parser.add_argument('--restored', type=str, default='results/Set14', help='Path to restored images')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
parser.add_argument('--suffix', type=str, default='', help='Suffix for restored images')
parser.add_argument(
'--test_y_channel',
action='store_true',
help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.')
parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.')
args = parser.parse_args()
main(args)
import argparse
import math
import numpy as np
import torch
from torch import nn
from basicsr.archs.stylegan2_arch import StyleGAN2Generator
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_stylegan2_fid():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.')
parser.add_argument('fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
args = parser.parse_args()
# create stylegan2 model
generator = StyleGAN2Generator(
out_size=args.size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=args.channel_multiplier,
resample_kernel=(1, 3, 3, 1))
generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
generator = nn.DataParallel(generator).eval().to(device)
if args.truncation < 1:
with torch.no_grad():
truncation_latent = generator.mean_latent(args.truncation_mean)
else:
truncation_latent = None
# inception model
inception = load_patched_inception_v3(device)
total_batch = math.ceil(args.num_sample / args.batch_size)
def sample_generator(total_batch):
for _ in range(total_batch):
with torch.no_grad():
latent = torch.randn(args.batch_size, 512, device=device)
samples, _ = generator([latent], truncation=args.truncation, truncation_latent=truncation_latent)
yield samples
features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
if __name__ == '__main__':
calculate_stylegan2_fid()
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import MODEL_REGISTRY
__all__ = ['build_model']
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must contain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model
import os
import time
import torch
from collections import OrderedDict
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils import get_root_logger
from basicsr.utils.dist_util import master_only
class BaseModel():
"""Base model."""
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def save(self, epoch, current_iter):
"""Save networks and training state."""
pass
def validation(self, dataloader, current_iter, tb_logger, save_img=False, save_as_dir=None):
"""Validation function.
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img, save_as_dir)
else:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img, save_as_dir)
def _initialize_best_metric_results(self, dataset_name):
"""Initialize the best metric results dict for recording the best metric value and iteration."""
if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
return
elif not hasattr(self, 'best_metric_results'):
self.best_metric_results = dict()
# add a dataset record
record = dict()
for metric, content in self.opt['val']['metrics'].items():
better = content.get('better', 'higher')
init_val = float('-inf') if better == 'higher' else float('inf')
record[metric] = dict(better=better, val=init_val, iter=-1)
self.best_metric_results[dataset_name] = record
def _update_metric_result(self, dataset_name, metric, val, current_iter):
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
if val >= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
return True
else:
return False
else:
if val <= self.best_metric_results[dataset_name][metric]['val']:
self.best_metric_results[dataset_name][metric]['val'] = val
self.best_metric_results[dataset_name][metric]['iter'] = current_iter
return True
else:
return False
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.net_g)
net_g_params = dict(net_g.named_parameters())
net_g_ema_params = dict(self.net_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
def copy_model(self, net_a, net_b):
"""copy model from net_a to net_b"""
tmp_net_a = self.get_bare_model(net_a)
tmp_net_b = self.get_bare_model(net_b)
tmp_net_b.load_state_dict(tmp_net_a.state_dict())
def get_current_log(self):
return self.log_dict
def model_to_device(self, net):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
net (nn.Module)
"""
net = net.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
def get_optimizer(self, optim_type, params, lr, **kwargs):
if optim_type == 'Adam':
optimizer = torch.optim.Adam(params, lr, **kwargs)
else:
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
return optimizer
def setup_schedulers(self):
"""Set up schedulers."""
train_opt = self.opt['train']
scheduler_type = train_opt['scheduler'].pop('type')
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
elif scheduler_type == 'CosineAnnealingRestartLR':
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
else:
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
return net
@master_only
def print_network(self, net):
"""Print the str and parameter number of a network.
Args:
net (nn.Module)
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
else:
net_cls_str = f'{net.__class__.__name__}'
net = self.get_bare_model(net)
net_str = str(net)
net_params = sum(map(lambda x: x.numel(), net.parameters()))
logger = get_root_logger()
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
logger.info(net_str)
def _set_lr(self, lr_groups_l):
"""Set learning rate for warmup.
Args:
lr_groups_l (list): List for lr_groups, each for an optimizer.
"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler.
"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, current_iter, warmup_iter=-1):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
Default: -1.
"""
if current_iter > 1:
for scheduler in self.schedulers:
scheduler.step()
# set up warm-up learning rate
if current_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
# currently only support linearly warm up
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [optim.param_groups[0]['lr'] for optim in self.optimizers]
@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
"""Save networks.
Args:
net (nn.Module | list[nn.Module]): Network(s) to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
save_path = os.path.join(self.opt['path']['models'], save_filename)
net = net if isinstance(net, list) else [net]
param_key = param_key if isinstance(param_key, list) else [param_key]
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
save_dict = {}
for net_, param_key_ in zip(net, param_key):
net_ = self.get_bare_model(net_)
state_dict = net_.state_dict()
for key, param in state_dict.items():
if key.startswith('module.'): # remove unnecessary 'module.'
key = key[7:]
state_dict[key] = param.cpu()
save_dict[param_key_] = state_dict
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(save_dict, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
"""Print keys with different name or different size when loading models.
1. Print keys with different names.
2. If strict=False, print the same key but with different tensor size.
It also ignore these keys with different sizes (not load).
Args:
crt_net (torch model): Current network.
load_net (dict): Loaded network.
strict (bool): Whether strictly loaded. Default: True.
"""
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
logger = get_root_logger()
if crt_net_keys != load_net_keys:
logger.warning('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
logger.warning(f' {v}')
logger.warning('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
logger.warning(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
logger.warning(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
def load_network(self, net, load_path, strict=True, param_key='params'):
"""Load network.
Args:
load_path (str): The path of networks to be loaded.
net (nn.Module): Network.
strict (bool): Whether strictly loaded.
param_key (str): The parameter key of loaded network. If set to
None, use the root 'path'.
Default: 'params'.
"""
logger = get_root_logger()
net = self.get_bare_model(net)
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info('Loading: params_ema does not exist, use params.')
load_net = load_net[param_key]
logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
self._print_different_keys_loading(net, load_net, strict)
net.load_state_dict(load_net, strict=strict)
@master_only
def save_training_state(self, epoch, current_iter):
"""Save training states during training, which will be used for
resuming.
Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.
"""
if current_iter != -1:
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
save_filename = f'{current_iter}.state'
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
# avoid occasional writing errors
retry = 3
while retry > 0:
try:
torch.save(state, save_path)
except Exception as e:
logger = get_root_logger()
logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
time.sleep(1)
else:
break
finally:
retry -= 1
if retry == 0:
logger.warning(f'Still cannot save {save_path}. Just ignore it.')
# raise IOError(f'Cannot save {save_path}.')
def resume_training(self, resume_state):
"""Reload the optimizers and schedulers for resumed training.
Args:
resume_state (dict): Resume state.
"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
def reduce_loss_dict(self, loss_dict):
"""reduce loss dict.
In distributed training, it averages the losses among different GPUs .
Args:
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
if self.opt['dist']:
keys = []
losses = []
for name, value in loss_dict.items():
keys.append(name)
losses.append(value)
losses = torch.stack(losses, 0)
torch.distributed.reduce(losses, dst=0)
if self.opt['rank'] == 0:
losses /= self.opt['world_size']
loss_dict = {key: loss for key, loss in zip(keys, losses)}
log_dict = OrderedDict()
for name, value in loss_dict.items():
log_dict[name] = value.mean().item()
return log_dict
This diff is collapsed.
This diff is collapsed.
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment