predict.py 1.44 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import tempfile
import shutil
import os
from PIL import Image
import subprocess
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
    def predict(
        self,
        image: Path = Input(
            description="Input Image.",
        ),
    ) -> Path:
        input_dir = "input_dir"
        output_path = Path(tempfile.mkdtemp()) / "output.png"

        try:
            for d in [input_dir, "results"]:
                if os.path.exists(input_dir):
                    shutil.rmtree(input_dir)
            os.makedirs(input_dir, exist_ok=False)

            input_path = os.path.join(input_dir, os.path.basename(image))
            shutil.copy(str(image), input_path)
            subprocess.call(
                [
                    "python",
                    "hat/test.py",
                    "-opt",
                    "options/test/HAT_SRx4_ImageNet-LR.yml",
                ]
            )
            res_dir = os.path.join(
                "results", "HAT_SRx4_ImageNet-LR", "visualization", "custom"
            )
            assert (
                len(os.listdir(res_dir)) == 1
            ), "Should contain only one result for Single prediction."
            res = Image.open(os.path.join(res_dir, os.listdir(res_dir)[0]))
            res.save(str(output_path))

        finally:
            pass
            shutil.rmtree(input_dir)
            shutil.rmtree("results")

        return output_path