test_fused_attn.py 51.6 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
import functools
Tim Moon's avatar
Tim Moon committed
6
7
from importlib.metadata import version
import os
8
import math
Tim Moon's avatar
Tim Moon committed
9
10
11
from typing import Any, Dict, List, Tuple, Union

from pkg_resources import packaging
12
import pytest
Tim Moon's avatar
Tim Moon committed
13
import torch
14

Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch.attention import (
    DotProductAttention,
    RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
    QKVLayout,
    fused_attn_bwd,
    fused_attn_fwd,
)
31
32
33
34
from transformer_engine.pytorch.distributed import (
    _set_cuda_rng_state,
    CudaRNGStatesTracker,
)
Tim Moon's avatar
Tim Moon committed
35
36
37
38
39
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
    TransformerEngineBaseModule,
    _prepare_backward,
)
40
from transformer_engine.pytorch.utils import (
Tim Moon's avatar
Tim Moon committed
41
    get_device_compute_capability,
42
43
44
    init_method_normal,
    scaled_init_method_normal,
)
Tim Moon's avatar
Tim Moon committed
45
import transformer_engine_extensions as tex
46
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
47

48
# Only run FP8 tests on H100
Tim Moon's avatar
Tim Moon committed
49
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
50

51
# Initialize RNG state
52
53
54
55
56
57
58
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

def reset_rng_states() -> None:
59
    """Revert back to initial RNG state"""
60
61
62
    torch.set_rng_state(_cpu_rng_state)
    _set_cuda_rng_state(_cuda_rng_state)

63
64
65
66
67
68
69
@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = ext.get_cudnn_version()
    major, encoded_version = divmod(encoded_version, 1000)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
70

71
72
class ModelConfig:
    def __init__(
73
74
75
76
77
78
79
80
81
82
83
        self,
        batch_size: int,
        num_heads: int,
        num_gqa_groups: int,
        head_dim: int,
        max_seqlen_q: int,
        max_seqlen_kv: int,
        dropout_p: float,
        attn_mask_type: str,
        attn_bias_type: str,
        num_layers: int = 1,
84
    ):
85
86
87
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.num_gqa_groups = num_gqa_groups
88
        self.head_dim = head_dim
89
90
91
92
        self.hidden_size = num_heads * head_dim
        self.hidden_size_kv = num_gqa_groups * head_dim
        self.max_seqlen_q = max_seqlen_q
        self.max_seqlen_kv = max_seqlen_kv
93
94
        self.dropout_p = dropout_p
        self.attn_mask_type  = attn_mask_type
95
96
97
        self.attn_bias_type  = attn_bias_type
        self.attn_type  = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
        self.num_layers = num_layers
98

Tim Moon's avatar
Tim Moon committed
99
100
101
102
def _is_fused_attention_supported(
    config: ModelConfig,
    dtype: torch.dtype,
    qkv_layout: str = "sbh3d",
103
104
105
106
) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
    """Check if FusedAttention supports a model configuration"""
    backends = []
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
Tim Moon's avatar
Tim Moon committed
107
108
109
110
    backend = tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
        QKVLayout[qkv_layout],
111
        AttnBiasType[config.attn_bias_type],
Tim Moon's avatar
Tim Moon committed
112
113
        AttnMaskType[config.attn_mask_type],
        config.dropout_p,
114
115
116
117
        config.num_heads,
        config.num_gqa_groups,
        config.max_seqlen_q,
        config.max_seqlen_kv,
Tim Moon's avatar
Tim Moon committed
118
119
        config.head_dim,
    )
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    if backend == FusedAttnBackend["FP8"]:
        backends.append(backend)
        return True, backends
    if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        backends.append(backend)
        return True, backends
    if backend == FusedAttnBackend["F16_max512_seqlen"]:
        backends.append(backend)
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
        backend = tex.get_fused_attn_backend(
            TE_DType[dtype],
            TE_DType[dtype],
            QKVLayout[qkv_layout],
            AttnBiasType[config.attn_bias_type],
            AttnMaskType[config.attn_mask_type],
            config.dropout_p,
            config.num_heads,
            config.num_gqa_groups,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            config.head_dim,
        )
        if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            backends.append(backend)
        return True, backends
    return False, backends

