te_llama.py 7.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import re
import gc
from contextlib import contextmanager

import torch
from torch import nn

import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init

import transformers
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files

@contextmanager
24
def replace_decoder(te_decoder_cls):
25
26
27
28
    """
    Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
    """
    original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
29
    transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    try:
        yield
    finally:
        transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls


class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
    """
    Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
    similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.

    Args:
        config: LlamaConfig
        args: positional args (for compatibility with `LlamaDecoderLayer`)
        kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
    """
    def __init__(self, config, *args, **kwargs):
        super().__init__(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
            bias=False,
            layernorm_epsilon=config.rms_norm_eps,
            hidden_dropout=0,
            attention_dropout=0,
            fuse_qkv_params=False,
            normalization="RMSNorm",
            activation="swiglu",
            attn_input_format="bshd",
59
            num_gqa_groups=config.num_key_value_heads,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        )
        te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
        self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()

    def forward(self,
                hidden_states,
                *args,
                attention_mask,
                **kwargs):
        """
        Custom forward to make sure we only pass relevant arguments to the
        forward pass of the `TransformerLayer`. Also, make sure the output
        format matches the output of the HF's `LlamaDecoderLayer`.
        """
        return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)


class TELlamaForCausalLM:
    """
    Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
    class is monkey-patched with `TELlamaDecoderLayer` class before
    initializing the causal LM with `LlamaForCausalLM`.

    Args:
        config: LlamaConfig
    """

    def __new__(cls, config: LlamaConfig):
88
        with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            llama_for_causal_lm = LlamaForCausalLM(config)
        return llama_for_causal_lm

    @classmethod
    def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
        """
        Custom method adapted from `from_pretrained` method in HuggingFace
        Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
        """
        vanilla_model = cls(config).to(kwargs['torch_dtype'])
        is_local = os.path.isdir(pretrained_model_name_or_path)
        subfolder = ""
        variant = None
        if os.path.isfile(
                    os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
            ):
                # Load from a sharded PyTorch checkpoint
                archive_file = os.path.join(
                    pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
                )
                is_sharded = True
        else:
            raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")


        resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
                pretrained_model_name_or_path,
                archive_file,
        )

        # If the checkpoint is not sharded, it's a trivial sharding case
        if not is_sharded:
            assert not isinstance(resolved_archive_file, list)
            resolved_archive_file = [resolved_archive_file]

        error_msgs = []
        for shard_file in resolved_archive_file:
            state_dict = load_state_dict(shard_file)
            replaced_layers = replace_params(state_dict, vanilla_model.state_dict())

            error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")

            # Force mem release. Taken from huggingface code
            del state_dict
            gc.collect()

        return vanilla_model

def replace_params(hf_state_dict, te_state_dict):
    # collect all layer prefixes to update
    all_layer_prefixes = set()
    for param_key in hf_state_dict.keys():
        layer_prefix_pat = 'model.layers.\d+.'
        m = re.match(layer_prefix_pat, param_key)
        if m is not None:
            all_layer_prefixes.add(m.group())

    for layer_prefix in all_layer_prefixes:
        # When loading weights into models with less number of layers, skip the
        # copy if the corresponding layer doesn't exist in TE model
        if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]

        if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]

        if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]

        if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]

        if layer_prefix + 'self_attention.proj.weight' in te_state_dict:
            te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]

        if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]

        if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0)

        if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict:
            te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]

    return all_layer_prefixes