test_fused_attn.py 36.2 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Tim Moon's avatar
Tim Moon committed
5
6
7
8
9
from importlib.metadata import version
import os
from typing import Any, Dict, List, Tuple, Union

from pkg_resources import packaging
10
import pytest
Tim Moon's avatar
Tim Moon committed
11
import torch
12

Tim Moon's avatar
Tim Moon committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast
from transformer_engine.pytorch.attention import (
    DotProductAttention,
    RotaryPositionEmbedding,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
    AttnBiasType,
    AttnMaskType,
    FusedAttnBackend,
    QKVLayout,
    fused_attn_bwd,
    fused_attn_fwd,
    fused_attn_bwd_qkvpacked,
    fused_attn_fwd_qkvpacked,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
    TransformerEngineBaseModule,
    _prepare_backward,
)
36
from transformer_engine.pytorch.utils import (
Tim Moon's avatar
Tim Moon committed
37
    get_device_compute_capability,
38
39
40
    init_method_normal,
    scaled_init_method_normal,
)
Tim Moon's avatar
Tim Moon committed
41
import transformer_engine_extensions as tex
42

43
from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
Tim Moon's avatar
Tim Moon committed
44
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
45
46
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
47
48
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class ModelConfig:
    def __init__(
        self, num_layers, hidden_size, num_attention_heads, head_dim, seq_len,
        dropout_p, attn_mask_type,
    ):
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = head_dim
        assert (hidden_size == num_attention_heads * head_dim
                ), """hidden_size must be = num_heads x head_dim."""
        self.seq_len = seq_len
        self.dropout_p = dropout_p
        self.attn_mask_type  = attn_mask_type

model_configs = {
    "test1": ModelConfig(1, 1024, 16, 64, 128, 0.0, "causal"),
67
68
69
70
    "test2": ModelConfig(1, 1024, 16, 64, 2048, 0.0, "causal"),
    "test3": ModelConfig(1, 2048, 16, 128, 128, 0.0, "causal"),
    "test4": ModelConfig(1, 3072, 24, 128, 2048, 0.0, "causal"),
    "test5": ModelConfig(1, 1024, 16, 64, 128, 0.0, "no_mask"),
71
72
73
74
75
76
}

param_types = [torch.float16]
if torch.cuda.is_bf16_supported():
    param_types.append(torch.bfloat16)

77
78
79
80
81
82
83
84
85
86
87
batch_sizes = [1, 32]

model_configs_lean = {
    "test6": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
    "test7": ModelConfig(1, 2048, 16, 128, 2048, 0.0, "causal"),
}

param_types_lean = [torch.bfloat16]

batch_sizes_lean = [2]

88

Tim Moon's avatar
Tim Moon committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def _is_fused_attention_supported(
    config: ModelConfig,
    dtype: torch.dtype,
    qkv_layout: str = "sbh3d",
    bias_type: str = "no_bias",
) -> bool:
    backend = tex.get_fused_attn_backend(
        TE_DType[dtype],
        TE_DType[dtype],
        QKVLayout[qkv_layout],
        AttnBiasType[bias_type],
        AttnMaskType[config.attn_mask_type],
        config.dropout_p,
        config.seq_len,
        config.seq_len,
        config.head_dim,
    )
    return backend != FusedAttnBackend["No_Backend"]

def _is_flash_attention_supported(bias_type: str = "no_bias") -> bool:
    if get_device_compute_capability() < (8, 0):
        return False
    if bias_type != "no_bias":
        return False
    return True

115
@pytest.mark.parametrize("dtype", param_types)
116
@pytest.mark.parametrize("bs", batch_sizes_lean)
117
@pytest.mark.parametrize("model", model_configs.keys())
118
119
120
@pytest.mark.parametrize("ckpt_attn", [True, False])
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type):
Tim Moon's avatar
Tim Moon committed
121
    """Test DotProductAttention module with different backends"""
122

Tim Moon's avatar
Tim Moon committed
123
    # Get configs
124
    config = model_configs[model]
Tim Moon's avatar
Tim Moon committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    tols = dict(atol=5e-3, rtol=5e-3)
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)

    # Skip if only unfused backend is supported
    fused_attn_supported = _is_fused_attention_supported(
        config,
        dtype,
        bias_type=bias_type,
    )
    flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
    if not (fused_attn_supported or flash_attn_supported):
        pytest.skip(
            "Neither FusedAttention nor FlashAttention support this model config"
        )

    # UnfusedDotProductAttention backend
142
    unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
Tim Moon's avatar
Tim Moon committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        dtype,
        bs,
        config,
        "UnfusedDotProductAttention",
        ckpt_attn,
        bias_type,
    )

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
            dtype,
            bs,
            config,
            "FusedAttention",
            ckpt_attn,
            bias_type,
        )
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
163

Tim Moon's avatar
Tim Moon committed
164
165
166
167
168
169
170
171
172
173
174
175
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
            dtype,
            bs,
            config,
            "FlashAttention",
            ckpt_attn,
            bias_type,
        )
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
176

177
def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type):
178

179
    reset_rng_states()
180
181
182
183
184
185
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
186
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1"
187

188
    inp = torch.randn(
189
            config.seq_len, bs, 3, config.num_attention_heads, config.head_dim,
190
            dtype=dtype).cuda()
191
    inp.requires_grad=True
192
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
193
    seqlens.fill_(config.seq_len)
194
195
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
196
197
198
199
200
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
        dtype = dtype).cuda()
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
201
                dtype=dtype).cuda()
202
203
    else:
        bias = None
204
205
206
207
208

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
209
210
211
212
213
214
215
216
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
217
218
219
220
221
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
222
223
224
225
226
    op = block(q, k, v,
        qkv_format='sbhd',
        cu_seqlens_q = cu_seqlens,
        cu_seqlens_kv = cu_seqlens,
        attn_mask_type=config.attn_mask_type,
227
228
229
        checkpoint_core_attention=ckpt_attn,
        core_attention_bias_type=bias_type,
        core_attention_bias=bias)
230
231
232
233
    op.backward(op_grad)

    return op, inp.grad

234
235
236
237
238
239
240
241
qkv_layouts = [
    'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
    'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
    # will add tests for thd layouts later when the support is available in fused attention
    #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd',
    ]

@pytest.mark.skipif(
Tim Moon's avatar
Tim Moon committed
242
    _cudnn_version < [8,9,5], reason="cuDNN 8.9.5+ is required.")
243
244
245
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
246
247
248
249
250
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, bs, model, workspace_opt, qkv_layout):
    """Test DotProductAttention module with different QKV layouts"""

Tim Moon's avatar
Tim Moon committed
251
    # Get configs
252
    config = model_configs_lean[model]
Tim Moon's avatar
Tim Moon committed
253
254
255
256
257
258
259
260
261
262
263
    tols = dict(atol=5e-3, rtol=5e-3)
    if dtype == torch.bfloat16:
        tols = dict(atol=2.5e-2, rtol=2.5e-2)

    # Skip if only unfused backend is supported
    fused_attn_supported = _is_fused_attention_supported(config, dtype)
    flash_attn_supported = _is_flash_attention_supported()
    if not (fused_attn_supported or flash_attn_supported):
        pytest.skip(
            "Neither FusedAttention nor FlashAttention support this model config"
        )
264

Tim Moon's avatar
Tim Moon committed
265
    # UnfusedDotProductAttention backend
266
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_qkv_layout(
Tim Moon's avatar
Tim Moon committed
267
268
269
270
271
272
273
274
275
        dtype, bs, config, "UnfusedDotProductAttention", qkv_layout, workspace_opt)

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_dpa_qkv_layout(
            dtype, bs, config, "FusedAttention", qkv_layout, workspace_opt)
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        for i in range(len(unfused_attn_bwd)):
            torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
276

Tim Moon's avatar
Tim Moon committed
277
278
279
280
281
282
283
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_dpa_qkv_layout(
            dtype, bs, config, "FlashAttention", qkv_layout, workspace_opt)
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        for i in range(len(unfused_attn_bwd)):
            torch.testing.assert_close(flash_attn_bwd[i], unfused_attn_bwd[i], **tols)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

def _run_dpa_qkv_layout(dtype, bs, config, backend, qkv_layout, workspace_opt):

    torch.manual_seed(1234)
    torch.cuda.manual_seed(1234)
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
        os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"

    dim_to_num = {'b': bs,
        's': config.seq_len,
        'h': config.num_attention_heads,
        'd': config.head_dim,
        't': bs * config.seq_len,
        '3': 3,
        '2': 2}

    inp = []
    for i,layout in enumerate(qkv_layout.split('_')):
        tensor_shape = [dim_to_num[j] for j in layout]
        tensor = 0.1 * torch.randn(tensor_shape, dtype = dtype).cuda()
        tensor_count = 1
        split_dim = 0
        for dim,l in enumerate(layout):
             if l.isdigit():
                 tensor_count = int(l)
                 split_dim = dim
                 break
        tensors = torch.split(tensor, 1, dim = split_dim) if split_dim != 0 else [tensor]
        for j in range(tensor_count):
            if split_dim != 0:
                inp.append(tensors[j].squeeze(split_dim))
            else:
                inp.append(tensors[j])
    for i in range(3):
        inp[i].requires_grad=True

    seqlens = torch.empty(bs, dtype = torch.int32).cuda()
    seqlens.fill_(config.seq_len)
    cu_seqlens = torch.zeros(bs + 1, device = inp[0].device, dtype = torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
    qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
    qkv_format_no_thd = qkv_format if qkv_format != 'thd' else 'bshd'
    op_grad_shape = [dim_to_num[i] for i in qkv_format_no_thd]
    op_grad_shape_new = [*op_grad_shape[:-2], op_grad_shape[-2] * op_grad_shape[-1]]
    op_grad = 0.001 * torch.randint(0, 200, op_grad_shape_new, dtype = dtype).cuda()

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
                attention_dropout = config.dropout_p,
                attn_mask_type = config.attn_mask_type,
                sequence_parallel = False,
                tp_size = 1,
                get_rng_state_tracker = None,
                tp_group = None,
                layer_number = 1,
                attention_type = "self"
        ).to(dtype = dtype).cuda()
    )

    if qkv_format != 'thd':
        op = block(inp[0], inp[1], inp[2], qkv_format=qkv_format)
    else:
        cu_seqlens_q = torch.arange(
                0,
                (bs + 1) * config.seq_len,
                step=config.seq_len,
                dtype=torch.int32,
                device=inp[0].device)
        cu_seqlens_kv = torch.arange(
                0,
                (bs + 1) * config.seq_len,
                step=config.seq_len,
                dtype=torch.int32,
                device=inp[1].device)
        op = block(inp[0], inp[1], inp[2],
                qkv_format=qkv_format,
                cu_seqlens_q = cu_seqlens_q,
                cu_seqlens_kv = cu_seqlens_kv)
    op.backward(op_grad)

    return op, (inp[0].grad, inp[1].grad, inp[2].grad)

373
374
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
375
@pytest.mark.parametrize("model", model_configs_lean.keys())
376
@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"])
cyanguwa's avatar
cyanguwa committed
377
@pytest.mark.parametrize("fused_qkv_params", [True, False])
378
379
@pytest.mark.parametrize("RoPE", [True, False])
def test_transformer_layer(dtype, bs, model, bias_type, fused_qkv_params, RoPE):
380
381
382
    """Test TransformerLayer module when its DotProductAttention is enabled with
    FlashAttention, FusedAttention, or UnfusedDotProductAttention backend"""

Tim Moon's avatar
Tim Moon committed
383
    # Get configs
384
    config = model_configs_lean[model]
Tim Moon's avatar
Tim Moon committed
385
    tols = dict(atol=5e-1, rtol=5e-2)
386

Tim Moon's avatar
Tim Moon committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    # Skip if only unfused backend is supported
    fused_attn_supported = _is_fused_attention_supported(
        config,
        dtype,
        qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
        bias_type=bias_type,
    )
    flash_attn_supported = _is_flash_attention_supported(bias_type=bias_type)
    if not (fused_attn_supported or flash_attn_supported):
        pytest.skip(
            "Neither FusedAttention nor FlashAttention support this model config"
        )

    # UnfusedDotProductAttention backend
401
    unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
Tim Moon's avatar
Tim Moon committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        dtype,
        bs,
        config,
        "UnfusedDotProductAttention",
        bias_type,
        fused_qkv_params,
        RoPE,
    )

    # FusedAttention backend
    if fused_attn_supported:
        fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
            dtype,
            bs,
            config,
            "FusedAttention",
            bias_type,
            fused_qkv_params,
            RoPE,
        )
        torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
424

Tim Moon's avatar
Tim Moon committed
425
426
427
428
429
430
431
432
433
434
435
436
437
    # FlashAttention backend
    if flash_attn_supported:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
            dtype,
            bs,
            config,
            "FlashAttention",
            bias_type,
            fused_qkv_params,
            RoPE,
        )
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
438

439
def _run_transformer_layer(dtype, bs, config, backend, bias_type, fused_qkv_params, RoPE):
440

441
    reset_rng_states()
442
443
444
445
446
447
448
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

449
    inp = torch.randn(
450
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
451
            dtype=dtype).cuda()
452
    inp.requires_grad=True
453
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
454
    seqlens.fill_(config.seq_len)
455
456
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
457
458
459
460
461
462
463
464
465

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
466
467
    if bias_type != "no_bias":
        bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len,
468
                dtype=dtype).cuda()
469
470
    else:
        bias = None
471

472
473
474
475
476
    rotary_pos_emb = None
    if RoPE:
        PE = RotaryPositionEmbedding(dim=config.head_dim)
        rotary_pos_emb = PE(config.seq_len).cuda().to(dtype=dtype)

477
478
479
480
481
    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size=1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
cyanguwa's avatar
cyanguwa committed
502
            fuse_qkv_params=fused_qkv_params,
503
504
505
506
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
507
        )
508
        .to(dtype=dtype)
509
510
511
        .cuda()
    )

512
    num_iters = 5
513
    for i in range(num_iters):
514
        op = block(inp, self_attn_mask_type=config.attn_mask_type,
515
            rotary_pos_emb=rotary_pos_emb,
516
517
            core_attention_bias_type=bias_type,
            core_attention_bias=bias)
518
519
        loss = op.sum()
        loss.backward()
520
521
522

    return op, inp.grad

523
524
525
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("bs", batch_sizes_lean)
@pytest.mark.parametrize("model", model_configs_lean.keys())
526
527
def test_transformer_layer_gqa(dtype, bs, model):
    """Test TransformerLayer module when its DotProductAttention is enabled with
Tim Moon's avatar
Tim Moon committed
528
    FlashAttention or UnfusedDotProductAttention backend"""
529

530
    config = model_configs_lean[model]
531
532
533
534
535
536
537
    def find_factors(x):
       f = []
       for i in range(1, x + 1):
           if x % i == 0:
               f.append(i)
       return f

Tim Moon's avatar
Tim Moon committed
538
539
540
541
    # Skip if only unfused backend is supported
    if not (_flash_attn_2_available and _is_flash_attention_supported()):
        pytest.skip("FlashAttention does not support this model config")

542
543
544
545
546
547
548
549
    num_querys_per_gqa_group = find_factors(config.num_attention_heads)

    for num_q_per_gqa_group in num_querys_per_gqa_group:
        flash_attn_fwd, flash_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "FlashAttention", num_q_per_gqa_group)
        unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa(
                dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group)

