test_fused_attn.py 49.4 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
import functools
Tim Moon's avatar
Tim Moon committed
6
7
from importlib.metadata import version
import os
8
import math
Tim Moon's avatar
Tim Moon committed
9
10
11
from typing import Any, Dict, List, Tuple, Union

from pkg_resources import packaging
12
import pytest
Tim Moon's avatar
Tim Moon committed
13
import torch
14

Tim Moon's avatar
Tim Moon committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch.attention import (
    DotProductAttention,
    RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
    QKVLayout,
    fused_attn_bwd,
    fused_attn_fwd,
)
31
32
33
34
from transformer_engine.pytorch.distributed import (
    _set_cuda_rng_state,
    CudaRNGStatesTracker,
)
Tim Moon's avatar
Tim Moon committed
35
36
37
38
39
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
    TransformerEngineBaseModule,
    _prepare_backward,
)
40
from transformer_engine.pytorch.utils import (
Tim Moon's avatar
Tim Moon committed
41
    get_device_compute_capability,
42
43
44
    init_method_normal,
    scaled_init_method_normal,
)
Tim Moon's avatar
Tim Moon committed
45
import transformer_engine_extensions as tex
46
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
47

48
# Only run FP8 tests on H100
Tim Moon's avatar
Tim Moon committed
49
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
50

51
# Initialize RNG state
52
53
54
55
56
57
58
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()

def reset_rng_states() -> None:
59
    """Revert back to initial RNG state"""
60
61
62
    torch.set_rng_state(_cpu_rng_state)
    _set_cuda_rng_state(_cuda_rng_state)

63
64
65
66
67
68
69
@functools.cache
def _cudnn_version() -> Tuple[int, int, int]:
    """Runtime cuDNN version (major, minor, patch)"""
    encoded_version = ext.get_cudnn_version()
    major, encoded_version = divmod(encoded_version, 1000)
    minor, patch = divmod(encoded_version, 100)
    return (major, minor, patch)
70

71
72
class ModelConfig:
    def __init__(
73
74
75
76
77
78
79
80
81
82
83
        self,
        batch_size: int,
        num_heads: int,
        num_gqa_groups: int,
        head_dim: int,
        max_seqlen_q: int,
        max_seqlen_kv: int,
        dropout_p: float,
        attn_mask_type: str,
        attn_bias_type: str,
        num_layers: int = 1,
84
    ):
85
86
87
        self.batch_size = batch_size
        self.num_heads = num_heads
        self.num_gqa_groups = num_gqa_groups
88
        self.head_dim = head_dim
89
90
91
92
        self.hidden_size = num_heads * head_dim
        self.hidden_size_kv = num_gqa_groups * head_dim
        self.max_seqlen_q = max_seqlen_q
        self.max_seqlen_kv = max_seqlen_kv
93
94
        self.dropout_p = dropout_p
        self.attn_mask_type  = attn_mask_type
95
96
97
        self.attn_bias_type  = attn_bias_type
        self.attn_type  = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
        self.num_layers = num_layers
98

Tim Moon's avatar
Tim Moon committed
99
100
101
102
def _is_fused_attention_supported(
    config: ModelConfig,
    dtype: torch.dtype,
    qkv_layout: str = "sbh3d",
103
104
105
106
) -> Tuple[bool, NVTE_Fused_Attn_Backend]:
    """Check if FusedAttention supports a model configuration"""
    backends = []
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
Tim Moon's avatar
Tim Moon committed
107
108
109
110
    backend = tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
        QKVLayout[qkv_layout],
111
        AttnBiasType[config.attn_bias_type],
Tim Moon's avatar
Tim Moon committed
112
113
        AttnMaskType[config.attn_mask_type],
        config.dropout_p,
114
115
116
117
        config.num_heads,
        config.num_gqa_groups,
        config.max_seqlen_q,
        config.max_seqlen_kv,