@functools.cache
def _is_flash_attention_2_available() -> bool:
    """Check if flash-attn 2.0+ is available"""
    Version = packaging.version.Version
    return Version(version("flash-attn")) >= Version("2")

@functools.cache
def _is_flash_attention_2_1() -> bool:
155
    """Check if flash-attn 2.1+ is available"""
156
157
158
    Version = packaging.version.Version
    return Version(version("flash-attn")) >= Version("2.1")

159
160
161
162
163
164
@functools.cache
def _is_flash_attention_2_3() -> bool:
    """Check if flash-attn 2.3+ is available"""
    Version = packaging.version.Version
    return Version(version("flash-attn")) >= Version("2.3")

165
166
def _is_flash_attention_supported(config: ModelConfig) -> bool:
    """Check if FlashAttention supports a model configuration"""
Tim Moon's avatar
Tim Moon committed
167
168
    if get_device_compute_capability() < (8, 0):
        return False
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    if config.attn_bias_type != "no_bias":
        return False
    if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
        return False
    if "causal" in config.attn_mask_type and config.attn_type == "cross":
        if _is_flash_attention_2_1():
            # FAv2.1 implements causal mask for cross attention differently
            # https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
            return False
    return True

def _is_unfused_attention_supported(config: ModelConfig) -> bool:
    """Check if UnfusedDotProductAttention supports a model configuration"""
    if ("padding" in config.attn_mask_type):
        return False
    if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
Tim Moon's avatar
Tim Moon committed
185
186
187
        return False
    return True

188
189
190
191
192
193
194
195
196
197
198
199
200
model_configs_base = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias   # attn , backend
    "base_1_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0, "no_mask", "no_bias"), # self , 0
    "base_1_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0, "no_mask", "no_bias"), # cross, 0
    "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
    "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

201
202
203
204
205
206
207
208
209
210
211
def get_swa(seq_q, seq_kv, w=None):
    """Generate a random sliding window size (left, right) if w is None,
    and create its equivalent attention mask in [seq_q, seq_kv] shape"""
    if w is None:
        w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
    m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
    mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
    ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
    ml = ~ ml
    return w, ml

212
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
213
@pytest.mark.parametrize("dtype", param_types)
214
215
216
217
218
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
219
220
@pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa):
221
    """Test DotProductAttention module"""
222

Tim Moon's avatar
Tim Moon committed
223
224
225
226
    # Get configs
    tols = dict(atol=5e-3, rtol=5e-3)
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)
227
228
229
230
231
232
233
234
235
236
    config = model_configs[model]
    if qkv_layout is None:
        if config.attn_type == "self":
            qkv_layout = "sb3hd"
        else:
            qkv_layout = "sbhd_sb2hd"
    if "3" in qkv_layout and config.attn_type == "cross":
        pytest.skip(
            "No need to test this layout for cross attention"
        )
Tim Moon's avatar
Tim Moon committed
237
238

    # Skip if only unfused backend is supported
239
240
241
242
243
    unfused_attn_supported = _is_unfused_attention_supported(config)
    if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
        config, dtype, qkv_layout=qkv_layout,
Tim Moon's avatar
Tim Moon committed
244
    )
245
246
    if swa:
        fused_attn_supported = False
247
248
249
    flash_attn_supported = _is_flash_attention_supported(config)
    if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
250
251

    # UnfusedDotProductAttention backend
252
    if unfused_attn_supported:
253
254
255
        if swa:
            attn_mask_type = config.attn_mask_type
            config.attn_mask_type = "arbitrary"
256
        unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
257
            dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
258
        )
259
260
        if swa:
            config.attn_mask_type = attn_mask_type
Tim Moon's avatar
Tim Moon committed
261
262
263

    # FusedAttention backend
    if fused_attn_supported:
264
265
        if len(fused_attn_backend) == 1:
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
266
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
267
268
269
270
            )
        if len(fused_attn_backend) == 2:
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
271
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
272
273
274
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
            fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
275
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
276
            )
277

Tim Moon's avatar
Tim Moon committed
278
279
280
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
281
            dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
Tim Moon's avatar
Tim Moon committed
282
        )
283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    if unfused_attn_supported and fused_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i,_ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
    if unfused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        for i,_ in enumerate(flash_attn_bwd):
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        for i,_ in enumerate(flash_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
    if fused_attn_supported and len(fused_attn_backend) == 2:
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
        for i,_ in enumerate(fused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
307
    test_dot_product_attention(dtype, model_configs, model, True, True, None, False)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

model_configs_mask = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,      bias
    "mask_1_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0,         "causal", "no_bias"),
    "mask_1_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0,         "causal", "no_bias"),
    "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,         "causal", "no_bias"),
    "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,         "causal", "no_bias"),
    "mask_3_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0,        "padding", "no_bias"),
    "mask_3_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0,        "padding", "no_bias"),
    "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,        "padding", "no_bias"),
    "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,        "padding", "no_bias"),
    "mask_5_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0, "padding_causal", "no_bias"),
    "mask_5_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0, "padding_causal", "no_bias"),
    "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
    "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
324

325
326
327
328
329
330
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
331
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

model_configs_bias = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
    "bias_1_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "no_mask",           "alibi"), # skipped
    "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "no_mask",           "alibi"), # skipped
    "bias_2_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "padding",           "alibi"), # skipped
    "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "padding",           "alibi"), # skipped
    "bias_3_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,         "causal", "post_scale_bias"),
    "bias_3_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,         "causal", "post_scale_bias"),
    "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,         "causal", "post_scale_bias"),
    "bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,         "causal", "post_scale_bias"), # skipped
    "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,         "causal",           "alibi"),
    "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,         "causal",           "alibi"), # skipped
    "bias_4_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal",           "alibi"), # skipped
    "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal",           "alibi"), # skipped
}
360

361
362
363
364
365
366
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_swa = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
    "swa_1_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,        "no_mask",          "no_bias"),
    "swa_1_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,        "no_mask",          "no_bias"),
    "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "no_mask",          "no_bias"),
    "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "no_mask",          "no_bias"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
    """Test DotProductAttention module with sliding window attention"""
    test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
383

384
385
386
387
388
389
390
qkv_layouts = [
    'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
    'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
    # will add tests for thd layouts later when the support is available in fused attention
    #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
    ]

391
392
393
394
395
396
397
398
399
400
401
402
403
model_configs_layout = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
    "layout_0_0": ModelConfig(2, 16, 16,  64,  128,  128, 0.0,        "no_mask",         "no_bias"),
    "layout_0_1": ModelConfig(2, 16, 16,  64,  128,  128, 0.0,         "causal", "post_scale_bias"),
    "layout_0_2": ModelConfig(1, 16, 16,  64,  128,  256, 0.0,        "padding",         "no_bias"),
    "layout_0_3": ModelConfig(1, 16, 16,  64,  128,  256, 0.0, "padding_causal", "post_scale_bias"),
    "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,        "no_mask",         "no_bias"),
    "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,         "causal", "post_scale_bias"),
    "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,        "padding",         "no_bias"),
    "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}

@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
404
@pytest.mark.parametrize("dtype", param_types_lean)
405
406
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
407
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
408
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
409
    """Test DotProductAttention module with different QKV layouts"""
410
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)
411
412
413
414
415
416
417
418

def _run_dot_product_attention(
        dtype: torch.dtype,
        config: ModelConfig,
        backend: str,
        ckpt_attn: bool,
        qkv_layout: str,
        workspace_opt: bool,
419
        swa: bool,
420
421
422
423
424
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Run DotProductAttention module with one forward pass and one backward pass"""

    # Set RNG and environment varables
    reset_rng_states()
425
426
427
428
429
430
431
432
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    # Create seqlens
    qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == 'thd':
        if config.attn_type == 'self':
            seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                    dtype=torch.int32, device="cuda")
            seqlens_kv = seqlens_q
        if config.attn_type == 'cross':
            seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                    dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
                    dtype=torch.int32, device="cuda")
    else:
        seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
                dtype=torch.int32, device="cuda")
        seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
                dtype=torch.int32, device="cuda")
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
        if config.attn_type == 'self':
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
                attention_mask_q = torch.cat([attention_mask_q,
                    torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
            attention_mask = attention_mask_q.to(device="cuda")
        if config.attn_type == 'cross':
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
                attention_mask_q = torch.cat([attention_mask_q,
                    torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
                attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
                    [True]*seqlens_kv[i] + [False]*(config.max_seqlen_kv-seqlens_kv[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
            attention_mask = (
                    attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
477
478
479
480
    if swa:
        window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
    else:
        window_size, attention_mask = None, None
481
482
483
484
485
486
487
488
489
490
491
492
493
494

    # Create input tensors
    dim_to_num = {
        'b'  : config.batch_size,
        'sq' : config.max_seqlen_q,
        'skv': config.max_seqlen_kv,
        'h'  : config.num_heads,
        'hg' : config.num_gqa_groups,
        'd'  : config.head_dim,
        't'  : cu_seqlens_q[-1],
        'tg' : cu_seqlens_kv[-1],
        '3'  : 3,
        '2'  : 2,
        }
495
496
    inp = []
    for i,layout in enumerate(qkv_layout.split('_')):
497
498
499
500
501
502
503
504
505
        layout = '_'.join(layout)
        if i == 0:
            layout = layout.replace('s', 'sq')
        else:
            layout = layout.replace('s', 'skv')
            layout = layout.replace('h', 'hg')
            layout = layout.replace('t', 'tg')
        tensor_shape = [dim_to_num[j] for j in layout.split('_')]
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
506
507
        tensor_count = 1
        split_dim = 0
508
509
510
511
512
513
        for dim, l in enumerate(layout.split('_')):
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
514
515
516
517
518
519
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
520
        inp[i].requires_grad = True
521

522
523
524
525
526
527
    # Create output gradient
    qkv_format_kv = '_'.join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace('s', 'sq')
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
528

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    # Create bias
    if config.attn_bias_type in ['no_bias', 'alibi']:
        bias = None
    if config.attn_bias_type == 'post_scale_bias':
        bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
                dtype=dtype, device="cuda")

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
544
545
    block = (
         DotProductAttention(
546
                config.num_heads,
547
                config.head_dim,
548
549
                num_gqa_groups=config.num_gqa_groups,
                attention_dropout=config.dropout_p,
550
                qkv_format=qkv_format,
551
552
553
554
555
556
557
558
559
                attn_mask_type=config.attn_mask_type,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                tp_group=None,
                layer_number=1,
                attention_type=config.attn_type,
        ).to(dtype=dtype, device="cuda")
    )
560

561
562
    # Run a forward and backward pass
    out = block(inp[0], inp[1], inp[2],
563
            window_size=window_size,
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
            attention_mask=attention_mask,
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=ckpt_attn,
            core_attention_bias_type=config.attn_bias_type,
            core_attention_bias=bias,
            fast_zero_fill=True)
    out.backward(out_grad)

    return out, (inp[0].grad, inp[1].grad, inp[2].grad)

model_configs_te_layer = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,      mask,             bias
    "te_1_0": ModelConfig(2, 16, 16,  64,  128,  128, 0.0, "no_mask", "post_scale_bias"),
    "te_1_1": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,  "causal", "post_scale_bias"),
    "te_1_2": ModelConfig(2, 16, 16,  64,  128,  128, 0.0, "padding", "post_scale_bias"),
    "te_2_0": ModelConfig(1, 16, 16,  64, 2048, 2048, 0.0,  "causal",         "no_bias"),
    "te_2_1": ModelConfig(2, 16, 16,  64, 2048, 2048, 0.0, "no_mask",         "no_bias"),
    "te_2_2": ModelConfig(1, 16, 16,  64, 2048, 2048, 0.0, "padding",         "no_bias"),
}
586

587
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
588
@pytest.mark.parametrize("dtype", param_types)
589
590
591
592
593
594
595
596
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
    """Test TransformerLayer module"""
597

Tim Moon's avatar
Tim Moon committed
598
    # Get configs
599
    config = model_configs[model]
Tim Moon's avatar
Tim Moon committed
600
    tols = dict(atol=5e-1, rtol=5e-2)
601
    workspace_opt = True
602

Tim Moon's avatar
Tim Moon committed
603
    # Skip if only unfused backend is supported
604
605
606
    if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
Tim Moon's avatar
Tim Moon committed
607
608
609
610
        config,
        dtype,
        qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
    )
611
612
613
614
    flash_attn_supported = _is_flash_attention_supported(config)
    unfused_attn_supported = _is_unfused_attention_supported(config)
    if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
615
616

    # UnfusedDotProductAttention backend
617
618
619
620
621
622
623
624
625
626
627
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
        )
Tim Moon's avatar
Tim Moon committed
628
629
630
631
632
633
634

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
635
636
637
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
638
639
640
            fused_qkv_params,
            RoPE,
        )
641

Tim Moon's avatar
Tim Moon committed
642
643
644
645
646
647
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
648
649
650
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
651
652
653
            fused_qkv_params,
            RoPE,
        )
654
655
656
657
658

    if unfused_attn_supported and fused_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
Tim Moon's avatar
Tim Moon committed
659
660
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
661
662
663
    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
664

665
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
666
@pytest.mark.parametrize("dtype", param_types_lean)
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
def test_te_layer_misc(dtype, model_configs, model):
    """Test TransformerLayer module with miscellanous settings"""
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    test_transformer_layer(dtype, model_configs, model,
            ckpt_attn, qkv_format, fused_qkv_params, RoPE)

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
684
685
    def find_factors(x):
       f = []
686
       for i in range(2, x + 1):
687
688
689
690
           if x % i == 0:
               f.append(i)
       return f

691
692
693
694
695
696
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
697
698

    for num_q_per_gqa_group in num_querys_per_gqa_group:
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
        test_transformer_layer(dtype, model_configs, model,
                ckpt_attn, qkv_format, fused_qkv_params, RoPE)

def _run_transformer_layer(
        dtype: torch.dtype,
        config: ModelConfig,
        backend: str,
        ckpt_attn: bool,
        qkv_layout: str,
        workspace_opt: bool,
        fused_qkv_params: bool,
        RoPE: bool,
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
716
    reset_rng_states()
717
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
718
    os.environ["NVTE_FUSED_ATTN"] = "0"
719
720
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
721
722
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
723

724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    # Create input tensor
    inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
            dtype=dtype, device="cuda", requires_grad = True)

    # Create seqlens
    if "padding" in config.attn_mask_type:
        seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                dtype=torch.int32, device="cuda")
    else:
        seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
                dtype=torch.int32, device="cuda")

    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
        attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
        for i in range(config.batch_size):
            attention_mask_q = torch.cat([attention_mask_q,
                torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                .to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
        attention_mask = attention_mask_q.to(device="cuda")
745
746
747
748
749
750
751
752
753
754

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    # Create bias
    if config.attn_bias_type == 'no_bias':
        bias = None
    if config.attn_bias_type == 'post_scale_bias':
        bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
                dtype=dtype, device="cuda")
    elif config.attn_bias_type == 'alibi':
        if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
            config.attn_bias_type = 'post_scale_bias'
            n = 2 ** math.floor(math.log2(config.num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
            b = torch.triu(a,diagonal=1)
            c = b.cumsum(dim=-1)
            d = c - torch.transpose(c, 0, 1)
            bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
            for i in range(config.num_heads):
                bias[0,i,:,:] = m[i] *  bias[0,i,:,:]
            bias = bias.to(dtype=dtype, device="cuda")
        else:
            bias = None

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
        PE = RotaryPositionEmbedding(dim=config.head_dim)
        rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")

    # Set up model
786
787
788
789
    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
790
791
            config.num_heads,
            num_gqa_groups=config.num_gqa_groups,
792
793
794
795
796
797
798
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
799
            self_attn_mask_type=config.attn_mask_type,
800
            tp_group=None,
801
            tp_size=1,
802
803
804
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
805
806
            seq_length=config.max_seqlen_q,
            micro_batch_size=config.batch_size,
807
808
809
810
811
812
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
813
            fuse_qkv_params=fused_qkv_params,
814
815
816
817
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
818
        )
819
        .to(dtype=dtype, device="cuda")
820
821
    )

822
823
824
825
826
827
828
829
830
831
832
833
    # Run a forward and backward pass
    out = block(inp,
        attention_mask=attention_mask,
        self_attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias)
    loss = out.sum()
    loss.backward()

    return out, inp.grad
834
835


836
model_configs_fp8 = {
837
838
839
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
    "fp8_1": ModelConfig(1, 16, 16,  64,  512,  512, 0.0, "no_mask", "no_bias"),
    "fp8_2": ModelConfig(4, 16, 16,  64,  512,  512, 0.0, "no_mask", "no_bias"),
840
841
842
}
param_types_fp8 = [torch.float16]

843
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
844
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
845
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.")
846
847
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
848
849
def test_dpa_fp8(dtype, model):
    """Test FP8 dot product attention
Tim Moon's avatar
Tim Moon committed
850

851
852
853
    FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions,
    and UnfusedDotProductAttention uses plain PyTorch operations in FP16
    and converts inputs/outputs from/to FP8.
Tim Moon's avatar
Tim Moon committed
854
855

    """
856
857
858

    config = model_configs_fp8[model]

Tim Moon's avatar
Tim Moon committed
859
    # Skip if not supported
860
861
862
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
        config, dtype)
    if not fused_attn_supported:
Tim Moon's avatar
Tim Moon committed
863
864
865
        pytest.skip("FusedAttention does not support this model config")

    # Run dot-product attention with different backends
866
    fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
867
        dtype, config, "FusedAttention")
868
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
869
        dtype, config, "UnfusedDotProductAttention")
870

Tim Moon's avatar
Tim Moon committed
871
872
873
    tols = dict(atol=2.5e-2, rtol=2.5e-2)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
874

875
876
877
def _run_dpa_fp8(dtype, config, backend):
    """Run FusedAttention FP8 backend, i.e.
    fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
878

879
    reset_rng_states()
880
881
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
882
883
884
885
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
886
887

    inp = 0.01 * torch.randn(
888
889
890
891
892
            config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
            dtype=dtype, device="cuda", requires_grad=True)
    seqlens = torch.full([config.batch_size], config.max_seqlen_q,
            dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
893
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
894
895
896
897
    out_grad = 0.01 * torch.randn(
            config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
            dtype=dtype, device="cuda")
    torch.save(out_grad, 'out_grad.pt')
898
899
900
901
902
903
904
905
906

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        interval=1,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

907
    dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda")
908
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
909
910
        out = dpa(inp, cu_seqlens, config.max_seqlen_q)
        out.backward(out_grad)
911
912
913

    context = torch.load("ctx.pt")
    dqkv = torch.load('dqkv.pt')
914
915
916
    return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1),
            dqkv.view(config.batch_size, config.max_seqlen_q, 3,
            config.num_heads, config.head_dim).transpose(0,1).contiguous())
