benchmark_layernorm.py 2.87 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

laibao's avatar
laibao committed
4
5
6
7
8
import time

import torch

from vllm.model_executor.layers.layernorm import RMSNorm
9
10
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
laibao's avatar
laibao committed
11
12
13


@torch.inference_mode()
14
15
16
17
18
19
20
21
22
23
24
def main(
    num_tokens: int,
    hidden_size: int,
    add_residual: bool,
    dtype: torch.dtype,
    seed: int = 0,
    do_profile: bool = False,
    num_warmup_iters: int = 5,
    num_iters: int = 100,
) -> None:
    current_platform.seed_everything(seed)
laibao's avatar
laibao committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    torch.set_default_device("cuda")

    layer = RMSNorm(hidden_size).to(dtype=dtype)
    layer.weight.data.normal_(mean=1.0, std=0.1)
    scale = 1 / (2 * hidden_size)
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
    x *= scale
    residual = torch.randn_like(x) * scale if add_residual else None

    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
        torch.cuda.synchronize()
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        start_time = time.perf_counter()

        for _ in range(num_iters):
            layer(x, residual)
        torch.cuda.synchronize()

        end_time = time.perf_counter()
        if profile:
46
            torch.cuda.cudart().cudaProfilerStop()
laibao's avatar
laibao committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        return (end_time - start_time) / num_iters

    # Warmup.
    print("Warming up...")
    run_benchmark = run_cuda_benchmark
    run_benchmark(num_iters=num_warmup_iters, profile=False)

    # Benchmark.
    if do_profile:
        latency = run_benchmark(num_iters=1, profile=True)
    else:
        latency = run_benchmark(num_iters=num_iters, profile=False)
    print(f"Kernel running time: {latency * 1000000:.3f} us")


62
63
if __name__ == "__main__":
    parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
laibao's avatar
laibao committed
64
65
66
    parser.add_argument("--num-tokens", type=int, default=4096)
    parser.add_argument("--hidden-size", type=int, default=8192)
    parser.add_argument("--add-residual", action="store_true")
67
68
69
    parser.add_argument(
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
    )
laibao's avatar
laibao committed
70
71
72
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--profile", action="store_true")
    parser.add_argument("--num-warmup-iters", type=int, default=5)
73
74
75
76
77
78
79
    parser.add_argument(
        "--num-iters",
        type=int,
        default=100,
        help="Number of benchmark iterations. "
        "If --profile is set, this number is ignored",
    )
laibao's avatar
laibao committed
80
81
82
83

    args = parser.parse_args()
    print(args)

84
85
86
87
88
89
90
91
92
93
    main(
        num_tokens=args.num_tokens,
        hidden_size=args.hidden_size,
        add_residual=args.add_residual,
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
        seed=args.seed,
        do_profile=args.profile,
        num_warmup_iters=args.num_warmup_iters,
        num_iters=args.num_iters,
    )