"vscode:/vscode.git/clone" did not exist on "2683991e32c81f1e44211a116261cb7e2d7cb2a0"
Commit e2696ece authored by mashun1's avatar mashun1
Browse files

controlnet

parents
Pipeline #643 canceled with stages
import argparse
import cv2
import numpy as np
from os import path as osp
from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.utils import bgr2ycbcr, scandir
def main(args):
"""Calculate PSNR and SSIM for images.
"""
psnr_all = []
ssim_all = []
img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True)))
if args.test_y_channel:
print('Testing Y channel.')
else:
print('Testing RGB channels.')
for i, img_path in enumerate(img_list_gt):
basename, ext = osp.splitext(osp.basename(img_path))
img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.suffix == '':
img_path_restored = img_list_restored[i]
else:
img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
if args.correct_mean_var:
mean_l = []
std_l = []
for j in range(3):
mean_l.append(np.mean(img_gt[:, :, j]))
std_l.append(np.std(img_gt[:, :, j]))
for j in range(3):
# correct twice
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
mean = np.mean(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
std = np.std(img_restored[:, :, j])
img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
img_gt = bgr2ycbcr(img_gt, y_only=True)
img_restored = bgr2ycbcr(img_restored, y_only=True)
# calculate PSNR and SSIM
psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
psnr_all.append(psnr)
ssim_all.append(ssim)
print(args.gt)
print(args.restored)
print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gt', type=str, default='datasets/val_set14/Set14', help='Path to gt (Ground-Truth)')
parser.add_argument('--restored', type=str, default='results/Set14', help='Path to restored images')
parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
parser.add_argument('--suffix', type=str, default='', help='Suffix for restored images')
parser.add_argument(
'--test_y_channel',
action='store_true',
help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.')
parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.')
args = parser.parse_args()
main(args)
import argparse
import math
import numpy as np
import torch
from torch import nn
from basicsr.archs.stylegan2_arch import StyleGAN2Generator
from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
def calculate_stylegan2_fid():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.')
parser.add_argument('fid_stats', type=str, help='Path to the dataset fid statistics.')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_sample', type=int, default=50000)
parser.add_argument('--truncation', type=float, default=1)
parser.add_argument('--truncation_mean', type=int, default=4096)
args = parser.parse_args()
# create stylegan2 model
generator = StyleGAN2Generator(
out_size=args.size,
num_style_feat=512,
num_mlp=8,
channel_multiplier=args.channel_multiplier,
resample_kernel=(1, 3, 3, 1))
generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
generator = nn.DataParallel(generator).eval().to(device)
if args.truncation < 1:
with torch.no_grad():
truncation_latent = generator.mean_latent(args.truncation_mean)
else:
truncation_latent = None
# inception model
inception = load_patched_inception_v3(device)
total_batch = math.ceil(args.num_sample / args.batch_size)
def sample_generator(total_batch):
for _ in range(total_batch):
with torch.no_grad():
latent = torch.randn(args.batch_size, 512, device=device)
samples, _ = generator([latent], truncation=args.truncation, truncation_latent=truncation_latent)
yield samples
features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
features = features[:args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
# load the dataset stats
stats = torch.load(args.fid_stats)
real_mean = stats['mean']
real_cov = stats['cov']
# calculate FID metric
fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
print('fid:', fid)
if __name__ == '__main__':
calculate_stylegan2_fid()
import torch
from basicsr.archs.dfdnet_arch import DFDNet
from basicsr.archs.vgg_arch import NAMES
def convert_net(ori_net, crt_net):
for crt_k, _ in crt_net.items():
# vgg feature extractor
if 'vgg_extractor' in crt_k:
ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model')
if 'mean' in crt_k:
ori_k = ori_k.replace('mean', 'RGB_mean')
elif 'std' in crt_k:
ori_k = ori_k.replace('std', 'RGB_std')
else:
idx = NAMES['vgg19'].index(crt_k.split('.')[2])
if 'weight' in crt_k:
ori_k = f'VggExtract.model.features.{idx}.weight'
else:
ori_k = f'VggExtract.model.features.{idx}.bias'
elif 'attn_blocks' in crt_k:
if 'left_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.left_eye', 'le')
elif 'right_eye' in crt_k:
ori_k = crt_k.replace('attn_blocks.right_eye', 're')
elif 'mouth' in crt_k:
ori_k = crt_k.replace('attn_blocks.mouth', 'mo')
elif 'nose' in crt_k:
ori_k = crt_k.replace('attn_blocks.nose', 'no')
else:
raise ValueError('Wrong!')
elif 'multi_scale_dilation' in crt_k:
if 'conv_blocks' in crt_k:
_, _, c, d, e = crt_k.split('.')
ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}'
else:
ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi')
elif crt_k.startswith('upsample'):
ori_k = crt_k.replace('upsample', 'up')
if 'scale_block' in crt_k:
ori_k = ori_k.replace('scale_block', 'ScaleModel1')
elif 'shift_block' in crt_k:
ori_k = ori_k.replace('shift_block', 'ShiftModel1')
elif 'upsample4' in crt_k and 'body' in crt_k:
ori_k = ori_k.replace('body', 'Model')
else:
print('unprocess key: ', crt_k)
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
raise ValueError('Wrong tensor size: \n'
f'crt_net: {crt_net[crt_k].size()}\n'
f'ori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
if __name__ == '__main__':
ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth')
dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth')
crt_net = dfd_net.state_dict()
crt_net_params = convert_net(ori_net, crt_net)
torch.save(
dict(params=crt_net_params),
'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
_use_new_zipfile_serialization=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment