utils.py 11.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.

from __future__ import annotations

7
8
9
10
11
import logging
import os
from contextlib import contextmanager

import pytest
12
13
14
import torch

import transformer_engine
15
import transformer_engine.common.recipe
16
17
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
18
19
20
21
22
from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
    get_attention_backend,
    AttentionParams,
    AttentionLogging,
23
    check_set_window_size,
24
25
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
    """Convert type name to PyTorch dtype"""
    if isinstance(dtype, torch.dtype):
        return dtype
    name = str(dtype).strip().lower()
    if name.startswith("torch."):
        name = name.replace("torch.", "", 1)
    if name.startswith("fp"):
        name = name.replace("fp", "float", 1)
    dtype = dict(
        float32=torch.float32,
        float=torch.float32,
        float64=torch.float64,
        double=torch.float64,
        float16=torch.float16,
        half=torch.float16,
        bfloat16=torch.bfloat16,
        bf16=torch.bfloat16,
        float8_e4m3fn=torch.float8_e4m3fn,
        float8_e4m3=torch.float8_e4m3fn,
        float8e4m3=torch.float8_e4m3fn,
        float8=torch.float8_e4m3fn,
        float8_e5m2=torch.float8_e5m2,
        float8e5m2=torch.float8_e5m2,
        uint8=torch.uint8,
        byte=torch.uint8,
        int8=torch.int8,
        char=torch.int8,
        int16=torch.int16,
        short=torch.int16,
        int32=torch.int32,
        int=torch.int32,
        int64=torch.int64,
        long=torch.int64,
        bool=torch.bool,
    )[name]
    return dtype


def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
    """Estimated numerical error for a datatype

    Based on tolerances for torch.testing.assert_close.

    """

    # Transformer Engine dtypes
    if isinstance(dtype, tex.DType):
        dtype = {
            tex.DType.kByte: torch.uint8,
            tex.DType.kInt32: torch.int32,
            tex.DType.kFloat32: torch.float32,
            tex.DType.kFloat16: torch.half,
            tex.DType.kBFloat16: torch.bfloat16,
            tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
            tex.DType.kFloat8E5M2: torch.float8_e5m2,
        }[dtype]

    # PyTorch dtypes
    if dtype == torch.float16:
        return dict(rtol=1e-3, atol=1e-5)
    if dtype == torch.bfloat16:
        return dict(rtol=1.6e-2, atol=1e-5)
    if dtype == torch.float32:
        return dict(rtol=1.3e-6, atol=1e-5)
    if dtype == torch.float64:
        return dict(rtol=1e-7, atol=1e-7)
    if dtype == torch.float8_e4m3fn:
        return dict(rtol=0.125, atol=0.0675)  # epsilon = 0.0625
    if dtype == torch.float8_e5m2:
        return dict(rtol=0.25, atol=0.125)  # epsilon = 0.152
    raise ValueError(f"Unsupported dtype ({dtype})")
100
101
102
103
104
105
106
107
108


def make_recipe(name: Optional[str]) -> Optional[Recipe]:
    """Make recipe for quantization scheme"""
    if name is None:
        return None
    if name in ("fp8", "fp8_delayed_scaling"):
        return transformer_engine.common.recipe.DelayedScaling(
            fp8_format=transformer_engine.common.recipe.Format.E4M3,
109
            amax_history_len=8,
110
111
112
113
114
115
116
117
118
119
120
121
        )
    if name == "fp8_current_scaling":
        return transformer_engine.common.recipe.Float8CurrentScaling(
            fp8_format=transformer_engine.common.recipe.Format.E4M3,
        )
    if name == "mxfp8":
        return transformer_engine.common.recipe.MXFP8BlockScaling(
            fp8_format=transformer_engine.common.recipe.Format.E4M3,
        )
    if name == "fp8_block_scaling":
        return transformer_engine.common.recipe.Float8BlockScaling()
    raise ValueError(f"Unsupported quantization scheme ({name})")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140


# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None


def reset_rng_states() -> None:
    """Revert to deterministic RNG state"""
    global _rng_states
    if _rng_states is None:
        torch.manual_seed(1234)
        torch.cuda.manual_seed(1234)
        _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
    else:
        cpu_rng_state, cuda_rng_state = _rng_states
        torch.set_rng_state(cpu_rng_state)
        torch.cuda.set_rng_state(cuda_rng_state)


141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
    if not is_fp8:
        torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
        return

    try:
        if a.dtype != b.dtype:
            a = a.to(b.dtype)
        torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
    except Exception as e:
        logging.debug(e)

    rmse = torch.sqrt((a - b).square().mean()).item()
    logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
    rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
    assert rmse < rmse_tol * rmse_range, (
        name_a
        + " vs "
        + name_b
        + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
            rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
        )
    )


166
167
168
169
170
171
172
173
174
175
class ModelConfig:
    def __init__(
        self,
        batch_size: int,
        max_seqlen_q: int,
        num_heads: int,
        head_dim_qk: int,
        max_seqlen_kv: int = None,
        num_gqa_groups: int = None,
        head_dim_v: int = None,
176
        softmax_type: str = "vanilla",
177
178
179
180
181
182
        dropout_p: float = 0.0,
        attn_mask_type: str = "no_mask",
        attn_bias_type: str = "no_bias",
        alibi_type: str = "none",
        bias_shape: str = "1hss",
        window_size: Tuple[int, int] = (-1, -1),
183
184
        context_parallel: bool = False,
        cp_comm_type: str = "p2p",
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        total_requests: int = None,
        max_ctx_len: int = None,
        num_layers: int = 1,
        eps: float = 1e-5,
    ):
        self.batch_size = batch_size
        self.max_seqlen_q = max_seqlen_q
        self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv
        self.num_heads = num_heads
        self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups
        self.head_dim_qk = head_dim_qk
        self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
        if self.head_dim_qk == self.head_dim_v:
            self.kv_channels = self.head_dim_qk
        else:
            self.kv_channels = (self.head_dim_qk, self.head_dim_v)
        self.hidden_size = self.num_heads * self.head_dim_qk
        self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
203
        self.softmax_type = softmax_type
204
205
206
207
208
209
        self.dropout_p = dropout_p
        self.attn_mask_type = attn_mask_type
        self.attn_bias_type = attn_bias_type
        self.alibi_type = alibi_type
        self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
        self.bias_shape = bias_shape
210
211
212
        self.window_size = check_set_window_size(self.attn_mask_type, window_size)
        self.context_parallel = context_parallel
        self.cp_comm_type = cp_comm_type
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        self.total_requests = total_requests
        self.max_ctx_len = max_ctx_len
        self.num_layers = num_layers
        self.eps = eps


@contextmanager
def logging_context(highest_level=logging.WARNING):
    previous_level = logging.root.manager.disable
    logging.disable(highest_level)
    try:
        yield
    finally:
        logging.disable(previous_level)


def get_available_attention_backends(
    config: ModelConfig,
    qkv_dtype: torch.dtype,
    qkv_layout: str,
    pad_between_seqs: bool = False,
    deterministic: bool = False,
    fp8: bool = False,
    fp8_meta: Optional[Dict[str, Any]] = None,
    is_training: bool = True,
    inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
    """Check for all available attention backends that support a model configuration"""

    os.environ["NVTE_FLASH_ATTN"] = "1"
    os.environ["NVTE_FUSED_ATTN"] = "1"
    os.environ["NVTE_UNFUSED_ATTN"] = "1"
    _attention_backends["backend_selection_requires_update"] = True

    alibi_slopes_shape = None
    if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
        if config.bias_shape == "1hss":
            alibi_slopes_shape = [config.num_heads]
        if config.bias_shape == "bhss":
            alibi_slopes_shape = [config.batch_size, config.num_heads]

    core_attention_bias_shape = (
        config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
    )
    core_attention_bias_requires_grad = False
    # d=256 is supported by cuDNN 9.0+ for inference but not training
    if (
        config.attn_bias_type == "post_scale_bias"
        and config.head_dim_qk <= 128
        and config.head_dim_v <= 128
    ):
        core_attention_bias_requires_grad = True

    fused_attn_backends = []
    available_backends = None
    flash_attention_backend = None
    fused_attention_backend = None

    def test():
        attention_params = AttentionParams(
            qkv_dtype=qkv_dtype,
            qkv_layout=qkv_layout,
            batch_size=config.batch_size,
            num_heads=config.num_heads,
            num_gqa_groups=config.num_gqa_groups,
            max_seqlen_q=config.max_seqlen_q,
            max_seqlen_kv=config.max_seqlen_kv,
            head_dim_qk=config.head_dim_qk,
            head_dim_v=config.head_dim_v,
            attn_mask_type=config.attn_mask_type,
283
            window_size=config.window_size,
284
285
286
287
288
289
            alibi_slopes_shape=alibi_slopes_shape,
            core_attention_bias_type=config.attn_bias_type,
            core_attention_bias_shape=core_attention_bias_shape,
            core_attention_bias_requires_grad=core_attention_bias_requires_grad,
            pad_between_seqs=pad_between_seqs,
            attention_dropout=config.dropout_p,
290
291
            context_parallel=config.context_parallel,
            cp_comm_type=config.cp_comm_type,
292
293
294
295
296
            deterministic=deterministic,
            fp8=fp8,
            fp8_meta=fp8_meta,
            is_training=is_training,
            inference_params=inference_params,
297
            softmax_type=config.softmax_type,
298
299
300
301
        )
        (
            use_flash_attention,
            flash_attention_backend,
302
            use_fused_attention,
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            fused_attention_backend,
            use_unfused_attention,
            available_backends,
        ) = get_attention_backend(attention_params)
        # Set attention.py _attention_backends var using return value
        # from get_attention_backend()
        _attention_backends["use_flash_attention"] = use_flash_attention
        _attention_backends["use_fused_attention"] = use_fused_attention
        _attention_backends["flash_attention_backend"] = flash_attention_backend
        _attention_backends["fused_attention_backend"] = fused_attention_backend
        _attention_backends["use_unfused_attention"] = use_unfused_attention
        _attention_backends["backend_selection_requires_update"] = False
        return available_backends, flash_attention_backend, fused_attention_backend

    backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
    if AttentionLogging._is_logging_setup is False:
        AttentionLogging.setup_logging()
    with logging_context(highest_level=AttentionLogging._log_level):
        for i in range(3):
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
            _attention_backends["backend_selection_requires_update"] = True
            available_backends, flash_attention_backend, fused_attention_backend = test()
            if fused_attention_backend == FusedAttnBackend[backends[i]]:
                fused_attn_backends.append(fused_attention_backend)
    return available_backends, flash_attention_backend, fused_attn_backends