917

918
919
920
921
def _run_dpa_fp8_ref(dtype, config, backend):
    """Run UnfusedDotProductAttention as a reference, i.e.
    plain PyTorch implementation in FP16 and inputs/outputs
    are converted from/to FP8"""
922
923
924
925
926
927
928
929

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

930
931
932
933
    inp = torch.load('qkv.pt').to(device="cuda")
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
934
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
935
936
937
938
939
940
941
942
943
    out_grad = torch.load('out_grad.pt').to(device="cuda").view(
            config.batch_size, config.max_seqlen_q, -1).transpose(0,1)

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
944

945
946
947
948
949
950
951
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker():
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

952
953
    block = (
         DotProductAttention(
954
                config.num_heads,
955
                config.head_dim,
956
957
958
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
959
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
960
961
962
                tp_group=None,
                layer_number=1,
                attention_type="self"
963
        ).to(dtype=dtype, device="cuda")
964
965
966
967
968
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
969
970
971
972
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1

META_S    = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS   = tex.FP8BwdTensors.GRAD_INPUT3

class _dpa_fp8(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
996
        num_heads: int,
997
998
999
1000
1001
1002
1003
1004
1005
1006
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
    ) -> torch.Tensor:

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
1007
        h = num_heads
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        d = in_features // h
        b = cu_seqlens.numel() - 1
        is_nl = False
        if b < 4 and b > 1:
            max_s = 512
            is_nl = True

        fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

        inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
            inp,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
        )

        qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
            qkv_weight,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
        )

        M = None
        ZInv = None
        philox_unpacked = None

1035
        qkv_out, _ = ext.fp8_gemm(
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            qkv_weight_fp8,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
            inputmat,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
            torch.uint8,
            workspace,
            bias=qkv_bias,
            use_bias=True,
1048
1049
            out_index=META_QKV,
            fp8_meta_tensor=fp8_meta["scaling_fwd"],
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
            use_split_accumulator=_2X_ACC_FPROP,
            D_dtype=fp8_dtype_forward,
        )
        qkv_out = qkv_out.view(-1, 3, h, d)
        qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
                META_QKV, fp8_dtype_forward,
                tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
        torch.save(qkv_out_fp16, 'qkv.pt')

        # FMHA
1060
        context_, aux_ctx_tensors, *rest = fused_attn_fwd(
1061
1062
                is_training,
                max_s,
1063
                max_s,
1064
                cu_seqlens,
1065
1066
1067
1068
                cu_seqlens,
                qkv_out[:,0,:,:],
                qkv_out[:,1,:,:],
                qkv_out[:,2,:,:],
1069
1070
1071
1072
1073
1074
1075
1076
                fp8_dtype_forward,
                FusedAttnBackend["FP8"],
                None,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
1077
1078
1079
                attn_scale=None,
                dropout=p_dropout,
                fast_zero_fill=fast_zero_fill,
1080
                qkv_layout="t3hd",
1081
1082
1083
                attn_bias_type="no_bias",
                attn_mask_type="padding",
                rng_gen=None,
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
                )
        M, ZInv, philox_unpacked = aux_ctx_tensors

        context = context_.view(-1, in_features)
        context_t = tex.fp8_transpose(context, fp8_dtype_forward)

        ctx.save_for_backward(
            inputmat_t, qkv_weight_t_fp8, workspace,
            qkv_out,
            context_, context_t,
            fp8_meta["scaling_fwd"].scale,
            fp8_meta["scaling_fwd"].scale_inv,
        )
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.is_nl = is_nl
        ctx.hidden_size = in_features
1105
        ctx.num_heads = num_heads
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

        context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
                META_O, fp8_dtype_forward, tex.DType.kFloat16)
        torch.save(context_fp16, 'ctx.pt')
        return context_fp16


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
            (
                inputmat_t,
                qkv_weight_t_fp8,
                workspace,
                qkv_out,
                context, context_t,
                fwd_scales,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
            fp8_dtype_forward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=True
            )
            fp8_dtype_backward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=False
            )

            proj_dgrad = ext.cast_to_fp8(
                grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
            )

1139
            dq, dk, dv, *rest = fused_attn_bwd(
1140
                    ctx.max_s,
1141
1142
                    ctx.max_s,
                    ctx.cu_seqlens,
1143
                    ctx.cu_seqlens,
1144
1145
1146
                    qkv_out[:,0,:,:],
                    qkv_out[:,1,:,:],
                    qkv_out[:,2,:,:],
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
                    context,
                    proj_dgrad.view_as(context),
                    fp8_dtype_forward,
                    ctx.aux_ctx_tensors,
                    FusedAttnBackend["FP8"],
                    fwd_scale_inverses[META_QKV], # d_scale_qkv,
                    fwd_scale_inverses[META_S], # d_scale_s,
                    fwd_scale_inverses[META_O], # d_scale_o,
                    ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                    fwd_scales[META_S], # q_scale_s
                    ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
                    ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                    None,
                    ctx.p_dropout,
                    ctx.fast_zero_fill,
1164
                    "t3hd",
1165
1166
1167
                    "no_bias",
                    "padding",
                    )
1168
            dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184

            dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
            dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                fp8_dtype_backward, tex.DType.kFloat16)
            torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')

            qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"],
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
            )

            # QKV DGRAD
1185
            qkv_dgrad, _ = ext.fp8_gemm(
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
                qkv_weight_t_fp8,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_DGRAD,
            )
            # QKV WGRAD
1199
            qkv_wgrad, _ = ext.fp8_gemm(
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
                inputmat_t,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
                dqkv_grad_output_t,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_WGRAD,
            )

        return (qkv_dgrad,
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None)

class DPA_FP8(TransformerEngineBaseModule):
    def __init__(
        self,
        config,
        params_dtype: torch.dtype = torch.float32):
        super().__init__()
        self.p_dropout = config.dropout_p
1234
        self.h = config.num_heads
1235
1236
1237
1238
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.fast_zero_fill = True

Tim Moon's avatar
Tim Moon committed
1239
        self.qkv_weight = torch.nn.Parameter(
1240
1241
1242
1243
1244
1245
1246
1247
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.qkv_weight.shape)
Tim Moon's avatar
Tim Moon committed
1248
        self.qkv_bias = torch.nn.Parameter(
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
        self, inp: torch.Tensor,
        cu_seqlens, max_s,
    ) -> torch.Tensor:
        with self.prepare_forward(inp, None, num_gemms=3) as inp:
            out = _dpa_fp8.apply(
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
                self.training)
        return out

    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""