Tim Moon's avatar
Tim Moon committed
118
119
        config.head_dim,
    )
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    if backend == FusedAttnBackend["FP8"]:
        backends.append(backend)
        return True, backends
    if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
        backends.append(backend)
        return True, backends
    if backend == FusedAttnBackend["F16_max512_seqlen"]:
        backends.append(backend)
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
        backend = tex.get_fused_attn_backend(
            TE_DType[dtype],
            TE_DType[dtype],
            QKVLayout[qkv_layout],
            AttnBiasType[config.attn_bias_type],
            AttnMaskType[config.attn_mask_type],
            config.dropout_p,
            config.num_heads,
            config.num_gqa_groups,
            config.max_seqlen_q,
            config.max_seqlen_kv,
            config.head_dim,
        )
        if backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
            backends.append(backend)
        return True, backends
    return False, backends

@functools.cache
def _is_flash_attention_2_available() -> bool:
    """Check if flash-attn 2.0+ is available"""
    Version = packaging.version.Version
    return Version(version("flash-attn")) >= Version("2")

@functools.cache
def _is_flash_attention_2_1() -> bool:
    """Check if flash-attn 2.0+ is available"""
    Version = packaging.version.Version
    return Version(version("flash-attn")) >= Version("2.1")

def _is_flash_attention_supported(config: ModelConfig) -> bool:
    """Check if FlashAttention supports a model configuration"""
Tim Moon's avatar
Tim Moon committed
161
162
    if get_device_compute_capability() < (8, 0):
        return False
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    if config.attn_bias_type != "no_bias":
        return False
    if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
        return False
    if "causal" in config.attn_mask_type and config.attn_type == "cross":
        if _is_flash_attention_2_1():
            # FAv2.1 implements causal mask for cross attention differently
            # https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag
            return False
    return True

def _is_unfused_attention_supported(config: ModelConfig) -> bool:
    """Check if UnfusedDotProductAttention supports a model configuration"""
    if ("padding" in config.attn_mask_type):
        return False
    if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
Tim Moon's avatar
Tim Moon committed
179
180
181
        return False
    return True

182
183
184
185
186
187
188
189
190
191
192
193
194
195
model_configs_base = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias   # attn , backend
    "base_1_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0, "no_mask", "no_bias"), # self , 0
    "base_1_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0, "no_mask", "no_bias"), # cross, 0
    "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
    "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
}

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
196
@pytest.mark.parametrize("dtype", param_types)
197
198
199
200
201
202
203
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout):
    """Test DotProductAttention module"""
204

Tim Moon's avatar
Tim Moon committed
205
206
207
208
    # Get configs
    tols = dict(atol=5e-3, rtol=5e-3)
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)
209
210
211
212
213
214
215
216
217
218
    config = model_configs[model]
    if qkv_layout is None:
        if config.attn_type == "self":
            qkv_layout = "sb3hd"
        else:
            qkv_layout = "sbhd_sb2hd"
    if "3" in qkv_layout and config.attn_type == "cross":
        pytest.skip(
            "No need to test this layout for cross attention"
        )
Tim Moon's avatar
Tim Moon committed
219
220

    # Skip if only unfused backend is supported
221
222
223
224
225
    unfused_attn_supported = _is_unfused_attention_supported(config)
    if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
        config, dtype, qkv_layout=qkv_layout,
Tim Moon's avatar
Tim Moon committed
226
    )
227
228
229
    flash_attn_supported = _is_flash_attention_supported(config)
    if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
230
231

    # UnfusedDotProductAttention backend
232
233
234
235
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
            dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt,
        )
Tim Moon's avatar
Tim Moon committed
236
237
238

    # FusedAttention backend
    if fused_attn_supported:
239
240
241
242
243
244
245
246
247
248
249
250
251
        if len(fused_attn_backend) == 1:
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
            )
        if len(fused_attn_backend) == 2:
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
            fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
            )
            os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
            fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
                dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
            )
252

Tim Moon's avatar
Tim Moon committed
253
254
255
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
256
            dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt,
Tim Moon's avatar
Tim Moon committed
257
        )
