te_llama.py 8.75 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# See LICENSE for license information.

import os
import re
import gc
from contextlib import contextmanager

import torch
from torch import nn

import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init

import transformers
18
19
20
21
22
23
from transformers.models.llama.modeling_llama import (
    LlamaModel,
    LlamaForCausalLM,
    LlamaRMSNorm,
    LlamaConfig,
)
24
25
26
27
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

28

29
@contextmanager
30
def replace_decoder(te_decoder_cls):
31
32
33
34
    """
    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
    """
    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
35
    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    try:
        yield
    finally:
        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls


class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

    Args:
        config: LlamaConfig
        args: positional args (for compatibility with `LlamaDecoderLayer`)
        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
    """
52

53
54
55
56
57
58
59
60
61
62
63
64
65
    def __init__(self, config, *args, **kwargs):
        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
66
            num_gqa_groups=config.num_key_value_heads,
67
        )
68
        te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
69
70
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

71
    def forward(self, hidden_states, *args, attention_mask, **kwargs):
72
73
74
75
76
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
77
78
79
80
81
        return (
            super().forward(
                hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
            ),
        )
82
83
84
85
86
87
88
89
90
91
92
93
94


class TELlamaForCausalLM:
    """
    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
    class is monkey-patched with `TELlamaDecoderLayer` class before
    initializing the causal LM with `LlamaForCausalLM`.

    Args:
        config: LlamaConfig
    """

    def __new__(cls, config: LlamaConfig):
95
        with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
96
97
98
99
100
101
102
103
104
            llama_for_causal_lm = LlamaForCausalLM(config)
        return llama_for_causal_lm

    @classmethod
    def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
        """
        Custom method adapted from `from_pretrained` method in HuggingFace
        Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
        """
105
106
107
108
109
        # Before loading the model, set the default dtype for torch
        torch.set_default_dtype(kwargs["torch_dtype"])

        # Load the vanilla model weights
        vanilla_model = cls(config)
110
111
112
        subfolder = ""
        variant = None
        if os.path.isfile(
113
114
115
116
117
118
            os.path.join(
                pretrained_model_name_or_path,
                subfolder,
                _add_variant("model.safetensors.index.json", variant),
            )
        ):
119
120
            # Load from a sharded PyTorch checkpoint
            archive_file = os.path.join(
121
122
123
                pretrained_model_name_or_path,
                subfolder,
                _add_variant("model.safetensors.index.json", variant),
124
125
126
            )
            is_sharded = True
        elif os.path.isfile(
127
128
129
130
            os.path.join(
                pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
            )
        ):
131
132
133
134
135
            # Load from a sharded PyTorch checkpoint
            archive_file = os.path.join(
                pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
            )
            is_sharded = True
136
137
138
        else:
            raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")

139
        resolved_archive_file, _ = get_checkpoint_shard_files(
140
141
            pretrained_model_name_or_path,
            archive_file,
142
143
144
145
146
147
148
149
150
        )

        # If the checkpoint is not sharded, it's a trivial sharding case
        if not is_sharded:
            assert not isinstance(resolved_archive_file, list)
            resolved_archive_file = [resolved_archive_file]

        for shard_file in resolved_archive_file:
            state_dict = load_state_dict(shard_file)
151
152
153
154
            # replace_params copies parameters relevant only to TransformerEngine
            replace_params(state_dict, vanilla_model.state_dict(), config)
            # _load_state_dict_into_model copies parameters other than those in TransformerEngine
            _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
155
156
157
158
159
160
161

            # Force mem release. Taken from huggingface code
            del state_dict
            gc.collect()

        return vanilla_model

162

163
def replace_params(hf_state_dict, te_state_dict, config):
164
165
166
    # collect all layer prefixes to update
    all_layer_prefixes = set()
    for param_key in hf_state_dict.keys():
167
        layer_prefix_pat = "model.layers.\d+."
168
169
170
171
172
173
        m = re.match(layer_prefix_pat, param_key)
        if m is not None:
            all_layer_prefixes.add(m.group())

    for layer_prefix in all_layer_prefixes:
        # When loading weights into models with less number of layers, skip the
174
        # copy if the corresponding layer doesn't exist in HF model
175
176
177
178
179
180
181
182
183
        if layer_prefix + "input_layernorm.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
                :
            ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]

        if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
            )
184

185
186
187
188
        if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
            )
189

190
191
192
193
        if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
                hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
            )
194

195
196
197
198
        if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
                layer_prefix + "self_attn.o_proj.weight"
            ].data[:]
199

200
201
202
203
        if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
                layer_prefix + "post_attention_layernorm.weight"
            ].data[:]
204

205
206
        # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
        # load them separately.
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                : config.intermediate_size
            ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data

        if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                config.intermediate_size :
            ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data

        if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
                layer_prefix + "mlp.down_proj.weight"
            ].data[:]
    return all_layer_prefixes