run_all.py 3.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import glob
import subprocess
import sys
from typing import List


sys.path.append(".")
from benchmark_text_to_image import ALL_T2I_CKPTS  # noqa: E402


PATTERN = "benchmark_*.py"


class SubprocessCallException(Exception):
    pass


# Taken from `test_examples_utils.py`
def run_command(command: List[str], return_stdout=False):
    """
    Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
    if an error occurred while running `command`
    """
    try:
        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
        if return_stdout:
            if hasattr(output, "decode"):
                output = output.decode("utf-8")
            return output
    except subprocess.CalledProcessError as e:
        raise SubprocessCallException(
            f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
        ) from e


def main():
    python_files = glob.glob(PATTERN)

    for file in python_files:
        print(f"****** Running file: {file} ******")

        # Run with canonical settings.
43
        if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
44
45
46
47
48
49
50
51
            command = f"python {file}"
            run_command(command.split())

            command += " --run_compile"
            run_command(command.split())

    # Run variants.
    for file in python_files:
52
53
        # See: https://github.com/pytorch/pytorch/issues/129637
        if file == "benchmark_ip_adapters.py":
Sayak Paul's avatar
Sayak Paul committed
54
55
            continue

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        if file == "benchmark_text_to_image.py":
            for ckpt in ALL_T2I_CKPTS:
                command = f"python {file} --ckpt {ckpt}"

                if "turbo" in ckpt:
                    command += " --num_inference_steps 1"

                run_command(command.split())

                command += " --run_compile"
                run_command(command.split())

        elif file == "benchmark_sd_img.py":
            for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
                command = f"python {file} --ckpt {ckpt}"

                if ckpt == "stabilityai/sdxl-turbo":
                    command += " --num_inference_steps 2"

                run_command(command.split())
                command += " --run_compile"
                run_command(command.split())

79
        elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
            command = f"python {file} --ckpt {sdxl_ckpt}"
            run_command(command.split())

            command += " --run_compile"
            run_command(command.split())

        elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
            sdxl_ckpt = (
                "diffusers/controlnet-canny-sdxl-1.0"
                if "controlnet" in file
                else "TencentARC/t2i-adapter-canny-sdxl-1.0"
            )
            command = f"python {file} --ckpt {sdxl_ckpt}"
            run_command(command.split())

            command += " --run_compile"
            run_command(command.split())


if __name__ == "__main__":
    main()