258

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    if unfused_attn_supported and fused_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i,_ in enumerate(unfused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
    if unfused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        for i,_ in enumerate(flash_attn_bwd):
            torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        for i,_ in enumerate(flash_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
    if fused_attn_supported and len(fused_attn_backend) == 2:
        torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
        for i,_ in enumerate(fused_attn_bwd):
            torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
    """Test DotProductAttention module with checkpointing"""
    test_dot_product_attention(dtype, model_configs, model, True, True, None)

model_configs_mask = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,      bias
    "mask_1_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0,         "causal", "no_bias"),
    "mask_1_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0,         "causal", "no_bias"),
    "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,         "causal", "no_bias"),
    "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,         "causal", "no_bias"),
    "mask_3_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0,        "padding", "no_bias"),
    "mask_3_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0,        "padding", "no_bias"),
    "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,        "padding", "no_bias"),
    "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,        "padding", "no_bias"),
    "mask_5_0": ModelConfig(8, 16, 16,  64,  128,  128, 0.0, "padding_causal", "no_bias"),
    "mask_5_1": ModelConfig(4, 16, 16,  64,  128,  256, 0.0, "padding_causal", "no_bias"),
    "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
    "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
    """Test DotProductAttention module with different mask types"""
    test_dot_product_attention(dtype, model_configs, model, False, True, None)

model_configs_bias = {
    #     test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
    "bias_1_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "no_mask", "post_scale_bias"),
    "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "no_mask",           "alibi"), # skipped
    "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "no_mask",           "alibi"), # skipped
    "bias_2_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "padding", "post_scale_bias"), # skipped
    "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,        "padding",           "alibi"), # skipped
    "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,        "padding",           "alibi"), # skipped
    "bias_3_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,         "causal", "post_scale_bias"),
    "bias_3_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0,         "causal", "post_scale_bias"),
    "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,         "causal", "post_scale_bias"),
    "bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,         "causal", "post_scale_bias"), # skipped
    "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,         "causal",           "alibi"),
    "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0,         "causal",           "alibi"), # skipped
    "bias_4_0": ModelConfig(4, 16, 16,  64,  128,  128, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_1": ModelConfig(2, 16, 16,  64,  128,  256, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
    "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal",           "alibi"), # skipped
    "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal",           "alibi"), # skipped
}
335

336
337
338
339
340
341
342
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
    """Test DotProductAttention module with different bias types"""
    test_dot_product_attention(dtype, model_configs, model, False, True, None)
343

344
345
346
347
348
349
350
qkv_layouts = [
    'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
    'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
    # will add tests for thd layouts later when the support is available in fused attention
    #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
    ]

351
352
353
354
355
356
357
358
359
360
361
362
363
model_configs_layout = {
    #       test:             b,  h, hg,   d,   sq,  skv,   p,             mask,             bias
    "layout_0_0": ModelConfig(2, 16, 16,  64,  128,  128, 0.0,        "no_mask",         "no_bias"),
    "layout_0_1": ModelConfig(2, 16, 16,  64,  128,  128, 0.0,         "causal", "post_scale_bias"),
    "layout_0_2": ModelConfig(1, 16, 16,  64,  128,  256, 0.0,        "padding",         "no_bias"),
    "layout_0_3": ModelConfig(1, 16, 16,  64,  128,  256, 0.0, "padding_causal", "post_scale_bias"),
    "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,        "no_mask",         "no_bias"),
    "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,         "causal", "post_scale_bias"),
    "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0,        "padding",         "no_bias"),
    "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
}

@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
364
@pytest.mark.parametrize("dtype", param_types_lean)
365
366
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
367
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
368
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
369
    """Test DotProductAttention module with different QKV layouts"""
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout)