550
        atol, rtol = 5e-1, 5e-2
551
552
        torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, atol=atol, rtol=rtol)
        torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, atol=atol, rtol=rtol)
553
554
555

def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group):

556
    reset_rng_states()
557
    os.environ["NVTE_FLASH_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
558
    os.environ["NVTE_FUSED_ATTN"] = "0"
559
560
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
Tim Moon's avatar
Tim Moon committed
561
562
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
563

564
    inp = torch.randn(
565
            config.seq_len, bs, config.num_attention_heads * config.head_dim,
566
            dtype=dtype).cuda()
567
    inp.requires_grad=True
568
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
569
    seqlens.fill_(config.seq_len)
570
571
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
572
573
    op_grad = torch.randn(
        config.seq_len, bs, config.num_attention_heads * config.head_dim,
574
        dtype=dtype).cuda()
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589

    sigma = 0.02
    init_method = init_method_normal(sigma)
    output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)

    layer_number = 1
    drop_path_rate = 0.0
    drop_path_rates = [
            rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]

    block = (
        TransformerLayer(
            config.hidden_size,
            4 * config.hidden_size,
            config.num_attention_heads,
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            num_gqa_groups=config.num_attention_heads / num_querys_per_gqa_group,
            layernorm_epsilon=1e-5,
            hidden_dropout=0.0,
            attention_dropout=config.dropout_p,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_number=layer_number,
            kv_channels=config.head_dim,
            tp_group=None,
            tp_size= 1,
            params_dtype=dtype,
            get_rng_state_tracker=None,
            fuse_wgrad_accumulation=False,
            seq_length=config.seq_len,
            micro_batch_size=bs,
            sequence_parallel=False,
            apply_residual_connection_post_layernorm=False,
            output_layernorm=False,
            layer_type="encoder",
            drop_path_rate=drop_path_rates[layer_number - 1],
            set_parallel_mode=True,
            fuse_qkv_params=True,
            zero_centered_gamma=False,
            qkv_weight_interleaved=False,
            ub_tp_comm_overlap=False,
            bias=True,
616
        )
617
        .to(dtype=dtype)
618
619
620
        .cuda()
    )

621
    op = block(inp, self_attn_mask_type=config.attn_mask_type)
622
623
624
625
    op.backward(op_grad)

    return op, inp.grad

626
627
628
629
630
631
model_configs_fp8 = {
    "test1": ModelConfig(1, 1024, 16, 64, 512, 0.0, "no_mask"),
}
batch_sizes_fp8 = [1, 4]
param_types_fp8 = [torch.float16]

632
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
633
634
635
636
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("bs", batch_sizes_fp8)
@pytest.mark.parametrize("model", model_configs_fp8.keys())
def test_dpa_fp8(dtype, bs, model):
Tim Moon's avatar
Tim Moon committed
637
638
639
640
641
642
643
    """Test FP8 dot-product attention with different backends

    FusedAttention uses fused_attn_fwd/bwd_qkvpacked from
    cpp_extensions. UnfusedDotProductAttention uses plain PyTorch
    operations.

    """
644
645
646

    config = model_configs_fp8[model]

Tim Moon's avatar
Tim Moon committed
647
648
649
650
651
    # Skip if not supported
    if not _is_fused_attention_supported(config, dtype):
        pytest.skip("FusedAttention does not support this model config")

    # Run dot-product attention with different backends
652
    fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8(
Tim Moon's avatar
Tim Moon committed
653
654
655
656
657
        dtype,
        bs,
        config,
        "FusedAttention"
    )
658
    unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref(
Tim Moon's avatar
Tim Moon committed
659
660
661
662
663
        dtype,
        bs,
        config,
        "UnfusedDotProductAttention",
    )
664

Tim Moon's avatar
Tim Moon committed
665
666
667
668
    # Check that results match
    tols = dict(atol=2.5e-2, rtol=2.5e-2)
    torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
    torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
669
670
671

def _run_dpa_fp8(dtype, bs, config, backend):

672
    reset_rng_states()
673
674
    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
Tim Moon's avatar
Tim Moon committed
675
676
677
678
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"
679
680
681

    inp = 0.01 * torch.randn(
            bs * config.seq_len, config.num_attention_heads * config.head_dim,
682
            dtype=dtype).cuda()
683
    inp.requires_grad=True
684
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
685
    seqlens.fill_(config.seq_len)
686
687
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
688
689
    op_grad = 0.01 * torch.randn(
        bs * config.seq_len, config.num_attention_heads * config.head_dim,
690
        dtype=dtype).cuda()
691
692
693
694
695
696
697
698
699
700
    torch.save(op_grad, 'op_grad.pt')

    fp8_recipe = recipe.DelayedScaling(
        margin=0,
        interval=1,
        fp8_format=recipe.Format.HYBRID,
        amax_history_len=1,
        amax_compute_algo="most_recent",
    )

701
    dpa = DPA_FP8(config).to(dtype=torch.float16).cuda()
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        op = dpa(inp, cu_seqlens, config.seq_len)
        op.backward(op_grad)

    context = torch.load("ctx.pt")
    dqkv = torch.load('dqkv.pt')
    return (context.view(bs, config.seq_len, -1).transpose(0,1),
        dqkv.view(bs, config.seq_len, 3, config.num_attention_heads, config.head_dim).transpose(0,1).contiguous())

def _run_dpa_fp8_ref(dtype, bs, config, backend):

    os.environ["NVTE_FLASH_ATTN"] = "0"
    os.environ["NVTE_FUSED_ATTN"] = "0"
    if backend == "FlashAttention":
        os.environ["NVTE_FLASH_ATTN"] = "1"
    if backend == "FusedAttention":
        os.environ["NVTE_FUSED_ATTN"] = "1"

    inp = torch.load('qkv.pt').cuda()
    inp.requires_grad=True
722
    seqlens = torch.empty(bs, dtype=torch.int32).cuda()
723
    seqlens.fill_(config.seq_len)
724
725
    cu_seqlens = torch.zeros(bs + 1, device=inp.device, dtype=torch.int32)
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
726
727
728
729
730
731
    op_grad = torch.load('op_grad.pt').cuda().view(bs, config.seq_len, -1).transpose(0,1)

    block = (
         DotProductAttention(
                config.num_attention_heads,
                config.head_dim,
732
733
734
                attention_dropout=config.dropout_p,
                sequence_parallel=False,
                tp_size=1,
735
                get_rng_state_tracker=get_dummy_cuda_rng_tracker,
736
737
738
739
                tp_group=None,
                layer_number=1,
                attention_type="self"
        ).to(dtype=dtype).cuda()
740
741
742
743
744
    )

    q = inp[:, :,0,:,:]
    k = inp[:, :,1,:,:]
    v = inp[:, :,2,:,:]
745
    op = block(q, k, v, attn_mask_type=config.attn_mask_type)
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    op.backward(op_grad)

    return op, inp.grad

_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432  # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV  = tex.FP8FwdTensors.GEMM1_OUTPUT
META_O    = tex.FP8FwdTensors.GEMM2_INPUT
META_DO   = tex.FP8BwdTensors.GRAD_INPUT2
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1

META_S    = tex.FP8FwdTensors.GEMM3_WEIGHT
META_DS   = tex.FP8BwdTensors.GRAD_INPUT3

class _dpa_fp8(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        qkv_weight: torch.Tensor,
        qkv_bias: torch.Tensor,
        cu_seqlens: torch.Tensor,
        num_attention_heads: int,
        p_dropout: float,
        max_s: int,
        fast_zero_fill: bool,
        fp8_meta: Dict[str, Any],
        workspace: torch.Tensor,
        is_training: bool,
    ) -> torch.Tensor:

        assert inp.dim() == 2
        in_features = qkv_weight.shape[-1]
        h = num_attention_heads
        d = in_features // h
        b = cu_seqlens.numel() - 1
        is_nl = False
        if b < 4 and b > 1:
            max_s = 512
            is_nl = True

        fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

        inputmat, inputmat_t = ext.fp8_cast_transpose_fused(
            inp,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
        )

        qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused(
            qkv_weight,
            fp8_meta["scaling_fwd"],
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
        )

        M = None
        ZInv = None
        philox_unpacked = None

810
        qkv_out, _ = ext.fp8_gemm(
811
812
813
814
815
816
817
818
819
820
821
822
            qkv_weight_fp8,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_WEIGHT,
            fp8_dtype_forward,
            inputmat,
            fp8_meta["scaling_fwd"].scale_inv,
            tex.FP8FwdTensors.GEMM1_INPUT,
            fp8_dtype_forward,
            torch.uint8,
            workspace,
            bias=qkv_bias,
            use_bias=True,
823
824
            out_index=META_QKV,
            fp8_meta_tensor=fp8_meta["scaling_fwd"],
825
826
827
828
829
830
831
832
833
834
            use_split_accumulator=_2X_ACC_FPROP,
            D_dtype=fp8_dtype_forward,
        )
        qkv_out = qkv_out.view(-1, 3, h, d)
        qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"],
                META_QKV, fp8_dtype_forward,
                tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous()
        torch.save(qkv_out_fp16, 'qkv.pt')

        # FMHA
835
        context_, aux_ctx_tensors, *rest = fused_attn_fwd(
836
837
                is_training,
                max_s,
838
                max_s,
839
                cu_seqlens,
840
841
842
843
                cu_seqlens,
                qkv_out[:,0,:,:],
                qkv_out[:,1,:,:],
                qkv_out[:,2,:,:],
844
845
846
847
848
849
850
851
                fp8_dtype_forward,
                FusedAttnBackend["FP8"],
                None,
                fp8_meta["scaling_fwd"].scale_inv[META_QKV],
                fp8_meta["scaling_fwd"].scale[META_S],
                fp8_meta["scaling_fwd"].scale[META_O],
                fp8_meta["scaling_fwd"].amax_history[0][META_S],
                fp8_meta["scaling_fwd"].amax_history[0][META_O],
852
853
854
                attn_scale=None,
                dropout=p_dropout,
                fast_zero_fill=fast_zero_fill,
855
                qkv_layout="t3hd",
856
857
858
                attn_bias_type="no_bias",
                attn_mask_type="padding",
                rng_gen=None,
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
                )
        M, ZInv, philox_unpacked = aux_ctx_tensors

        context = context_.view(-1, in_features)
        context_t = tex.fp8_transpose(context, fp8_dtype_forward)

        ctx.save_for_backward(
            inputmat_t, qkv_weight_t_fp8, workspace,
            qkv_out,
            context_, context_t,
            fp8_meta["scaling_fwd"].scale,
            fp8_meta["scaling_fwd"].scale_inv,
        )
        ctx.aux_ctx_tensors = aux_ctx_tensors
        ctx.fp8_meta = fp8_meta
        ctx.cu_seqlens = cu_seqlens
        ctx.p_dropout = p_dropout
        ctx.max_s = max_s
        ctx.fast_zero_fill = fast_zero_fill
        ctx.is_nl = is_nl
        ctx.hidden_size = in_features
        ctx.num_attention_heads = num_attention_heads

        context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"],
                META_O, fp8_dtype_forward, tex.DType.kFloat16)
        torch.save(context_fp16, 'ctx.pt')
        return context_fp16


    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
            (
                inputmat_t,
                qkv_weight_t_fp8,
                workspace,
                qkv_out,
                context, context_t,
                fwd_scales,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
            fp8_dtype_forward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=True
            )
            fp8_dtype_backward = fp8.get_fp8_te_dtype(
                ctx.fp8_meta["recipe"], fprop_tensor=False
            )

            proj_dgrad = ext.cast_to_fp8(
                grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
            )

914
            dq, dk, dv, *rest = fused_attn_bwd(
915
                    ctx.max_s,
916
917
                    ctx.max_s,
                    ctx.cu_seqlens,
918
                    ctx.cu_seqlens,
919
920
921
                    qkv_out[:,0,:,:],
                    qkv_out[:,1,:,:],
                    qkv_out[:,2,:,:],
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
                    context,
                    proj_dgrad.view_as(context),
                    fp8_dtype_forward,
                    ctx.aux_ctx_tensors,
                    FusedAttnBackend["FP8"],
                    fwd_scale_inverses[META_QKV], # d_scale_qkv,
                    fwd_scale_inverses[META_S], # d_scale_s,
                    fwd_scale_inverses[META_O], # d_scale_o,
                    ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
                    fwd_scales[META_S], # q_scale_s
                    ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds
                    ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds
                    ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
                    None,
                    ctx.p_dropout,
                    ctx.fast_zero_fill,
939
                    "t3hd",
940
941
942
                    "no_bias",
                    "padding",
                    )
943
            dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1)
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959

            dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size)
            dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"], META_DQKV,
                fp8_dtype_backward, tex.DType.kFloat16)
            torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt')

            qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused(
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"],
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
            )

            # QKV DGRAD
960
            qkv_dgrad, _ = ext.fp8_gemm(
961
962
963
964
965
966
967
968
969
970
971
972
973
                qkv_weight_t_fp8,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
                fp8_dtype_forward,
                dqkv_grad_output_c,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_DGRAD,
            )
            # QKV WGRAD
974
            qkv_wgrad, _ = ext.fp8_gemm(
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                inputmat_t,
                fwd_scale_inverses,
                tex.FP8FwdTensors.GEMM1_INPUT,
                fp8_dtype_forward,
                dqkv_grad_output_t,
                ctx.fp8_meta["scaling_bwd"].scale_inv,
                META_DQKV,
                fp8_dtype_backward,
                torch.float16,
                workspace,
                use_split_accumulator=_2X_ACC_WGRAD,
            )

        return (qkv_dgrad,
            qkv_wgrad,
            qkv_bgrad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None)

class DPA_FP8(TransformerEngineBaseModule):
    def __init__(
        self,
        config,
        params_dtype: torch.dtype = torch.float32):
        super().__init__()
        self.p_dropout = config.dropout_p
        self.h = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        self.fast_zero_fill = True

Tim Moon's avatar
Tim Moon committed
1014
        self.qkv_weight = torch.nn.Parameter(
1015
1016
1017
1018
1019
1020
1021
1022
            torch.empty(
                self.hidden_size * 3,
                self.hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.qkv_weight.shape)
Tim Moon's avatar
Tim Moon committed
1023
        self.qkv_bias = torch.nn.Parameter(
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            torch.empty(
                self.hidden_size * 3,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        with torch.no_grad():
            self.qkv_bias.zero_()
            self.qkv_weight.fill_(1.0)
        self.workspace = torch.empty(
            _CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
        )

    def forward(
        self, inp: torch.Tensor,
        cu_seqlens, max_s,
    ) -> torch.Tensor:
        with self.prepare_forward(inp, None, num_gemms=3) as inp:
            out = _dpa_fp8.apply(
                inp,
                self.qkv_weight,
                self.qkv_bias,
                cu_seqlens,
                self.h,
                self.p_dropout,
                max_s,
                self.fast_zero_fill,
                self.fp8_meta,
                self.workspace,
                self.training)
        return out

    def get_fp8_weights_scratchpad(
        self,
        is_first_microbatch: Union[bool, None],
    ) -> List[torch.Tensor]:
        """Needs override."""