app.py 1.58 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
import gradio as gr
from PIL import Image
chenpangpang's avatar
chenpangpang committed
3
import numpy as np
chenpangpang's avatar
chenpangpang committed
4
from aura_sr import AuraSR
chenpangpang's avatar
chenpangpang committed
5
6
import torch
import os
chenpangpang's avatar
chenpangpang committed
7
8


chenpangpang's avatar
chenpangpang committed
9
10
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
11

chenpangpang's avatar
chenpangpang committed
12
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chenpangpang's avatar
chenpangpang committed
13

chenpangpang's avatar
chenpangpang committed
14
15
# Initialize the AuraSR model
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2/model.safetensors")
chenpangpang's avatar
chenpangpang committed
16
17
18
19

def process_image(input_image):
    if input_image is None:
        raise gr.Error("Please provide an image to upscale.")
chenpangpang's avatar
chenpangpang committed
20
    print("get input image: ", input_image)
chenpangpang's avatar
chenpangpang committed
21
    # Upscale the image using AuraSR
chenpangpang's avatar
chenpangpang committed
22
23
    upscaled_image = process_image_on_gpu(input_image)
    print("upscaled_image: ", upscaled_image)
24
    return upscaled_image
chenpangpang's avatar
chenpangpang committed
25
26
27

def process_image_on_gpu(pil_image):
    return aura_sr.upscale_4x(pil_image)
28
29


chenpangpang's avatar
chenpangpang committed
30
title = """<h1 align="center">AuraSR-v2:一款基于GAN的图像修复工具,可从低分辨率图片生成高分辨率图片</h1>"""
chenpangpang's avatar
chenpangpang committed
31
32
33

with gr.Blocks() as demo:
    gr.HTML(title)
34

chenpangpang's avatar
chenpangpang committed
35
36
    with gr.Row():
        with gr.Column(scale=1):
chenpangpang's avatar
chenpangpang committed
37
            input_image = gr.Image(type="pil", label="输入图片")
38
            process_btn = gr.Button("生成")
chenpangpang's avatar
chenpangpang committed
39
        with gr.Column(scale=1):
40
            gallery = gr.Image(label="生成图片")
chenpangpang's avatar
chenpangpang committed
41
42
43
44

    process_btn.click(
        fn=process_image,
        inputs=[input_image],
45
        outputs=gallery
chenpangpang's avatar
chenpangpang committed
46
47
48
49
50
51
52
53
54
    )

    # Add examples
    gr.Examples(
        examples=[
            "image1.png",
            "image3.png"
        ],
        inputs=input_image,
55
        outputs=gallery,
chenpangpang's avatar
chenpangpang committed
56
        fn=process_image,
chenpangpang's avatar
chenpangpang committed
57
        cache_examples=True,
chenpangpang's avatar
chenpangpang committed
58
59
    )

60
demo.launch(server_name='0.0.0.0', share=True)