Unverified Commit 5ffde61d authored by Igor Pavlov's avatar Igor Pavlov Committed by GitHub
Browse files

Merge pull request #11 from ai-forever/refactor

Refactor
parents 9410db54 7db57974
build/
dist/
*.egg-info/
__pycache__/
.ipynb_checkpoints/
from .model import RealESRGAN
\ No newline at end of file
import os
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import cv2 import cv2
from huggingface_hub import hf_hub_url, cached_download
from rrdbnet_arch import RRDBNet from .rrdbnet_arch import RRDBNet
from utils_sr import * from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
unpad_image
HF_MODELS = {
2: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x2.pth',
),
4: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x4.pth',
),
8: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x8.pth',
),
}
class RealESRGAN: class RealESRGAN:
def __init__(self, device, scale=4): def __init__(self, device, scale=4):
self.device = device self.device = device
self.scale = scale self.scale = scale
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) self.model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=scale
)
def load_weights(self, model_path, download=True):
if not os.path.exists(model_path) and download:
assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
config = HF_MODELS[self.scale]
cache_dir = os.path.dirname(model_path)
local_filename = os.path.basename(model_path)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
def load_weights(self, model_path):
loadnet = torch.load(model_path) loadnet = torch.load(model_path)
if 'params' in loadnet: if 'params' in loadnet:
self.model.load_state_dict(loadnet['params'], strict=True) self.model.load_state_dict(loadnet['params'], strict=True)
...@@ -33,8 +64,9 @@ class RealESRGAN: ...@@ -33,8 +64,9 @@ class RealESRGAN:
lr_image = np.array(lr_image) lr_image = np.array(lr_image)
lr_image = pad_reflect(lr_image, pad_size) lr_image = pad_reflect(lr_image, pad_size)
patches, p_shape = split_image_into_overlapping_patches(lr_image, patch_size=patches_size, patches, p_shape = split_image_into_overlapping_patches(
padding_size=padding) lr_image, patch_size=patches_size, padding_size=padding
)
img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach() img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
with torch.no_grad(): with torch.no_grad():
...@@ -47,8 +79,10 @@ class RealESRGAN: ...@@ -47,8 +79,10 @@ class RealESRGAN:
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,) scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
np_sr_image = stich_together(np_sr_image, padded_image_shape=padded_size_scaled, np_sr_image = stich_together(
target_shape=scaled_image_shape, padding_size=padding * scale) np_sr_image, padded_image_shape=padded_size_scaled,
target_shape=scaled_image_shape, padding_size=padding * scale
)
sr_img = (np_sr_image*255).astype(np.uint8) sr_img = (np_sr_image*255).astype(np.uint8)
sr_img = unpad_image(sr_img, pad_size*scale) sr_img = unpad_image(sr_img, pad_size*scale)
sr_img = Image.fromarray(sr_img) sr_img = Image.fromarray(sr_img)
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from arch_util import default_init_weights, make_layer, pixel_unshuffle from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module): class ResidualDenseBlock(nn.Module):
......
...@@ -3,7 +3,6 @@ import torch ...@@ -3,7 +3,6 @@ import torch
from PIL import Image from PIL import Image
import os import os
import io import io
import imageio
def pad_reflect(image, pad_size): def pad_reflect(image, pad_size):
imsize = image.shape imsize = image.shape
...@@ -22,13 +21,6 @@ def unpad_image(image, pad_size): ...@@ -22,13 +21,6 @@ def unpad_image(image, pad_size):
return image[pad_size:-pad_size, pad_size:-pad_size, :] return image[pad_size:-pad_size, pad_size:-pad_size, :]
def jpegBlur(im,q):
buf = io.BytesIO()
imageio.imwrite(buf,im,format='jpg',quality=q)
s = buf.getbuffer()
return imageio.imread(s,format='jpg')
def process_array(image_array, expand=True): def process_array(image_array, expand=True):
""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
......
...@@ -2,18 +2,17 @@ import os ...@@ -2,18 +2,17 @@ import os
import torch import torch
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from realesrgan import RealESRGAN from RealESRGAN import RealESRGAN
def main() -> int: def main() -> int:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RealESRGAN(device, scale=4) model = RealESRGAN(device, scale=4)
model.load_weights('weights/RealESRGAN_x4.pth') model.load_weights('weights/RealESRGAN_x4.pth', download=True)
for i, image in enumerate(os.listdir("inputs")): for i, image in enumerate(os.listdir("inputs")):
image = Image.open(f"inputs/{image}").convert('RGB') image = Image.open(f"inputs/{image}").convert('RGB')
sr_image = model.predict(image) sr_image = model.predict(image)
sr_image.save(f'results/{i}.png') sr_image.save(f'results/{i}.png')
return 1
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,4 +4,4 @@ Pillow ...@@ -4,4 +4,4 @@ Pillow
torch>=1.7 torch>=1.7
torchvision>=0.8.0 torchvision>=0.8.0
tqdm tqdm
imageio huggingface-hub
\ No newline at end of file \ No newline at end of file
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="RealESRGAN",
py_modules=["RealESRGAN"],
version="1.0",
description="",
author="Sberbank AI, Xintao Wang",
url='https://github.com/ai-forever/Real-ESRGAN',
packages=find_packages(include=['RealESRGAN']),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
]
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment