Commit 76b9024b authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3145 failed with stages
in 0 seconds
clipiqa: 0.77264
musiq: 71.75000
niqe: 4.18718
maniqa: 0.48229
We enable automatic model download in our code, if you need to conduct offline inference, download the pretrained model [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and S3Diff [[HuggingFace](https://huggingface.co/zhangap/S3Diff) | [GoogleDrive](https://drive.google.com/drive/folders/1cWYQYRFpadC4K2GuH8peg_hWEoFddZtj?usp=sharing)]
You can put the weight into `pretrained_weight/`
本仓库提供了快速从huggingface下载该模型的脚本,你可以运行以下命令进行下载:
python /S3Diff/down_model.py
einops>=0.6.1
#numpy>=1.24.4
open-clip-torch>=2.20.0
#opencv-python==4.6.0.66
pillow>=9.5.0
scipy==1.11.1
timm>=0.9.2
tokenizers
#torch==2.1.0
#torchaudio==2.1.0
torchdata==0.6.1
torchmetrics>=1.0.1
#torchvision==0.16.0
tqdm>=4.65.0
transformers==4.35.2
#triton==2.0.0
urllib3<1.27,>=1.25.4
#xformers>=0.0.20
streamlit-keyup==0.2.0
lpips
peft==0.10.0
pyiqa
omegaconf
dominate
#diffusers==0.25.1
gradio==3.43.1
huggingface_hub==0.22.2
import os
from PIL import Image
# 配置路径
ref_dir = "./datasets/Set5/image_SRF_2/HR/" # 原参考图目录(512x512)
resized_ref_dir = "./datasets/Set5/image_SRF_2/HR_1024/" # 调整后参考图目录(1024x1024)
# 创建输出目录
os.makedirs(resized_ref_dir, exist_ok=True)
# 批量调整尺寸(bicubic 插值)
for img_name in os.listdir(ref_dir):
if img_name.endswith((".png", ".jpg", ".jpeg")):
img_path = os.path.join(ref_dir, img_name)
save_path = os.path.join(resized_ref_dir, img_name)
# 打开图片并放大到 1024x1024
with Image.open(img_path) as img:
resized_img = img.resize((1024, 1024), Image.BICUBIC)
resized_img.save(save_path, quality=100)
print(f"已调整:{img_name} → 1024x1024")
print(f"\n✅ 所有参考图已保存到:{resized_ref_dir}")
python src/evaluate_img.py -i "path_to_generated_HR" -r "path_to_ground_truth"
accelerate launch --num_processes=1 --gpu_ids="0," --main_process_port 29300 src/inference_s3diff.py \
--de_net_path="./pretrained_weight/zhangap/S3Diff/de_net.pth" \
--output_dir="./output" \
--ref_path="./datasets/Set5/image_SRF_2/HR_1024/" \
--align_method="wavelet"
accelerate launch --num_processes=4 --gpu_ids="0,1,2,3" --main_process_port 29300 src/train_s3diff.py \
--de_net_path="assets/mm-realsr/de_net.pth" \
--output_dir="./output" \
--resolution=512 \
--train_batch_size=4 \
--enable_xformers_memory_efficient_attention \
--viz_freq 25
from setuptools import setup, find_packages
setup(
name='S3Diff',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
import torch
import copy
from torch import nn as nn
from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
class DEResNet(nn.Module):
"""Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
resnet arch works for image quality estimation.
Args:
num_in_ch (int): channel number of inputs. Default: 3.
num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
degradation_embed_size (int): embedding size of each degradation vector.
degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
num_feats (list): channel number of each stage.
num_blocks (list): residual block of each stage.
downscales (list): downscales of each stage.
"""
def __init__(self,
num_in_ch=3,
num_degradation=2,
degradation_degree_actv='sigmoid',
num_feats=[64, 64, 64, 128],
num_blocks=[2, 2, 2, 2],
downscales=[1, 1, 2, 1]):
super(DEResNet, self).__init__()
assert isinstance(num_feats, list)
assert isinstance(num_blocks, list)
assert isinstance(downscales, list)
assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
num_stage = len(num_feats)
self.conv_first = nn.ModuleList()
for _ in range(num_degradation):
self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
self.body = nn.ModuleList()
for _ in range(num_degradation):
body = list()
for stage in range(num_stage):
for _ in range(num_blocks[stage]):
body.append(ResidualBlockNoBN(num_feats[stage]))
if downscales[stage] == 1:
if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
continue
elif downscales[stage] == 2:
body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
else:
raise NotImplementedError
self.body.append(nn.Sequential(*body))
self.num_degradation = num_degradation
self.fc_degree = nn.ModuleList()
if degradation_degree_actv == 'sigmoid':
actv = nn.Sigmoid
elif degradation_degree_actv == 'tanh':
actv = nn.Tanh
else:
raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
f'{degradation_degree_actv} is not supported yet.')
for _ in range(num_degradation):
self.fc_degree.append(
nn.Sequential(
nn.Linear(num_feats[-1], 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1),
actv(),
))
self.avg_pool = nn.AdaptiveAvgPool2d(1)
default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
def clone_module(self, module):
new_module = copy.deepcopy(module)
return new_module
def average_parameters(self, modules):
avg_module = self.clone_module(modules[0])
for name, param in avg_module.named_parameters():
avg_param = sum([mod.state_dict()[name].data for mod in modules]) / len(modules)
param.data.copy_(avg_param)
return avg_module
def expand_degradation_modules(self, new_num_degradation):
if new_num_degradation <= self.num_degradation:
return
initial_modules = [self.conv_first, self.body, self.fc_degree]
for modules in initial_modules:
avg_module = self.average_parameters(modules[:2])
while len(modules) < new_num_degradation:
modules.append(self.clone_module(avg_module))
def load_and_expand_model(self, path, num_degradation):
state_dict = torch.load(path, map_location=torch.device('cpu'))
self.load_state_dict(state_dict, strict=True)
self.expand_degradation_modules(num_degradation)
self.num_degradation = num_degradation
def load_model(self, path):
state_dict = torch.load(path, map_location=torch.device('cpu'))
self.load_state_dict(state_dict, strict=True)
def set_train(self):
self.conv_first.requires_grad_(True)
self.fc_degree.requires_grad_(True)
for n, _p in self.body.named_parameters():
if "lora" in n:
_p.requires_grad = True
def forward(self, x):
degrees = []
for i in range(self.num_degradation):
x_out = self.conv_first[i](x)
feat = self.body[i](x_out)
feat = self.avg_pool(feat)
feat = feat.squeeze(-1).squeeze(-1)
# for i in range(self.num_degradation):
degrees.append(self.fc_degree[i](feat).squeeze(-1))
return torch.stack(degrees, dim=1)
\ No newline at end of file
import pyiqa
import os
import argparse
from pathlib import Path
import torch
from utils import util_image
import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(pyiqa.list_models())
def evaluate(in_path, ref_path, ntest):
metric_dict = {}
metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device)
metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device)
metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device)
metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device)
metric_paired_dict = {}
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
assert in_path.is_dir()
ref_path_list = None
if ref_path is not None:
ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path
ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")])
if ntest is not None: ref_path_list = ref_path_list[:ntest]
metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device)
metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device)
metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device)
lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")])
if ntest is not None: lr_path_list = lr_path_list[:ntest]
print(f'Find {len(lr_path_list)} images in {in_path}')
result = {}
for i in tqdm.tqdm(range(len(lr_path_list))):
_in_path = lr_path_list[i]
_ref_path = ref_path_list[i] if ref_path_list is not None else None
im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') # h x w x c
im_in_tensor = util_image.img2tensor(im_in).cuda() # 1 x c x h x w
for key, metric in metric_dict.items():
with torch.cuda.amp.autocast():
result[key] = result.get(key, 0) + metric(im_in_tensor).item()
if ref_path is not None:
im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32') # h x w x c
im_ref_tensor = util_image.img2tensor(im_ref).cuda()
for key, metric in metric_paired_dict.items():
result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item()
if ref_path is not None:
fid_metric = pyiqa.create_metric('fid')
result['fid'] = fid_metric(in_path, ref_path)
for key, res in result.items():
if key == 'fid':
print(f"{key}: {res:.2f}")
else:
print(f"{key}: {res/len(lr_path_list):.5f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i',"--in_path", type=str, required=True)
parser.add_argument("-r", "--ref_path", type=str, default=None)
parser.add_argument("--ntest", type=int, default=None)
args = parser.parse_args()
evaluate(args.in_path, args.ref_path, args.ntest)
\ No newline at end of file
import gradio as gr
import os
import sys
import math
from typing import List
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers.utils.import_utils import is_xformers_available
from my_utils.testing_utils import parse_args_paired_testing
from de_net import DEResNet
from s3diff_tile import S3Diff
from torchvision import transforms
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
tensor_transforms = transforms.Compose([
transforms.ToTensor(),
])
args = parse_args_paired_testing()
# Load scheduler, tokenizer and models.
if args.pretrained_path is None:
from huggingface_hub import hf_hub_download
pretrained_path = hf_hub_download(repo_id="zhangap/S3Diff", filename="s3diff.pkl")
else:
pretrained_path = args.pretrained_path
if args.sd_path is None:
from huggingface_hub import snapshot_download
sd_path = snapshot_download(repo_id="stabilityai/sd-turbo")
else:
sd_path = args.sd_path
de_net_path = 'assets/mm-realsr/de_net.pth'
# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=sd_path, pretrained_path=pretrained_path, args=args)
net_sr.set_eval()
# initalize degradation estimation network
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(de_net_path)
net_de = net_de.cuda()
net_de.eval()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_sr.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
net_sr.unet.enable_gradient_checkpointing()
weight_dtype = torch.float32
device = "cuda"
# Move text_encode and vae to gpu and cast to weight_dtype
net_sr.to(device, dtype=weight_dtype)
net_de.to(device, dtype=weight_dtype)
@torch.no_grad()
def process(
input_image: Image.Image,
scale_factor: float,
cfg_scale: float,
latent_tiled_size: int,
latent_tiled_overlap: int,
align_method: str,
) -> List[np.ndarray]:
# positive_prompt = ""
# negative_prompt = ""
net_sr._set_latent_tile(latent_tiled_size = latent_tiled_size, latent_tiled_overlap = latent_tiled_overlap)
im_lr = tensor_transforms(input_image).unsqueeze(0).to(device)
ori_h, ori_w = im_lr.shape[2:]
im_lr_resize = F.interpolate(
im_lr,
size=(int(ori_h * scale_factor),
int(ori_w * scale_factor)),
mode='bilinear',
align_corners=False # align_corners with this model causes the output to be shifted, presumably due to training without align_corners
)
im_lr_resize = im_lr_resize.contiguous()
im_lr_resize_norm = im_lr_resize * 2 - 1.0
im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
resize_h, resize_w = im_lr_resize_norm.shape[2:]
pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
try:
with torch.autocast("cuda"):
deg_score = net_de(im_lr)
pos_tag_prompt = [args.pos_prompt]
neg_tag_prompt = [args.neg_prompt]
x_tgt_pred = net_sr(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
output_pil = transforms.ToPILImage()(out_img[0])
if align_method == 'no fix':
image = output_pil
else:
im_lr_resize = transforms.ToPILImage()(im_lr_resize[0])
if align_method == 'wavelet':
image = wavelet_color_fix(output_pil, im_lr_resize)
elif align_method == 'adain':
image = adain_color_fix(output_pil, im_lr_resize)
except Exception as e:
print(e)
image = Image.new(mode="RGB", size=(512, 512))
return image
#
MARKDOWN = \
"""
## Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors
[GitHub](https://github.com/ArcticHare105/S3Diff) | [Paper](https://arxiv.org/abs/2409.17058)
If S3Diff is helpful for you, please help star the GitHub Repo. Thanks!
"""
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image = gr.Image(source="upload", type="pil")
run_button = gr.Button(label="Run")
with gr.Accordion("Options", open=True):
cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set a value larger than 1 to enable it!)", minimum=1.0, maximum=1.1, value=1.07, step=0.01)
scale_factor = gr.Number(label="SR Scale", value=4)
latent_tiled_size = gr.Slider(label="Tile Size", minimum=64, maximum=160, value=96, step=1)
latent_tiled_overlap = gr.Slider(label="Tile Overlap", minimum=16, maximum=48, value=32, step=1)
align_method = gr.Dropdown(label="Color Correction", choices=["wavelet", "adain", "no fix"], value="wavelet")
with gr.Column():
result_image = gr.Image(label="Output", show_label=False, elem_id="result_image", source="canvas", width="100%", height="auto")
inputs = [
input_image,
scale_factor,
cfg_scale,
latent_tiled_size,
latent_tiled_overlap,
align_method
]
run_button.click(fn=process, inputs=inputs, outputs=[result_image])
block.launch()
import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import gc
import tqdm
import math
import argparse
import clip
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from omegaconf import OmegaConf
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
import diffusers
import utils.misc as misc
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler
from de_net import DEResNet
from s3diff_tile import S3Diff
from my_utils.testing_utils import parse_args_paired_testing, PlainDataset, lr_proc
from utils.util_image import ImageSpliterTh
from my_utils.utils import instantiate_from_config
from pathlib import Path
from utils import util_image
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
def evaluate(in_path, ref_path, ntest):
"""仅保留无参考评估(不依赖参考图)"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# 仅保留无参考指标(无需参考图,不要求尺寸匹配)
metric_dict = {}
metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device)
metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device)
metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device)
metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device)
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
assert in_path.is_dir()
# 无参考评估,无需处理参考图路径
lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")])
if ntest is not None:
lr_path_list = lr_path_list[:ntest]
print(f'Find {len(lr_path_list)} images in {in_path}')
result = {}
for i in tqdm.tqdm(range(len(lr_path_list))):
_in_path = lr_path_list[i]
# 仅加载超分图(无参考评估不需要参考图)
im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') # h x w x c
im_in_tensor = util_image.img2tensor(im_in).cuda() # 1 x c x h x w
# 计算无参考指标分数
for key, metric in metric_dict.items():
with torch.cuda.amp.autocast():
result[key] = result.get(key, 0) + metric(im_in_tensor).item()
# 输出平均分数(无FID等有参考指标)
print_results = []
for key, res in result.items():
avg_score = res / len(lr_path_list)
print(f"{key}: {avg_score:.5f}")
print_results.append(f"{key}: {avg_score:.5f}")
return print_results
def main(args):
config = OmegaConf.load(args.base_config)
if args.pretrained_path is None:
from huggingface_hub import hf_hub_download
#pretrained_path = hf_hub_download(repo_id="zhangap/S3Diff", filename="s3diff.pkl")
pretrained_path = "./pretrained_weight/zhangap/S3Diff/s3diff.pkl" ###
else:
pretrained_path = args.pretrained_path
if args.sd_path is None:
#from huggingface_hub import snapshot_download
#sd_path = snapshot_download(repo_id="stabilityai/sd-turbo")
# 直接使用本地已下载的 sd-turbo路径
sd_path = "./pretrained_weight/stabilityai/sd-turbo/"
else:
sd_path = args.sd_path
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=sd_path, pretrained_path=pretrained_path, args=args)
net_sr.set_eval()
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(args.de_net_path)
net_de = net_de.cuda()
net_de.eval()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_sr.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
if args.gradient_checkpointing:
net_sr.unet.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
dataset_val = PlainDataset(config.validation)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
# Prepare everything with our `accelerator`.
net_sr, net_de = accelerator.prepare(net_sr, net_de)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move al networksr to device and cast to weight_dtype
net_sr.to(accelerator.device, dtype=weight_dtype)
net_de.to(accelerator.device, dtype=weight_dtype)
offset = args.padding_offset
for step, batch_val in enumerate(dl_val):
lr_path = batch_val['lr_path'][0]
(path, name) = os.path.split(lr_path)
im_lr = batch_val['lr'].cuda()
im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
ori_h, ori_w = im_lr.shape[2:]
im_lr_resize = F.interpolate(
im_lr,
size=(ori_h * config.sf,
ori_w * config.sf),
mode='bilinear',
align_corners=False # align_corners with this model causes the output to be shifted, presumably due to training without align_corners
)
im_lr_resize = im_lr_resize.contiguous()
im_lr_resize_norm = im_lr_resize * 2 - 1.0
im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
resize_h, resize_w = im_lr_resize_norm.shape[2:]
pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
B = im_lr_resize.size(0)
with torch.no_grad():
# forward pass
deg_score = net_de(im_lr)
pos_tag_prompt = [args.pos_prompt for _ in range(B)]
neg_tag_prompt = [args.neg_prompt for _ in range(B)]
x_tgt_pred = accelerator.unwrap_model(net_sr)(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
output_pil = transforms.ToPILImage()(out_img[0])
if args.align_method == 'nofix':
output_pil = output_pil
else:
im_lr_resize = transforms.ToPILImage()(im_lr_resize[0].cpu().detach())
if args.align_method == 'wavelet':
output_pil = wavelet_color_fix(output_pil, im_lr_resize)
elif args.align_method == 'adain':
output_pil = adain_color_fix(output_pil, im_lr_resize)
fname, ext = os.path.splitext(name)
outf = os.path.join(args.output_dir, fname+'.png')
output_pil.save(outf)
# 执行无参考评估(ref_path传入None不影响,函数内已忽略)
print_results = evaluate(args.output_dir, args.ref_path, None)
out_t = os.path.join(args.output_dir, 'results.txt')
with open(out_t, 'w', encoding='utf-8') as f:
for item in print_results:
f.write(f"{item}\n")
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
# 导入pyiqa(确保在执行评估前导入)
import pyiqa
args = parse_args_paired_testing()
main(args)
import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import gc
import tqdm
import math
import lpips
import pyiqa
import argparse
import clip
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from omegaconf import OmegaConf
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
# from tqdm.auto import tqdm
import diffusers
import utils.misc as misc
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler
from de_net import DEResNet
from s3diff_tile import S3Diff
from my_utils.testing_utils import parse_args_paired_testing, PlainDataset, lr_proc
from utils.util_image import ImageSpliterTh
from my_utils.utils import instantiate_from_config
from pathlib import Path
from utils import util_image
from utils.wavelet_color import wavelet_color_fix, adain_color_fix
def evaluate(in_path, ref_path, ntest):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
metric_dict = {}
metric_dict["clipiqa"] = pyiqa.create_metric('clipiqa').to(device)
metric_dict["musiq"] = pyiqa.create_metric('musiq').to(device)
metric_dict["niqe"] = pyiqa.create_metric('niqe').to(device)
metric_dict["maniqa"] = pyiqa.create_metric('maniqa').to(device)
metric_paired_dict = {}
in_path = Path(in_path) if not isinstance(in_path, Path) else in_path
assert in_path.is_dir()
ref_path_list = None
if ref_path is not None:
ref_path = Path(ref_path) if not isinstance(ref_path, Path) else ref_path
ref_path_list = sorted([x for x in ref_path.glob("*.[jpJP][pnPN]*[gG]")])
if ntest is not None: ref_path_list = ref_path_list[:ntest]
metric_paired_dict["psnr"]=pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr').to(device)
metric_paired_dict["lpips"]=pyiqa.create_metric('lpips').to(device)
metric_paired_dict["dists"]=pyiqa.create_metric('dists').to(device)
metric_paired_dict["ssim"]=pyiqa.create_metric('ssim', test_y_channel=True, color_space='ycbcr' ).to(device)
lr_path_list = sorted([x for x in in_path.glob("*.[jpJP][pnPN]*[gG]")])
if ntest is not None: lr_path_list = lr_path_list[:ntest]
print(f'Find {len(lr_path_list)} images in {in_path}')
result = {}
for i in tqdm.tqdm(range(len(lr_path_list))):
_in_path = lr_path_list[i]
_ref_path = ref_path_list[i] if ref_path_list is not None else None
im_in = util_image.imread(_in_path, chn='rgb', dtype='float32') # h x w x c
im_in_tensor = util_image.img2tensor(im_in).cuda() # 1 x c x h x w
for key, metric in metric_dict.items():
with torch.cuda.amp.autocast():
result[key] = result.get(key, 0) + metric(im_in_tensor).item()
if ref_path is not None:
im_ref = util_image.imread(_ref_path, chn='rgb', dtype='float32') # h x w x c
im_ref_tensor = util_image.img2tensor(im_ref).cuda()
for key, metric in metric_paired_dict.items():
result[key] = result.get(key, 0) + metric(im_in_tensor, im_ref_tensor).item()
if ref_path is not None:
fid_metric = pyiqa.create_metric('fid')
result['fid'] = fid_metric(in_path, ref_path)
print_results = []
for key, res in result.items():
if key == 'fid':
print(f"{key}: {res:.2f}")
print_results.append(f"{key}: {res:.2f}")
else:
print(f"{key}: {res/len(lr_path_list):.5f}")
print_results.append(f"{key}: {res/len(lr_path_list):.5f}")
return print_results
def main(args):
config = OmegaConf.load(args.base_config)
if args.pretrained_path is None:
from huggingface_hub import hf_hub_download
#pretrained_path = hf_hub_download(repo_id="zhangap/S3Diff", filename="s3diff.pkl")
pretrained_path = "./pretrained_weight/zhangap/S3Diff/s3diff.pkl" ###
else:
pretrained_path = args.pretrained_path
if args.sd_path is None:
#from huggingface_hub import snapshot_download
#sd_path = snapshot_download(repo_id="stabilityai/sd-turbo")
# 直接使用本地已下载的 sd-turbo路径
sd_path = "./pretrained_weight/stabilityai/sd-turbo/"
else:
sd_path = args.sd_path
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
# initialize net_sr
net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=sd_path, pretrained_path=pretrained_path, args=args)
net_sr.set_eval()
net_de = DEResNet(num_in_ch=3, num_degradation=2)
net_de.load_model(args.de_net_path)
net_de = net_de.cuda()
net_de.eval()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
net_sr.unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
if args.gradient_checkpointing:
net_sr.unet.enable_gradient_checkpointing()
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
dataset_val = PlainDataset(config.validation)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
# Prepare everything with our `accelerator`.
net_sr, net_de = accelerator.prepare(net_sr, net_de)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move al networksr to device and cast to weight_dtype
net_sr.to(accelerator.device, dtype=weight_dtype)
net_de.to(accelerator.device, dtype=weight_dtype)
offset = args.padding_offset
for step, batch_val in enumerate(dl_val):
lr_path = batch_val['lr_path'][0]
(path, name) = os.path.split(lr_path)
im_lr = batch_val['lr'].cuda()
im_lr = im_lr.to(memory_format=torch.contiguous_format).float()
ori_h, ori_w = im_lr.shape[2:]
im_lr_resize = F.interpolate(
im_lr,
size=(ori_h * config.sf,
ori_w * config.sf),
mode='bilinear',
align_corners=False # align_corners with this model causes the output to be shifted, presumably due to training without align_corners
)
im_lr_resize = im_lr_resize.contiguous()
im_lr_resize_norm = im_lr_resize * 2 - 1.0
im_lr_resize_norm = torch.clamp(im_lr_resize_norm, -1.0, 1.0)
resize_h, resize_w = im_lr_resize_norm.shape[2:]
pad_h = (math.ceil(resize_h / 64)) * 64 - resize_h
pad_w = (math.ceil(resize_w / 64)) * 64 - resize_w
im_lr_resize_norm = F.pad(im_lr_resize_norm, pad=(0, pad_w, 0, pad_h), mode='reflect')
B = im_lr_resize.size(0)
with torch.no_grad():
# forward pass
deg_score = net_de(im_lr)
pos_tag_prompt = [args.pos_prompt for _ in range(B)]
neg_tag_prompt = [args.neg_prompt for _ in range(B)]
x_tgt_pred = accelerator.unwrap_model(net_sr)(im_lr_resize_norm, deg_score, pos_prompt=pos_tag_prompt, neg_prompt=neg_tag_prompt)
x_tgt_pred = x_tgt_pred[:, :, :resize_h, :resize_w]
out_img = (x_tgt_pred * 0.5 + 0.5).cpu().detach()
output_pil = transforms.ToPILImage()(out_img[0])
if args.align_method == 'nofix':
output_pil = output_pil
else:
im_lr_resize = transforms.ToPILImage()(im_lr_resize[0].cpu().detach())
if args.align_method == 'wavelet':
output_pil = wavelet_color_fix(output_pil, im_lr_resize)
elif args.align_method == 'adain':
output_pil = adain_color_fix(output_pil, im_lr_resize)
fname, ext = os.path.splitext(name)
outf = os.path.join(args.output_dir, fname+'.png')
output_pil.save(outf)
print_results = evaluate(args.output_dir, args.ref_path, None)
out_t = os.path.join(args.output_dir, 'results.txt')
with open(out_t, 'w', encoding='utf-8') as f:
for item in print_results:
f.write(f"{item}\n")
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
args = parse_args_paired_testing()
main(args)
import torch
import os
import requests
from tqdm import tqdm
from diffusers import DDPMScheduler, EulerDiscreteScheduler
from typing import Any, Optional, Union
# def make_1step_sched(pretrained_path, step=4):
# noise_scheduler_1step = EulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
# noise_scheduler_1step.set_timesteps(step, device="cuda")
# noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
# return noise_scheduler_1step
def make_1step_sched(pretrained_path):
noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
noise_scheduler_1step.set_timesteps(1, device="cuda")
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
return noise_scheduler_1step
def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
_tmp = lora_A(dropout(x))
if isinstance(lora_A, torch.nn.Conv2d):
_tmp = torch.einsum('...khw,...kr->...rhw', _tmp, self.de_mod)
elif isinstance(lora_A, torch.nn.Linear):
_tmp = torch.einsum('...lk,...kr->...lr', _tmp, self.de_mod)
else:
raise NotImplementedError('only conv and linear are supported yet.')
result = result + lora_B(_tmp) * scaling
else:
x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
result = result.to(torch_result_dtype)
return result
def download_url(url, outf):
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
else:
print(f"Skipping download, {outf} already exists")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment