Commit d118e789 authored by Rayyyyy's avatar Rayyyyy
Browse files

First commit.

parents
Pipeline #782 failed with stages
in 0 seconds
# ignored folders
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
weights/*
version.py
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
xintao.wang@outlook.com or xintaowang@tencent.com.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
BSD 3-Clause License
Copyright (c) 2021, Xintao Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include assets/*
include inputs/*
include scripts/*.py
include inference_realesrgan.py
include VERSION
include LICENSE
include requirements.txt
include weights/README.md
# Real-ESRGAN
## 论文
[Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data](https://arxiv.org/abs/2107.10833)
## 模型结构
生成网络: 采用ESRGAN的生成网络,对于x4倍的超分辨,网络完全按照ESRGAN的生成器执行;对x2和x1倍的超分辨,网络先进行pixel-unshuffle(pixel-shuffl的反操作,pixel-shuffle可理解为通过压缩图像通道而对图像尺寸进行放大),以降低图像分辨率为前提,对图像通道数进行扩充,然后将处理后的图像输入网络进行超分辨重建。
<div align=center>
<img src="./doc/ESRGAN.png"/>
</div>
对抗网络: 由于使用的复杂的构建数据集的方式,所以需要使用更先进的判别器对生成图像进行判别。使用U-Net判别器可以在像素角度,对单个生成的像素进行真假判断,这能够在保证生成图像整体真实的情况下,注重生成图像细节。
<div align=center>
<img src="./doc/UNet.png"/>
</div>
## 算法原理
通过使用更实用的退化过程合成训练对,扩展强大的ESRGAN以恢复一般的真实世界LR图像。
<div align=center>
<img src="./doc/High-order的pipeline.png"/>
</div>
## 环境配置
-v 路径、docker_name和imageID根据实际情况修改
### Docker(方法一)
```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04.1-py38-latest
docker run -it -v /path/your_code_data/:/path/ your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/real-esrgan_pytorch
# 推理和训练都用到了basicsr库
pip install basicsr
# facexlib 和 gfpgan 用于面部增强
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop
```
### Dockerfile(方法二)
```bash
cd ./docker
cp ../requirements.txt requirements.txt
docker build --no-cache -t real_esrgan:latest .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
cd /your_code_path/real-esrgan_pytorch
# 推理和训练都用到了basicsr库
pip install basicsr
# facexlib 和 gfpgan 用于面部增强
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```bash
DTK软件栈:dtk23.04.1
python:python3.8
torch:1.13.1
torchvision:0.14.1
```
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
2、其他非特殊库直接按照下面步骤进行安装
```bash
# 推理和训练都用到了basicsr库
pip install basicsr
# facexlib 和 gfpgan 用于面部增强
pip install facexlib
pip install gfpgan
pip install -r requirements.txt
python setup.py develop
```
## 数据集
### 准备数据集
所需数据集为: DF2K ( DIV2K 和 Flickr2K ) + OST. 仅需要 HR 图片.
[DIV2K](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip)
[Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar)
[OST](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/datasets/OST_dataset.zip)
数据集的目录结构如下:
```bash
├── datasets
│ ├── DF2K
│ ├── DF2K_HR # 将DIV2K和Flickr2K的HR图像放于这里
│ ├── DF2K_multiscale # 生成的
│ ├── DF2K_multiscale_sub # 生成的
│ └── meta_info # 生成的
│ ├── OST
```
### 预处理数据集
#### 1.【可选】生成多尺寸图片
针对 DF2K 数据集,我们使用多尺寸缩放策略,对 HR 图像进行下采样,获得多尺寸的标准参考(Ground-Truth)图像。 <br>
使用 [scripts/generate_multiscale_DF2K.py](scripts/generate_multiscale_DF2K.py) 脚本快速生成多尺寸的图像。<br>
如果只想简单试试而不进行准确训练,那么该过程可选。
```bash
# example
python scripts/generate_multiscale_DF2K.py --input datasets/DF2K/DF2K_HR --output datasets/DF2K/DF2K_multiscale
```
#### 2.【可选】裁切为子图像
使用脚本 [scripts/extract_subimages.py](scripts/extract_subimages.py) 将 DF2K 图像裁切为子图像,以加快 IO 和处理速度。<br>
如果你的 IO 够好或储存空间有限,那么此步骤是可选的。
```bash
# example
python scripts/extract_subimages.py --input datasets/DF2K/DF2K_multiscale --output datasets/DF2K/DF2K_multiscale_sub --crop_size 400 --step 200
```
#### 3. 准备元信息 txt
您需要准备一个包含图像路径的 txt 文件。
使用脚本 [scripts/generate_meta_info.py](scripts/generate_meta_info.py) 生成包含图像路径的 txt 文件。<br>
你还可以合并多个文件夹的图像路径到一个元信息(meta_info)txt, 示例参考如下:
```bash
python scripts/generate_meta_info.py --input datasets/DF2K/DF2K_HR, datasets/DF2K/DF2K_multiscale --root datasets/DF2K, datasets/DF2K --meta_info datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt
```
## 训练
### 完整训练
1. 下载预先训练的模型 [ESRGAN](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth), 将模型放到 `experiments/pretrained_models`目录下。
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models
```
2. 相应地修改选项文件 `options/train_realesrnet_x4plus.yml` 中的内容:
```yml
train:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K # 修改为你的数据集文件夹根目录
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # 修改为你自己生成的元信息txt
io_backend:
type: disk
```
如果需要指定预训练路径到其他文件,`pretrain_network_g` 参数对应的值, 当前默认保存为`experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/model/net_g_1000000.pth`
3. 如果你想在训练过程中执行验证,就取消注释这些内容并进行相应的修改:
```yml
# 取消注释这些以进行验证
# val:
# name: validation
# type: PairedImageDataset
# dataroot_gt: path_to_gt
# dataroot_lq: path_to_lq
# io_backend:
# type: disk
...
# 取消注释这些以进行验证
# 验证设置
# val:
# val_freq: !!float 5e3
# save_img: True
# metrics:
# psnr: # 指标名称,可以是任意的
# type: calculate_psnr
# crop_border: 4
# test_y_channel: false
```
### 微调训练
你可以用自己的数据集微调 Real-ESRGAN。一般地,微调(Fine-Tune)程序可以分为两种类型:
1. [动态生成降级图像](#动态生成降级图像)
2. [使用**已配对**的数据](#使用已配对的数据)
#### 动态生成降级图像
只需要高分辨率图像,在训练过程中,使用 Real-ESRGAN 描述的降级模型生成低质量图像。
1. 下载预训练模型到 `experiments/pretrained_models` 目录下:
- *RealESRGAN_x4plus.pth*:
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
```
- *RealESRGAN_x4plus_netD.pth*:
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
```
2. 修改选项文件 [options/finetune_realesrgan_x4plus.yml](options/finetune_realesrgan_x4plus.yml),特别是 `datasets` 部分:
```yml
train:
name: DF2K+OST
type: RealESRGANDataset
dataroot_gt: datasets/DF2K # 修改为你的数据集文件夹根目录
meta_info: realesrgan/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt # 修改为你自己生成的元信息txt
io_backend:
type: disk
```
#### 使用已配对的数据
你还可以用自己已经配对的数据微调 RealESRGAN,这个过程更类似于微调 ESRGAN。
1. 数据准备
假设你已经有两个文件夹(folder):
- **gt folder**(标准参考,高分辨率图像):*datasets/DF2K/DIV2K_train_HR_sub*
- **lq folder**(低质量,低分辨率图像):*datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub*
然后,使用脚本 [scripts/generate_meta_info_pairdata.py](scripts/generate_meta_info_pairdata.py) 生成元信息(meta_info)txt 文件。
```bash
python scripts/generate_meta_info_pairdata.py --input datasets/DF2K/DIV2K_train_HR_sub datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub --meta_info datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
```
2. 下载预训练模型
下载所需预训练模型到 `experiments/pretrained_models` 目录下。
- *RealESRGAN_x4plus.pth*:
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
```
- *RealESRGAN_x4plus_netD.pth*:
```bash
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.3/RealESRGAN_x4plus_netD.pth -P experiments/pretrained_models
```
3. finetune
修改选项文件 [options/finetune_realesrgan_x4plus_pairdata.yml](options/finetune_realesrgan_x4plus_pairdata.yml) ,特别是 `datasets` 部分:
```yml
train:
name: DIV2K
type: RealESRGANPairedDataset
dataroot_gt: datasets/DF2K # 修改为你的 gt folder 文件夹根目录
dataroot_lq: datasets/DF2K # 修改为你的 lq folder 文件夹根目录
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt # 修改为你自己生成的元信息txt
io_backend:
type: disk
```
### 训练命令
#### 单机多卡
默认auto_resume模式,根据完整训练或者微调训练,修改-opt参数对应的yml文件。当前默认为完整训练
```bash
bash train_multi.sh
```
#### 单机单卡
默认auto_resume模式
```bash
bash train.sh
```
## 推理
下载预训练模型[RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth), 将其放入 weights文件夹下, 测试结果存默认保存在results文件夹下。
```bash
# 下载预训练模型
wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights
# 执行推理
python inference_realesrgan.py -n RealESRGAN_x4plus -i inputs --face_enhance
```
## result
<div align=center>
<img src="./doc/00017_gray.jpg"/>
<img src="./doc/00017_gray_out.jpg"/>
</div>
### 精度
| NIQE | ADE20K val | OST300 |
| :------: | :------: | :------: |
| our | xxx | xxx |
| paper | 3.7886 | 2.8659 |
## 应用场景
### 算法类别
图像超分
### 热点应用行业
网安,交通,政府,工业
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/real-esrgan_pytorch
## 参考资料
https://github.com/xinntao/Real-ESRGAN
# This file is used for constructing replicate env
image: "r8.im/tencentarc/realesrgan"
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "torch==1.7.1"
- "torchvision==0.8.2"
- "numpy==1.21.1"
- "lmdb==1.2.1"
- "opencv-python==4.5.3.56"
- "PyYAML==5.4.1"
- "tqdm==4.62.2"
- "yapf==0.31.0"
- "basicsr==1.4.2"
- "facexlib==0.2.5"
predict: "cog_predict.py:Predictor"
# flake8: noqa
# This file is used for deploying replicate models
# running: cog predict -i img=@inputs/00017_gray.png -i version='General - v3' -i scale=2 -i face_enhance=True -i tile=0
# push: cog push r8.im/xinntao/realesrgan
import os
os.system('pip install gfpgan')
os.system('python setup.py develop')
import cv2
import shutil
import tempfile
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from realesrgan.utils import RealESRGANer
try:
from cog import BasePredictor, Input, Path
from gfpgan import GFPGANer
except Exception:
print('please install cog and realesrgan package')
class Predictor(BasePredictor):
def setup(self):
os.makedirs('output', exist_ok=True)
# download weights
if not os.path.exists('weights/realesr-general-x4v3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights'
)
if not os.path.exists('weights/GFPGANv1.4.pth'):
os.system('wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P ./weights')
if not os.path.exists('weights/RealESRGAN_x4plus.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights'
)
if not os.path.exists('weights/RealESRGAN_x4plus_anime_6B.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P ./weights'
)
if not os.path.exists('weights/realesr-animevideov3.pth'):
os.system(
'wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth -P ./weights'
)
def choose_model(self, scale, version, tile=0):
half = True if torch.cuda.is_available() else False
if version == 'General - RealESRGANplus':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
model_path = 'weights/RealESRGAN_x4plus.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'General - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'weights/realesr-general-x4v3.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'Anime - anime6B':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
model_path = 'weights/RealESRGAN_x4plus_anime_6B.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
elif version == 'AnimeVideo - v3':
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
model_path = 'weights/realesr-animevideov3.pth'
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=10, pre_pad=0, half=half)
self.face_enhancer = GFPGANer(
model_path='weights/GFPGANv1.4.pth',
upscale=scale,
arch='clean',
channel_multiplier=2,
bg_upsampler=self.upsampler)
def predict(
self,
img: Path = Input(description='Input'),
version: str = Input(
description='RealESRGAN version. Please see [Readme] below for more descriptions',
choices=['General - RealESRGANplus', 'General - v3', 'Anime - anime6B', 'AnimeVideo - v3'],
default='General - v3'),
scale: float = Input(description='Rescaling factor', default=2),
face_enhance: bool = Input(
description='Enhance faces with GFPGAN. Note that it does not work for anime images/vidoes', default=False),
tile: int = Input(
description=
'Tile size. Default is 0, that is no tile. When encountering the out-of-GPU-memory issue, please specify it, e.g., 400 or 200',
default=0)
) -> Path:
if tile <= 100 or tile is None:
tile = 0
print(f'img: {img}. version: {version}. scale: {scale}. face_enhance: {face_enhance}. tile: {tile}.')
try:
extension = os.path.splitext(os.path.basename(str(img)))[1]
img = cv2.imread(str(img), cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2:
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
self.choose_model(scale, version, tile)
try:
if face_enhance:
_, _, output = self.face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = self.upsampler.enhance(img, outscale=scale)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set "tile" to a smaller size, e.g., 400.')
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
# save_path = f'output/out.{extension}'
# cv2.imwrite(save_path, output)
out_path = Path(tempfile.mkdtemp()) / f'out.{extension}'
cv2.imwrite(str(out_path), output)
except Exception as error:
print('global exception: ', error)
finally:
clean_folder('output')
return out_path
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment