benchmark_swiglu.py 3.75 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math

import torch

from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import get_benchmark_model_config
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import parse_benchmark_script_args
from utils import run_benchmarks
from utils import run_memory_benchmark
from utils import run_speed_benchmark

from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from liger_kernel.utils import infer_device

device = infer_device()


def _setup_swiglu(input: SingleBenchmarkRunInput):
    """Create input tensor and SwiGLU layer from benchmark config."""
    cfg = input.extra_benchmark_config
    llama_config = LlamaConfig(
        hidden_size=cfg["hidden_size"],
        intermediate_size=cfg["intermediate_size"],
        hidden_act=cfg["hidden_act"],
    )
    x = torch.randn(
        cfg["bsz"],
        input.x,
        cfg["hidden_size"],
        device=device,
        dtype=cfg["dtype"],
        requires_grad=True,
    )
    if input.kernel_provider == "liger":
        layer = LigerSwiGLUMLP(config=llama_config).to(device).to(cfg["dtype"])
    elif input.kernel_provider == "huggingface":
        layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"])
    else:
        raise ValueError(f"Invalid provider: {input.kernel_provider} for SwiGLU")
    return x, layer


def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    x, layer = _setup_swiglu(input)
    return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])


def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    x, layer = _setup_swiglu(input)
    return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)


if __name__ == "__main__":
    args = parse_benchmark_script_args()

    model = get_benchmark_model_config(args.model)
    probe_seq_len = 1024

    def _probe():
        probe_input = SingleBenchmarkRunInput(
            x=probe_seq_len,
            kernel_provider="huggingface",
            extra_benchmark_config={
                "bsz": 1,
                "hidden_size": model.hidden_size,
                "intermediate_size": model.intermediate_size,
                "hidden_act": "silu",
                "dtype": model.dtype,
            },
        )
        x, layer = _setup_swiglu(probe_input)
        return layer(x)

    peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
    kernel_bpt = peak_bytes // probe_seq_len

    config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)

    common_configs = {
        "kernel_name": "swiglu",
        "x_name": "T",
        "x_label": "sequence length",
        "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)],
        "kernel_providers": ["liger", "huggingface"],
        "extra_benchmark_configs": [
            {
                "bsz": config.batch_size,
                "hidden_size": model.hidden_size,
                "intermediate_size": model.intermediate_size,
                "hidden_act": "silu",
                "dtype": model.dtype,
            }
        ],
        "overwrite": args.overwrite,
    }

    run_benchmarks(
        bench_test_fn=bench_speed_swiglu,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_swiglu,
        kernel_operation_modes=["full", "forward", "backward"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs,
    )