benchmark_config.py 9.07 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import hashlib
import json
import logging
from typing import Any


KERNELIZATION_AVAILABLE = False
try:
    from kernels import Mode, kernelize  # noqa: F401

    KERNELIZATION_AVAILABLE = True
except ImportError:
    pass

logger = logging.getLogger(__name__)


class BenchmarkConfig:
    """Configuration for a single benchmark scenario."""

    def __init__(
        self,
        warmup_iterations: int = 5,
        measurement_iterations: int = 20,
        gpu_monitoring: bool = True,  # NOTE: you may want to disable this at times as we have obsvered it could heavily slow down benchmarks on AMD
        batch_size: int = 1,
        sequence_length: int = 128,
        num_tokens_to_generate: int = 128,
        attn_implementation: str = "eager",
        sdpa_backend: str | None = None,
        compile_mode: str | None = None,
        compile_options: dict[str, Any] | None = None,
        kernelize: bool = False,
        name: str | None = None,
        skip_validity_check: bool = False,
    ) -> None:
        # Benchmark parameters
        self.warmup_iterations = warmup_iterations
        self.measurement_iterations = measurement_iterations
        self.gpu_monitoring = gpu_monitoring
        # Input parameters
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.num_tokens_to_generate = num_tokens_to_generate
        # Generation parameters
        self.attn_implementation = attn_implementation
        self.sdpa_backend = sdpa_backend
        # Optimization parameters
        self.compile_mode = compile_mode
        self.compile_options = compile_options if compile_options is not None else {}
        self.kernelize = kernelize
        # Constant parameters
        self.dtype = "torch.bfloat16"
        self.device = "cuda"

        self.check_validity(skip_validity_check)
        self.name = name if name is not None else self.infer_name()

    def check_validity(self, skip_validity_check: bool = False) -> None:
        if skip_validity_check:
            return
        # Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
        is_fa = self.attn_implementation == "flash_attention_2"
        is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
        if is_fa:
            logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
            self.compile_mode = None

    @property
    def hash(self) -> str:
        return hashlib.sha256(json.dumps(self.to_dict()).encode()).hexdigest()

    def infer_name(self, compact: bool = True) -> str:
        """Infer a human-readable name for the benchmark config, either compact or verbose."""
        if compact:
            iter_str = f"w{self.warmup_iterations}_i{self.measurement_iterations}"
            gpu_monitor_str = "monitored" if self.gpu_monitoring else "unmonitored"
            dimensions_str = f"b{self.batch_size}_s{self.sequence_length}_n{self.num_tokens_to_generate}"
            attn_code = self.attn_implementation
            attn_code += f"_{self.sdpa_backend}" if self.attn_implementation == "sdpa" else ""
            compile_str = f"compiled_{self.compile_mode}" if self.compile_mode is not None else "uncompiled"
            kernelize_str = "kernelized" if self.kernelize else "unkernelized"
            sep = "-"
        else:
            iter_str = f"{self.warmup_iterations} warmup, {self.measurement_iterations} iterations"
            gpu_monitor_str = ("with" if self.gpu_monitoring else "no") + " GPU monitoring"
            dimensions_str = f"batch size {self.batch_size}, sequence length {self.sequence_length}, {self.num_tokens_to_generate} generated tokens"
            attn_code = f"{self.attn_implementation} attention"
            attn_code += f" with {self.sdpa_backend} backend" if self.attn_implementation == "sdpa" else ""
            compile_str = "compiled" if self.compile_mode is not None else "not compiled"
            kernelize_str = "kernelized" if self.kernelize else "not kernelized"
            sep = ", "
        return sep.join([iter_str, gpu_monitor_str, dimensions_str, attn_code, compile_str, kernelize_str])

    def to_dict(self) -> dict[str, Any]:
        return {
            "name": self.name,
            "warmup_iterations": self.warmup_iterations,
            "measurement_iterations": self.measurement_iterations,
            "gpu_monitoring": self.gpu_monitoring,
            "batch_size": self.batch_size,
            "sequence_length": self.sequence_length,
            "num_tokens_to_generate": self.num_tokens_to_generate,
            "attn_implementation": self.attn_implementation,
            "sdpa_backend": self.sdpa_backend,
            "compile_mode": self.compile_mode,
            "compile_options": self.compile_options | {},  # to avoid inplace modification of the original dict
            "kernelize": self.kernelize,
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "BenchmarkConfig":
        return cls(
            warmup_iterations=data.get("warmup_iterations", 5),
            measurement_iterations=data.get("measurement_iterations", 20),
            gpu_monitoring=data.get("gpu_monitoring", False),
            batch_size=data.get("batch_size", 1),
            sequence_length=data.get("sequence_length", 128),
            num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
            attn_implementation=data.get("attn_implementation", "eager"),
            sdpa_backend=data.get("sdpa_backend"),
            compile_mode=data.get("compile_mode"),
            compile_options=data.get("compile_options"),
            kernelize=data.get("kernelize", False),
            name=data.get("name"),
            skip_validity_check=skip_validity_check,
        )


def cross_generate_configs(
    attn_impl_and_sdpa_backend: list[tuple[str, str | None]],
    compiled_mode: list[str | None],
    kernelized: list[bool],
    warmup_iterations: int = 5,
    measurement_iterations: int = 20,
    batch_size: int = 1,
    sequence_length: int = 128,
    num_tokens_to_generate: int = 128,
    gpu_monitoring: bool = True,
) -> list[BenchmarkConfig]:
    # Create kwargs common to all configs
    kwargs = {
        "warmup_iterations": warmup_iterations,
        "measurement_iterations": measurement_iterations,
        "batch_size": batch_size,
        "sequence_length": sequence_length,
        "num_tokens_to_generate": num_tokens_to_generate,
        "gpu_monitoring": gpu_monitoring,
    }
    # Cross-generate all combinations of attn_implementation, compiled_mode, and kernelized
    configs = []
    for attn_implementation, sdpa_backend in list(dict.fromkeys(attn_impl_and_sdpa_backend)):
        for cm in list(dict.fromkeys(compiled_mode)):
            for kernelize_on in list(dict.fromkeys(kernelized)):
                config = BenchmarkConfig(
                    attn_implementation=attn_implementation,
                    sdpa_backend=sdpa_backend,
                    compile_mode=cm,
                    kernelize=kernelize_on,
                    **kwargs,
                )
                configs.append(config)
    return configs


def generate_all_configs(
    warmup_iterations: int = 5,
    measurement_iterations: int = 20,
    batch_size: int = 1,
    sequence_length: int = 128,
    num_tokens_to_generate: int = 128,
    gpu_monitoring: bool = True,
) -> list[BenchmarkConfig]:
    all_attn_implementations = [
        ("flash_attention_2", None),
        ("eager", None),
        ("sdpa", "math"),
        ("sdpa", "flash_attention"),
        ("flex_attention", None),
    ]
    return cross_generate_configs(
        attn_impl_and_sdpa_backend=all_attn_implementations,
        compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
        kernelized=[False, KERNELIZATION_AVAILABLE],
        warmup_iterations=warmup_iterations,
        measurement_iterations=measurement_iterations,
        batch_size=batch_size,
        sequence_length=sequence_length,
        num_tokens_to_generate=num_tokens_to_generate,
        gpu_monitoring=gpu_monitoring,
    )


def generate_main_configs(
    warmup_iterations: int = 5,
    measurement_iterations: int = 20,
    batch_size: int = 1,
    sequence_length: int = 128,
    num_tokens_to_generate: int = 128,
) -> list[BenchmarkConfig]:
    # Create kwargs common to all configs
    kwargs = {
        "warmup_iterations": warmup_iterations,
        "measurement_iterations": measurement_iterations,
        "batch_size": batch_size,
        "sequence_length": sequence_length,
        "num_tokens_to_generate": num_tokens_to_generate,
    }
    return [  # TODO: test max-autotune instead of default
        BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=False, **kwargs),
        BenchmarkConfig(attn_implementation="flex_attention", compile_mode="default", gpu_monitoring=True, **kwargs),
        BenchmarkConfig(attn_implementation="eager", compile_mode="default", gpu_monitoring=True, **kwargs),
        BenchmarkConfig(attn_implementation="flash_attention_2", gpu_monitoring=True, **kwargs),
    ]