Commit 8f9687f5 authored by mashun1's avatar mashun1
Browse files

ridcp

parents
Pipeline #617 canceled with stages
import argparse
import cv2
import glob
import os
from tqdm import tqdm
import torch
from yaml import load
import time
from basicsr.utils import img2tensor, tensor2img, imwrite
from basicsr.archs.dehaze_vq_weight_arch import VQWeightDehazeNet
from basicsr.utils.download_util import load_file_from_url
def main():
"""Inference demo for FeMaSR
"""
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
parser.add_argument('-w', '--weight', type=str, default=None, help='path for model weights')
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
parser.add_argument('--use_weight', action="store_true")
parser.add_argument('--alpha', type=float, default=1.0, help='value of alpha')
parser.add_argument('--suffix', type=str, default='', help='Suffix of the restored image')
parser.add_argument('--max_size', type=int, default=1500, help='Max image size for whole image inference, otherwise use tiled_test')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = args.weight
# set up the model
sr_model = VQWeightDehazeNet(codebook_params=[[64, 1024, 512]], LQ_stage=True, use_weight=args.use_weight, weight_alpha=args.alpha).to(device)
sr_model.load_state_dict(torch.load(weight_path)['params'], strict=False)
sr_model.eval()
os.makedirs(args.output, exist_ok=True)
if os.path.isfile(args.input):
paths = [args.input]
else:
paths = sorted(glob.glob(os.path.join(args.input, '*')))
pbar = tqdm(total=len(paths), unit='image')
run_time_records = []
for idx, path in enumerate(paths):
img_name = os.path.basename(path)
save_path = os.path.join(args.output, f'{img_name}')
pbar.set_description(f'Test {img_name}')
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.max() > 255.0:
img = img / 255.0
if img.shape[-1] > 3:
img = img[:, :, :3]
img_tensor = img2tensor(img).to(device) / 255.
img_tensor = img_tensor.unsqueeze(0)
max_size = args.max_size ** 2
h, w = img_tensor.shape[2:]
start_time = time.time()
if h * w < max_size:
output, _ = sr_model.test(img_tensor)
else:
down_img = torch.nn.UpsamplingBilinear2d((h//2, w//2))(img_tensor)
output, _ = sr_model.test(down_img)
output = torch.nn.UpsamplingBilinear2d((h, w))(output)
end_time = time.time()
run_time_records.append(end_time - start_time)
output_img = tensor2img(output)
imwrite(output_img, save_path)
pbar.update(1)
pbar.close()
print(f"样本数量:{len(run_time_records)}, 平均运行时间:{sum(run_time_records) / len(run_time_records)}")
if __name__ == '__main__':
main()
# 模型编码
modelCode=400
# 模型名称
modelName=ridcp_pytorch
# 模型描述
modelDescription=ridcp是一种基于codebook先验知识的图像去雾模型
# 应用场景
appScenario=推理,训练,图像去雾,交通,医疗,环保,气象
# 框架类型
frameType=pytorch
# general settings
name: vq_weight_dehaze_trained_on_ours
model_type: VQDehazeModel
scale: &upscale 1
num_gpu: 4 # set num_gpu: 0 for cpu mode
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: General_Image_Train
type: HazeOnlineDataset
dataroot_gt: datasets/rgb_500
dataroot_depth: datasets/depth_500
beta_range: [0.3, 1.5]
A_range: [0.25, 1.0]
color_p: 1.0
color_range: [-0.025, 0.025]
io_backend:
type: disk
gt_size: 256
use_resize_crop: true
use_flip: true
use_rot: false
# data loader
use_shuffle: true
batch_size_per_gpu: &bsz 4
num_worker_per_gpu: *bsz
dataset_enlarge_ratio: 1
prefetch_mode: cpu
num_prefetch_queue: *bsz
val:
name: General_Image_Train
type: HazeOnlineDataset
dataroot_gt: datasets/rgb_500
dataroot_depth: datasets/depth_500
beta_range: [0.3, 1.5]
A_range: [0.25, 1.0]
color_p: 1.0
color_range: [-0.025, 0.025]
io_backend:
type: disk
# network structures
network_g:
type: VQWeightDehazeNet
gt_resolution: 256
norm_type: 'gn'
act_type: 'silu'
scale_factor: *upscale
codebook_params:
- [64, 1024, 512]
LQ_stage: true
use_weight: false
weight_alpha: -1.0
frozen_module_keywords: ['quantize', 'decoder_group', 'after_quant_group', 'out_conv']
network_d:
type: UNetDiscriminatorSN
num_in_ch: 512
# path
path:
pretrain_network_hq: pretrained_models/pretrained_HQPs.pth
pretrain_network_g:
pretrain_network_d: ~
strict_load: false
# resume_state: ~
# training settings
train:
optim_g:
type: Adam
lr: !!float 1e-4
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 4e-4
weight_decay: 0
betas: [0.9, 0.99]
scheduler:
type: MultiStepLR
milestones: [5000, 10000, 15000, 20000, 250000, 300000, 350000]
gamma: 1
total_iter: 45000
warmup_iter: -1 # no warm up
# losses
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: !!float 1.0
gan_opt:
type: GANLoss
gan_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: 0.1
codebook_opt:
loss_weight: 1.0
semantic_opt:
loss_weight: 0.1
net_d_iters: 1
net_d_init_iters: !!float 0
# validation settings·
val:
val_freq: !!float 80000
save_img: true
key_metric: psnr
metrics:
psnr: # metric name, can be arbitrary
type: psnr
crop_border: 4
test_y_channel: true
ssim:
type: ssim
crop_border: 4
test_y_channel: true
lpips:
type: lpips
better: lower
# logging settings
logger:
print_freq: 10
save_checkpoint_freq: !!float 1e3
save_latest_freq: !!float 5e2
show_tf_imgs_freq: !!float 1e2
use_tb_logger: true
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment