latency.py 3.58 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
import argparse
import time

import torch
5
from torch import nn
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from tqdm import trange

from utils import get_pipeline


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
    )
    parser.add_argument(
        "-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
    )

    parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
    parser.add_argument("-g", "--guidance-scale", type=float, default=0, help="Guidance scale")

    # Test related
    parser.add_argument("--warmup-times", type=int, default=2, help="Number of warmup times")
    parser.add_argument("--test-times", type=int, default=10, help="Number of test times")
    parser.add_argument(
        "--mode",
        type=str,
        default="end2end",
        choices=["end2end", "step"],
        help="Measure mode: end-to-end latency or per-step latency",
    )
    parser.add_argument(
        "--ignore_ratio", type=float, default=0.2, help="Ignored ratio of the slowest and fastest steps"
    )
    known_args, _ = parser.parse_known_args()

    if known_args.model == "dev":
        parser.set_defaults(num_inference_steps=50, guidance_scale=3.5)
    args = parser.parse_args()
    return args


def main():
    args = get_args()
    pipeline = get_pipeline(model_name=args.model, precision=args.precision, device="cuda")

    dummy_prompt = "A cat holding a sign that says hello world"

    latency_list = []
    if args.mode == "end2end":
        pipeline.set_progress_bar_config(position=1, desc="Step", leave=False)
        for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
            pipeline(
55
                prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
59
60
            )
            torch.cuda.synchronize()
        for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
            start_time = time.time()
            pipeline(
61
                prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
65
66
            )
            torch.cuda.synchronize()
            end_time = time.time()
            latency_list.append(end_time - start_time)
    elif args.mode == "step":
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        inputs = {}

        def get_input_hook(module: nn.Module, input_args, input_kwargs):
            inputs["args"] = input_args
            inputs["kwargs"] = input_kwargs

        pipeline.transformer.register_forward_pre_hook(get_input_hook, with_kwargs=True)

        pipeline(prompt=dummy_prompt, num_inference_steps=1, guidance_scale=args.guidance_scale, output_type="latent")

        for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
            pipeline.transformer(*inputs["args"], **inputs["kwargs"])
            torch.cuda.synchronize()
        for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
            start_time = time.time()
            pipeline.transformer(*inputs["args"], **inputs["kwargs"])
            torch.cuda.synchronize()
            end_time = time.time()
            latency_list.append(end_time - start_time)

Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
89
90
91
92
93
94
95
96
    latency_list = sorted(latency_list)
    ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
    if ignored_count > 0:
        latency_list = latency_list[ignored_count:-ignored_count]

    print(f"Latency: {sum(latency_list) / len(latency_list):.5f} s")


if __name__ == "__main__":
    main()