te_llama.py 8.99 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

import os
import re
import gc
from contextlib import contextmanager

import torch

import transformer_engine as te
13
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
14
15

import transformers
16
17
18
19
20
21
from transformers.models.llama.modeling_llama import (
    LlamaModel,
    LlamaForCausalLM,
    LlamaRMSNorm,
    LlamaConfig,
)
22
from transformers.modeling_utils import _add_variant, load_state_dict
23
24
25
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

26

27
@contextmanager
28
def replace_decoder(te_decoder_cls):
29
30
31
32
    """
    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
    """
    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
33
    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    try:
        yield
    finally:
        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls


class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

    Args:
        config: LlamaConfig
        args: positional args (for compatibility with `LlamaDecoderLayer`)
        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
    """
50

51
52
53
54
55
56
57
58
59
60
61
62
63
    def __init__(self, config, *args, **kwargs):
        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
64
            num_gqa_groups=config.num_key_value_heads,
65
        )
66
        te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
67
68
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

69
    def forward(self, hidden_states, *args, attention_mask, **kwargs):
70
71
72
73
74
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
75
76
77
78
79
80
81
82
83
        # Handle case where hidden_states might be a tuple (from previous layer output)
        # This can happen with older versions of HuggingFace transformers
        if isinstance(hidden_states, tuple):
            hidden_states = hidden_states[0]

        # Return tensor directly for HuggingFace transformers >= 4.57
        # (older versions wrapped output in tuple and extracted with layer_outputs[0])
        return super().forward(
            hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
84
        )
85
86
87
88
89
90
91
92
93
94
95
96
97


class TELlamaForCausalLM:
    """
    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
    class is monkey-patched with `TELlamaDecoderLayer` class before
    initializing the causal LM with `LlamaForCausalLM`.

    Args:
        config: LlamaConfig
    """

    def __new__(cls, config: LlamaConfig):
98
        with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
99
100
101
102
103
104
105
106
107
            llama_for_causal_lm = LlamaForCausalLM(config)
        return llama_for_causal_lm

    @classmethod
    def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
        """
        Custom method adapted from `from_pretrained` method in HuggingFace
        Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
        """
108
109
110
111
112
        # Before loading the model, set the default dtype for torch
        torch.set_default_dtype(kwargs["torch_dtype"])

        # Load the vanilla model weights
        vanilla_model = cls(config)
113
114
115
        subfolder = ""
        variant = None
        if os.path.isfile(
116
117
118
119
120
121
            os.path.join(
                pretrained_model_name_or_path,
                subfolder,
                _add_variant("model.safetensors.index.json", variant),
            )
        ):
122
123
            # Load from a sharded PyTorch checkpoint
            archive_file = os.path.join(
124
125
126
                pretrained_model_name_or_path,
                subfolder,
                _add_variant("model.safetensors.index.json", variant),
127
128
129
            )
            is_sharded = True
        elif os.path.isfile(
130
131
132
133
            os.path.join(
                pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
            )
        ):
134
135
136
137
138
            # Load from a sharded PyTorch checkpoint
            archive_file = os.path.join(
                pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
            )
            is_sharded = True
139
140
141
        else:
            raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

142
        resolved_archive_file, _ = get_checkpoint_shard_files(
143
144
            pretrained_model_name_or_path,
            archive_file,
145
146
147
148
149
150
151
152
153
        )

        # If the checkpoint is not sharded, it's a trivial sharding case
        if not is_sharded:
            assert not isinstance(resolved_archive_file, list)
            resolved_archive_file = [resolved_archive_file]

        for shard_file in resolved_archive_file:
            state_dict = load_state_dict(shard_file)
154
155
            # replace_params copies parameters relevant only to TransformerEngine
            replace_params(state_dict, vanilla_model.state_dict(), config)
156
157
            # load_state_dict copies parameters other than those in TransformerEngine
            vanilla_model.load_state_dict(state_dict, strict=False)
158
159
160
161
162
163
164

            # Force mem release. Taken from huggingface code
            del state_dict
            gc.collect()

        return vanilla_model

165

166
def replace_params(hf_state_dict, te_state_dict, config):
167
168
169
    # collect all layer prefixes to update
    all_layer_prefixes = set()
    for param_key in hf_state_dict.keys():
170
        layer_prefix_pat = r"model.layers.\d+."
171
172
173
174
175
176
        m = re.match(layer_prefix_pat, param_key)
        if m is not None:
            all_layer_prefixes.add(m.group())

    for layer_prefix in all_layer_prefixes:
        # When loading weights into models with less number of layers, skip the
177
        # copy if the corresponding layer doesn't exist in HF model
178
179
180
181
182
183
184
185
186
        if layer_prefix + "input_layernorm.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
                :
            ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]

        if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
            )
187

188
189
190
191
        if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
            )
192

193
194
195
196
        if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
            )
197

198
199
200
201
        if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
                layer_prefix + "self_attn.o_proj.weight"
            ].data[:]
202

203
204
205
206
        if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
                layer_prefix + "post_attention_layernorm.weight"
            ].data[:]
207

208
209
        # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
        # load them separately.
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                : config.intermediate_size
            ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data

        if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                config.intermediate_size :
            ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data

        if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
                layer_prefix + "mlp.down_proj.weight"
            ].data[:]
    return all_layer_prefixes