app.py 1.7 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
import gradio as gr
from PIL import Image
from aura_sr import AuraSR

5
6
7
8
9
10
11
# # Force CPU usage
# torch.set_default_type(torch.FloatTensor)
# torch.set_default_device('cpu')
#
# # Override torch.load to always use CPU
# original_load = torch.load
# torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu'))
chenpangpang's avatar
chenpangpang committed
12
13

# Initialize the AuraSR model
14
15
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2/model.safetensors")

chenpangpang's avatar
chenpangpang committed
16

17
18
# # Restore original torch.load
# torch.load = original_load
chenpangpang's avatar
chenpangpang committed
19
20
21
22
23
24
25
26
27
28
29

def process_image(input_image):
    if input_image is None:
        raise gr.Error("Please provide an image to upscale.")

    # Convert to PIL Image for resizing
    pil_image = Image.fromarray(input_image)

    # Upscale the image using AuraSR
    upscaled_image = process_image_on_gpu(pil_image)

30
    return upscaled_image
chenpangpang's avatar
chenpangpang committed
31
32
33
34


def process_image_on_gpu(pil_image):
    return aura_sr.upscale_4x(pil_image)
35
36
37


title = """<h1 align="center">AuraSR-v2:一款基于GAN图像修复工具,可从低分辨率图片生成高分辨率图片</h1>"""
chenpangpang's avatar
chenpangpang committed
38
39
40

with gr.Blocks() as demo:
    gr.HTML(title)
41

chenpangpang's avatar
chenpangpang committed
42
43
    with gr.Row():
        with gr.Column(scale=1):
44
45
            input_image = gr.Image(label="输入图片", type="numpy")
            process_btn = gr.Button("生成")
chenpangpang's avatar
chenpangpang committed
46
        with gr.Column(scale=1):
47
            gallery = gr.Image(label="生成图片")
chenpangpang's avatar
chenpangpang committed
48
49
50
51

    process_btn.click(
        fn=process_image,
        inputs=[input_image],
52
        outputs=gallery
chenpangpang's avatar
chenpangpang committed
53
54
55
56
57
58
59
60
61
    )

    # Add examples
    gr.Examples(
        examples=[
            "image1.png",
            "image3.png"
        ],
        inputs=input_image,
62
        outputs=gallery,
chenpangpang's avatar
chenpangpang committed
63
64
65
66
        fn=process_image,
        cache_examples=True
    )

67
demo.launch(server_name='0.0.0.0', share=True)