faq.rst 3.48 KB
Newer Older
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    See LICENSE for license information.

Frequently Asked Questions (FAQ)
================================

FP8 checkpoint compatibility
----------------------------

Transformer Engine starts to support FP8 attention in 1.6. It stores the FP8 metadata, i.e. scaling factors and amax histories, under a `._extra_state` key in the checkpoint. As the FP8 attention support expands from one backend to multiple backends, the location of the `._extra_state` key has also shifted.

Here, we take the `MultiheadAttention` module as an example. Its FP8 attention metadata in Transformer Engine 1.11 is stored as `core_attention._extra_state` as shown below.

.. code-block:: python

    >>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init
    >>> with fp8_model_init(enabled=True):
    ...     mha = MultiheadAttention(
    ...         hidden_size=1024,
    ...         num_attention_heads=16,
    ...         bias=True,
    ...         params_dtype=torch.bfloat16,
    ...         input_layernorm=False,
    ...         fuse_qkv_params=True,
    ...         attention_type="self",
    ...         qkv_weight_interleaved=True,
    ...     ).to(dtype=torch.bfloat16, device="cuda")
    ...
    >>> state_dict = mha.state_dict()
    >>> print(state_dict.keys())
    odict_keys(['qkv.weight', 'qkv.bias', 'qkv._extra_state', 'core_attention._extra_state', 'proj.weight', 'proj.bias', 'proj._extra_state'])

Here is a full list of the checkpoint save/load behaviors from all Transformer Engine versions.

.. list-table::

   * - **Version: <= 1.5**

         - Saves no FP8 metadata since FP8 attention is not supported
         - Loading behavior for checkpoints created by the following versions:

             :<= 1.5:    Loads no FP8 metadata
             :>  1.5:    Error: unexpected key
   * - **Version: 1.6, 1.7**

         - Saves FP8 metadata to `core_attention.fused_attention._extra_state`
         - Loading behavior for checkpoints created by the following versions:

             :<= 1.5:    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
             :1.6, 1.7:  Loads FP8 metadata from checkpoint
             :>= 1.8:    Error: unexpected key
   * - **Version: >=1.8, <= 1.11**

         - Saves FP8 metadata to `core_attention._extra_state`
         - Loading behavior for checkpoints created by the following versions:

             :<= 1.5:    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
             :1.6, 1.7:  This save/load combination relies on users to map the 1.6/1.7 key to the 1.8-1.11 key. Otherwise, it initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes. The mapping can be done, in this `MultiheadAttention` example, by

              .. code-block:: python

                  >>> state_dict["core_attention._extra_state"] = \
                          state_dict["core_attention.fused_attention._extra_state"]
                  >>> del state_dict["core_attention.fused_attention._extra_state"]

             :>= 1.8:    Loads FP8 metadata from checkpoint
   * - **Version: >=1.12**

         - Saves FP8 metadata to `core_attention._extra_state`
         - Loading behavior for checkpoints created by the following versions:

             :<= 1.5:    Initializes FP8 metadata to the default, i.e. 1s for scaling factors, and 0s for amaxes
             :>= 1.6:    Loads FP8 metadata from checkpoint