def _run_dot_product_attention(
        dtype: torch.dtype,
        config: ModelConfig,
        backend: str,
        ckpt_attn: bool,
        qkv_layout: str,
        workspace_opt: bool,
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Run DotProductAttention module with one forward pass and one backward pass"""

    # Set RNG and environment varables
    reset_rng_states()
384
385
386
387
388
389
390
391
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    # Create seqlens
    qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
    if "padding" in config.attn_mask_type or qkv_format == 'thd':
        if config.attn_type == 'self':
            seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                    dtype=torch.int32, device="cuda")
            seqlens_kv = seqlens_q
        if config.attn_type == 'cross':
            seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                    dtype=torch.int32, device="cuda")
            seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
                    dtype=torch.int32, device="cuda")
    else:
        seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
                dtype=torch.int32, device="cuda")
        seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
                dtype=torch.int32, device="cuda")
    cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
    cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)

    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
        if config.attn_type == 'self':
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
                attention_mask_q = torch.cat([attention_mask_q,
                    torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
            attention_mask = attention_mask_q.to(device="cuda")
        if config.attn_type == 'cross':
            attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
            attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
            for i in range(config.batch_size):
                attention_mask_q = torch.cat([attention_mask_q,
                    torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
                attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
                    [True]*seqlens_kv[i] + [False]*(config.max_seqlen_kv-seqlens_kv[i]))
                    .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
            attention_mask = (
                    attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))

    # Create input tensors
    dim_to_num = {
        'b'  : config.batch_size,
        'sq' : config.max_seqlen_q,
        'skv': config.max_seqlen_kv,
        'h'  : config.num_heads,
        'hg' : config.num_gqa_groups,
        'd'  : config.head_dim,
        't'  : cu_seqlens_q[-1],
        'tg' : cu_seqlens_kv[-1],
        '3'  : 3,
        '2'  : 2,
        }
450
451
    inp = []
    for i,layout in enumerate(qkv_layout.split('_')):
452
453
454
455
456
457
458
459
460
        layout = '_'.join(layout)
        if i == 0:
            layout = layout.replace('s', 'sq')
        else:
            layout = layout.replace('s', 'skv')
            layout = layout.replace('h', 'hg')
            layout = layout.replace('t', 'tg')
        tensor_shape = [dim_to_num[j] for j in layout.split('_')]
        tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
461
462
        tensor_count = 1
        split_dim = 0
463
464
465
466
467
468
        for dim, l in enumerate(layout.split('_')):
            if l.isdigit():
                tensor_count = int(l)
                split_dim = dim
                break
        tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
469
470
471
472
473
474
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
475
        inp[i].requires_grad = True
476

477
478
479
480
481
482
    # Create output gradient
    qkv_format_kv = '_'.join(qkv_format)
    qkv_format_kv = qkv_format_kv.replace('s', 'sq')
    out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
    out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
    out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    # Create bias
    if config.attn_bias_type in ['no_bias', 'alibi']:
        bias = None
    if config.attn_bias_type == 'post_scale_bias':
        bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
                dtype=dtype, device="cuda")

    # Create RNG
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

    # Set up model
499
500
    block = (
         DotProductAttention(
501
                config.num_heads,
502
                config.head_dim,
503
504
                num_gqa_groups=config.num_gqa_groups,
                attention_dropout=config.dropout_p,
505
                qkv_format=qkv_format,
506
507
508
509
510
511
512
513
514
                attn_mask_type=config.attn_mask_type,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                tp_group=None,
                layer_number=1,
                attention_type=config.attn_type,
        ).to(dtype=dtype, device="cuda")
    )
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    # Run a forward and backward pass
    out = block(inp[0], inp[1], inp[2],
            attention_mask=attention_mask,
            qkv_format=qkv_format,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            attn_mask_type=config.attn_mask_type,
            checkpoint_core_attention=ckpt_attn,
            core_attention_bias_type=config.attn_bias_type,
            core_attention_bias=bias,
            fast_zero_fill=True)
    out.backward(out_grad)

    return out, (inp[0].grad, inp[1].grad, inp[2].grad)

model_configs_te_layer = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,      mask,             bias
    "te_1_0": ModelConfig(2, 16, 16,  64,  128,  128, 0.0, "no_mask", "post_scale_bias"),
    "te_1_1": ModelConfig(4, 16, 16,  64,  128,  128, 0.0,  "causal", "post_scale_bias"),
    "te_1_2": ModelConfig(2, 16, 16,  64,  128,  128, 0.0, "padding", "post_scale_bias"),
    "te_2_0": ModelConfig(1, 16, 16,  64, 2048, 2048, 0.0,  "causal",         "no_bias"),
    "te_2_1": ModelConfig(2, 16, 16,  64, 2048, 2048, 0.0, "no_mask",         "no_bias"),
    "te_2_2": ModelConfig(1, 16, 16,  64, 2048, 2048, 0.0, "padding",         "no_bias"),
}
540

541
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
542
@pytest.mark.parametrize("dtype", param_types)
543
544
545
546
547
548
549
550
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
    """Test TransformerLayer module"""
551

Tim Moon's avatar
Tim Moon committed
552
    # Get configs
553
    config = model_configs[model]
Tim Moon's avatar
Tim Moon committed
554
    tols = dict(atol=5e-1, rtol=5e-2)
555
    workspace_opt = True
556

Tim Moon's avatar
Tim Moon committed
557
    # Skip if only unfused backend is supported
558
559
560
    if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
        os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
Tim Moon's avatar
Tim Moon committed
561
562
563
564
        config,
        dtype,
        qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
    )
565
566
567
568
    flash_attn_supported = _is_flash_attention_supported(config)
    unfused_attn_supported = _is_unfused_attention_supported(config)
    if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
        pytest.skip("Less than two backends to compare.")
Tim Moon's avatar
Tim Moon committed
569
570

    # UnfusedDotProductAttention backend
571
572
573
574
575
576
577
578
579
580
581
    if unfused_attn_supported:
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "UnfusedDotProductAttention",
            ckpt_attn,
            qkv_format,
            workspace_opt,
            fused_qkv_params,
            RoPE,
        )
Tim Moon's avatar
Tim Moon committed
582
583
584
585
586
587
588

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FusedAttention",
589
590
591
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
592
593
594
            fused_qkv_params,
            RoPE,
        )
595

Tim Moon's avatar
Tim Moon committed
596
597
598
599
600
601
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            config,
            "FlashAttention",
602
603
604
            ckpt_attn,
            qkv_format,
            workspace_opt,
Tim Moon's avatar
Tim Moon committed
605
606
607
            fused_qkv_params,
            RoPE,
        )
608
609
610
611
612

    if unfused_attn_supported and fused_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
    if unfused_attn_supported and flash_attn_supported:
Tim Moon's avatar
Tim Moon committed
613
614
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
615
616
617
    if fused_attn_supported and flash_attn_supported:
        torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
618

619
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
620
@pytest.mark.parametrize("dtype", param_types_lean)
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
def test_te_layer_misc(dtype, model_configs, model):
    """Test TransformerLayer module with miscellanous settings"""
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    test_transformer_layer(dtype, model_configs, model,
            ckpt_attn, qkv_format, fused_qkv_params, RoPE)

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
    """Test TransformerLayer module with MQA/GQA"""
638
639
    def find_factors(x):
       f = []
640
       for i in range(2, x + 1):
641
642
643
644
           if x % i == 0:
               f.append(i)
       return f

645
646
647
648
649
650
    ckpt_attn = True
    qkv_format = "bshd"
    fused_qkv_params = True
    RoPE = True
    config = model_configs[model]
    num_querys_per_gqa_group = find_factors(config.num_heads)
651
652

    for num_q_per_gqa_group in num_querys_per_gqa_group:
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
        config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
        test_transformer_layer(dtype, model_configs, model,
                ckpt_attn, qkv_format, fused_qkv_params, RoPE)

def _run_transformer_layer(
        dtype: torch.dtype,
        config: ModelConfig,
        backend: str,
        ckpt_attn: bool,
        qkv_layout: str,
        workspace_opt: bool,
        fused_qkv_params: bool,
        RoPE: bool,
        ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
    """Run TransformerLayer module with one forward pass and one backward pass"""

    # Set RNG and environment variables
670
    reset_rng_states()
671
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
672
    os.environ["NVTE_FUSED_ATTN"] = "0"
673
674
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
675
676
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    # Create input tensor
    inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
            dtype=dtype, device="cuda", requires_grad = True)

    # Create seqlens
    if "padding" in config.attn_mask_type:
        seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
                dtype=torch.int32, device="cuda")
    else:
        seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
                dtype=torch.int32, device="cuda")

    # Create attention mask if padding
    attention_mask = None
    if "padding" in config.attn_mask_type:
        attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
        for i in range(config.batch_size):
            attention_mask_q = torch.cat([attention_mask_q,
                torch.Tensor([True]*seqlens_q[i] + [False]*(config.max_seqlen_q-seqlens_q[i]))
                .to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
        attention_mask = attention_mask_q.to(device="cuda")
699
700
701
702
703
704
705
706
707
708

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
    # Create bias
    if config.attn_bias_type == 'no_bias':
        bias = None
    if config.attn_bias_type == 'post_scale_bias':
        bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
                dtype=dtype, device="cuda")
    elif config.attn_bias_type == 'alibi':
        if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
            config.attn_bias_type = 'post_scale_bias'
            n = 2 ** math.floor(math.log2(config.num_heads))
            m_0 = 2.0 ** (-8.0 / n)
            m = torch.pow(m_0, torch.arange(1, 1 + n))

            a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
            b = torch.triu(a,diagonal=1)
            c = b.cumsum(dim=-1)
            d = c - torch.transpose(c, 0, 1)
            bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
            for i in range(config.num_heads):
                bias[0,i,:,:] = m[i] *  bias[0,i,:,:]
            bias = bias.to(dtype=dtype, device="cuda")
        else:
            bias = None

    # Create RoPE
    rotary_pos_emb = None
    if RoPE:
        PE = RotaryPositionEmbedding(dim=config.head_dim)
        rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")

    # Set up model
740
741
742
743
    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
744
745
            config.num_heads,
            num_gqa_groups=config.num_gqa_groups,
746
747
748
749
750
751
752
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
753
            self_attn_mask_type=config.attn_mask_type,
754
            tp_group=None,
755
            tp_size=1,
756
757
758
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
759
760
            seq_length=config.max_seqlen_q,
            micro_batch_size=config.batch_size,
761
762
763
764
765
766
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
767
            fuse_qkv_params=fused_qkv_params,
768
769
770
771
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
772
        )
773
        .to(dtype=dtype, device="cuda")
774
775
    )

776
777
778
779
780
781
782
783
784
785
786
787
    # Run a forward and backward pass
    out = block(inp,
        attention_mask=attention_mask,
        self_attn_mask_type=config.attn_mask_type,
        checkpoint_core_attention=False,
        rotary_pos_emb=rotary_pos_emb,
        core_attention_bias_type=config.attn_bias_type,
        core_attention_bias=bias)
    loss = out.sum()
    loss.backward()

    return out, inp.grad
788
789


790
model_configs_fp8 = {
791
792
793
    #  test:             b,  h, hg,   d,   sq,  skv,   p,      mask,      bias
    "fp8_1": ModelConfig(1, 16, 16,  64,  512,  512, 0.0, "no_mask", "no_bias"),
    "fp8_2": ModelConfig(4, 16, 16,  64,  512,  512, 0.0, "no_mask", "no_bias"),
794
795
796
}
param_types_fp8 = [torch.float16]

797
@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
798
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
799
@pytest.mark.skipif(get_device_compute_capability != (9, 0), reason="FP8 tests require Hopper.")
800
801
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
802
803
def test_dpa_fp8(dtype, model):
    """Test FP8 dot product attention
Tim Moon's avatar
Tim Moon committed
804

805
806
807
    FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions,
    and UnfusedDotProductAttention uses plain PyTorch operations in FP16
    and converts inputs/outputs from/to FP8.
Tim Moon's avatar
Tim Moon committed
808
809

    """
810
811
812

    config = model_configs_fp8[model]

Tim Moon's avatar
Tim Moon committed
813
    # Skip if not supported
814
815
816
    fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
        config, dtype)
    if not fused_attn_supported:
Tim Moon's avatar
Tim Moon committed
817
818
819
        pytest.skip("FusedAttention does not support this model config")

    # Run dot-product attention with different backends
820
    fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
821
        dtype, config, "FusedAttention")
822
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
823
        dtype, config, "UnfusedDotProductAttention")
824

Tim Moon's avatar
Tim Moon committed
825
826
827
    tols = dict(atol=2.5e-2, rtol=2.5e-2)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
828

829
830
831
def _run_dpa_fp8(dtype, config, backend):
    """Run FusedAttention FP8 backend, i.e.
    fused_attn_fwd/bwd_qkvpacked from cpp_extensions"""
832

833
    reset_rng_states()
834
835
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
836
837
838
839
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
840
841

    inp = 0.01 * torch.randn(
842
843
844
845
846
            config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
            dtype=dtype, device="cuda", requires_grad=True)
    seqlens = torch.full([config.batch_size], config.max_seqlen_q,
            dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
847
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
848
849
850
851
    out_grad = 0.01 * torch.randn(
            config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
            dtype=dtype, device="cuda")
    torch.save(out_grad, 'out_grad.pt')
852
853
854
855
856
857
858
859
860

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        interval=1,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

861
    dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda")
862
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
863
864
        out = dpa(inp, cu_seqlens, config.max_seqlen_q)
        out.backward(out_grad)
865
866
867

    context = torch.load("ctx.pt")
    dqkv = torch.load('dqkv.pt')
868
869
870
    return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1),
            dqkv.view(config.batch_size, config.max_seqlen_q, 3,
            config.num_heads, config.head_dim).transpose(0,1).contiguous())
871

872
873
874
875
def _run_dpa_fp8_ref(dtype, config, backend):
    """Run UnfusedDotProductAttention as a reference, i.e.
    plain PyTorch implementation in FP16 and inputs/outputs
    are converted from/to FP8"""
876
877
878
879
880
881
882
883

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

884
885
886
887
    inp = torch.load('qkv.pt').to(device="cuda")
    inp.requires_grad = True
    seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
888
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
889
890
891
892
893
894
895
896
897
    out_grad = torch.load('out_grad.pt').to(device="cuda").view(
            config.batch_size, config.max_seqlen_q, -1).transpose(0,1)

    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER
898

899
900
901
902
903
904
905
    _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
    _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)

    def get_dummy_cuda_rng_tracker():
        """Get cuda rng tracker."""
        return _DUMMY_CUDA_RNG_STATE_TRACKER

906
907
    block = (
         DotProductAttention(
908
                config.num_heads,
909
                config.head_dim,
910
911
912
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
913
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
914
915
916
                tp_group=None,
                layer_number=1,
                attention_type="self"
917
        ).to(dtype=dtype, device="cuda")
918
919
920
921
922
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
923
924
925
926
    out = block(q, k, v, attn_mask_type=config.attn_mask_type)
    out.backward(out_grad)

    return out, inp.grad
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949


_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1

META_S    = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS   = tex.FP8BwdTensors.GRAD_INPUT3

class _dpa_fp8(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
950
        num_heads: int,
951
952
953
954
955
956
957
958
959
960
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
    ) -> torch.Tensor:

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
961
        h = num_heads
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        d = in_features // h
        b = cu_seqlens.numel() - 1
        is_nl = False
        if b < 4 and b > 1:
            max_s = 512
            is_nl = True

        fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

        inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
            inp,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
        )

        qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
            qkv_weight,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
        )

        M = None
        ZInv = None
        philox_unpacked = None

989
        qkv_out, _ = ext.fp8_gemm(
990
991
992
993
994
995
996
997
998
999
1000
1001
            qkv_weight_fp8,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
            inputmat,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
            torch.uint8,
            workspace,
            bias=qkv_bias,
            use_bias=True,
1002
1003
            out_index=META_QKV,
            fp8_meta_tensor=fp8_meta["scaling_fwd"],
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
            use_split_accumulator=_2X_ACC_FPROP,
            D_dtype=fp8_dtype_forward,
        )
        qkv_out = qkv_out.view(-1, 3, h, d)
        qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
                META_QKV, fp8_dtype_forward,
                tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
        torch.save(qkv_out_fp16, 'qkv.pt')

        # FMHA
1014
        context_, aux_ctx_tensors, *rest = fused_attn_fwd(
1015
1016
                is_training,
                max_s,
1017
                max_s,
1018
                cu_seqlens,
1019
1020
1021
1022
                cu_seqlens,
                qkv_out[:,0,:,:],
                qkv_out[:,1,:,:],
                qkv_out[:,2,:,:],
1023
1024
1025
1026
1027
1028
1029
1030
                fp8_dtype_forward,
                FusedAttnBackend["FP8"],
                None,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
1031
1032
1033
                attn_scale=None,
                dropout=p_dropout,
                fast_zero_fill=fast_zero_fill,
1034
                qkv_layout="t3hd",
1035
1036
1037
                attn_bias_type="no_bias",
                attn_mask_type="padding",
                rng_gen=None,
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
                )
        M, ZInv, philox_unpacked = aux_ctx_tensors

        context = context_.view(-1, in_features)
        context_t = tex.fp8_transpose(context, fp8_dtype_forward)

        ctx.save_for_backward(
            inputmat_t, qkv_weight_t_fp8, workspace,
            qkv_out,
            context_, context_t,
            fp8_meta["scaling_fwd"].scale,
            fp8_meta["scaling_fwd"].scale_inv,
        )
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.is_nl = is_nl
        ctx.hidden_size = in_features
1059
        ctx.num_heads = num_heads
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

        context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
                META_O, fp8_dtype_forward, tex.DType.kFloat16)
        torch.save(context_fp16, 'ctx.pt')
        return context_fp16


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
            (
                inputmat_t,
                qkv_weight_t_fp8,
                workspace,
                qkv_out,
                context, context_t,
                fwd_scales,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
            fp8_dtype_forward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=True
            )
            fp8_dtype_backward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=False
            )

            proj_dgrad = ext.cast_to_fp8(
                grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
            )

1093
            dq, dk, dv, *rest = fused_attn_bwd(
1094
                    ctx.max_s,
1095
1096
                    ctx.max_s,
                    ctx.cu_seqlens,
1097
                    ctx.cu_seqlens,
1098
1099
1100
                    qkv_out[:,0,:,:],
                    qkv_out[:,1,:,:],
                    qkv_out[:,2,:,:],
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
                    context,
                    proj_dgrad.view_as(context),
                    fp8_dtype_forward,
                    ctx.aux_ctx_tensors,
                    FusedAttnBackend["FP8"],
                    fwd_scale_inverses[META_QKV], # d_scale_qkv,
                    fwd_scale_inverses[META_S], # d_scale_s,
                    fwd_scale_inverses[META_O], # d_scale_o,
                    ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                    fwd_scales[META_S], # q_scale_s
                    ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
                    ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                    None,
                    ctx.p_dropout,
                    ctx.fast_zero_fill,
1118
                    "t3hd",
1119
1120
1121
                    "no_bias",
                    "padding",
                    )
1122
            dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

            dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
            dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                fp8_dtype_backward, tex.DType.kFloat16)
            torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')

            qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"],
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
            )

            # QKV DGRAD
1139
            qkv_dgrad, _ = ext.fp8_gemm(
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
                qkv_weight_t_fp8,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_DGRAD,
            )
            # QKV WGRAD
1153
            qkv_wgrad, _ = ext.fp8_gemm(
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
                inputmat_t,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
                dqkv_grad_output_t,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_WGRAD,
            )

        return (qkv_dgrad,
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None)

class DPA_FP8(TransformerEngineBaseModule):
    def __init__(
        self,
        config,
        params_dtype: torch.dtype = torch.float32):
        super().__init__()
        self.p_dropout = config.dropout_p
1188
        self.h = config.num_heads
1189
1190
1191
1192
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.fast_zero_fill = True

Tim Moon's avatar
Tim Moon committed
1193
        self.qkv_weight = torch.nn.Parameter(
1194
1195
1196
1197
1198
1199
1200
1201
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.qkv_weight.shape)
Tim Moon's avatar
Tim Moon committed
1202
        self.qkv_bias = torch.nn.Parameter(
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
        self, inp: torch.Tensor,
        cu_seqlens, max_s,
    ) -> torch.Tensor:
        with self.prepare_forward(inp, None, num_gemms=3) as inp:
            out = _dpa_fp8.apply(
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
                self.training)
        return out

    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""