nemotron_nas.py 16.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
30
31
32
33
34
35
36
37
38
39

import torch
from torch import nn
from transformers import LlamaConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
43
44
    ParallelLMHead,
    VocabParallelEmbedding,
)
45
from vllm.model_executor.model_loader.weight_utils import (
46
47
48
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
49
50
from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP
from vllm.sequence import IntermediateTensors
51
from vllm.v1.attention.backend import AttentionType
52
53

from .interfaces import HasNoOps, SupportsLoRA, SupportsPP
54
55
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
    # DeciLM-specific code
    intermediate_size = int(2 * ffn_mult * n_embd / 3)
    return _find_multiple(intermediate_size, 256)


def _find_multiple(n: int, k: int) -> int:
    # DeciLM-specific code
    if n % k == 0:
        return n
    return n + k - (n % k)


77
78
79
80
81
82
83
84
class DeciLMAttention(LlamaAttention):
    def __init__(
        self,
        config: LlamaConfig,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position_embeddings: int = 8192,
85
        quant_config: QuantizationConfig | None = None,
86
87
        bias: bool = False,
        bias_o_proj: bool = False,
88
        cache_config: CacheConfig | None = None,
89
90
91
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
92
93
94
95
96
97
98
99
100
101
102
103
104
        super().__init__(
            config,
            hidden_size,
            num_heads,
            num_kv_heads,
            max_position_embeddings,
            quant_config,
            bias,
            bias_o_proj,
            cache_config,
            prefix,
            attn_type,
        )
105

106
107
108
    def _init_rotary_emb(
        self,
        config,
109
        quant_config: QuantizationConfig | None,
110
    ) -> None:
111
112
113
114
        # Enables YARN for Mistral and LLaMA4 derivatives.
        is_neox_style = True
        if hasattr(config, "position_embedding_type"):
            is_neox_style = config.position_embedding_type not in [
115
116
                "mistral_yarn",
                "rope_llama4",
117
118
119
120
121
            ]

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
122
            rope_parameters=config.rope_parameters,
123
            is_neox_style=is_neox_style,
124
        )
125
126


127
128
129
130
131
class DeciLMDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_idx: int,
132
133
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
134
135
136
137
138
139
140
141
        prefix: str = "",
    ) -> None:
        super().__init__()
        block_config = config.block_configs[layer_idx]
        self._is_no_op_attention = block_config.attention.no_op
        self._is_no_op_ffn = block_config.ffn.no_op

        self.hidden_size = config.hidden_size
142
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
143
144
145
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
146
147
            config, "bias", False
        )
148
149
150
151
152
153
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, "qkv_bias"):
            attention_bias = config.qkv_bias

        if not self._is_no_op_attention:
154
155
156
            num_kv_heads = (
                config.num_attention_heads // block_config.attention.n_heads_in_group
            )
157
            self.self_attn = DeciLMAttention(
158
159
160
161
162
163
164
165
166
167
168
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                num_kv_heads=num_kv_heads,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                bias=attention_bias,
                bias_o_proj=bias_o_proj,
                cache_config=cache_config,
                prefix=f"{prefix}.self_attn",
            )
169
            self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
170
171

        if not self._is_no_op_ffn:
172
173
174
175
176
177
178
            if hasattr(block_config.ffn, "ffn_mult"):
                ffn_mult = block_config.ffn.ffn_mult
                intermediate_size = _ffn_mult_to_intermediate_size(
                    ffn_mult, config.hidden_size
                )
            else:
                intermediate_size = block_config.ffn.intermediate_size
179
180
181
182
183
184
185
186
187

            self.mlp = LlamaMLP(
                hidden_size=self.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                bias=getattr(config, "mlp_bias", False),
                prefix=f"{prefix}.mlp",
            )
188
189
190
            self.post_attention_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
191
192
193
194
195

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
196
        residual: torch.Tensor | None,
197
    ) -> tuple[torch.Tensor, torch.Tensor]:
198
199
200
201
202
        # Self Attention

        if self._is_no_op_attention:
            pass
        else:
203
            if residual is None:
204
205
206
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
207
                hidden_states, residual = self.input_layernorm(hidden_states, residual)
208
209
210
211
212
213
214
215
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
            )

        # Fully Connected
        if not self._is_no_op_ffn:
            hidden_states, residual = self.post_attention_layernorm(
216
217
                hidden_states, residual
            )
218
219
220
221
222
223
224
225
226
227
228
            hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile
class DeciModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
229
        layer_type: type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
230
231
232
233
234
235
236
237
238
239
    ):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
        self.padding_idx = config.pad_token_id
240
241
242

        self.vocab_size = config.vocab_size

243
244
245
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
            return layer_type(
                config,
                layer_idx,
                cache_config,
                quant_config=quant_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            get_layer,
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

274
275
276
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
277

278
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
279
280
281
282
        return self.embed_tokens(input_ids)

    def forward(
        self,
283
        input_ids: torch.Tensor | None,
284
        positions: torch.Tensor,
285
286
287
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
288
289
290
291
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
292
                hidden_states = self.embed_input_ids(input_ids)
293
294
295
296
297
298
299
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        kv_cache_index = 0
300
        for layer in islice(self.layers, self.start_layer, self.end_layer):
301
            if not layer._is_no_op_attention:
302
                hidden_states, residual = layer(positions, hidden_states, residual)
303
304
                kv_cache_index += 1
            else:
305
                hidden_states, residual = layer(positions, hidden_states, residual)
306
307

        if not get_pp_group().is_last_rank:
308
309
310
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
311
312
313
314

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

315
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
316
317
318
319
320
321
322
323
324
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
325
        loaded_params: set[str] = set()
326
327
328
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
329
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
330
331
332
333
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            if self.quant_config is not None and (
334
335
                scale_name := self.quant_config.get_cache_scale(name)
            ):
336
337
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
338
339
340
341
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
342
343
344
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
345
            if "scale" in name or "zero_point" in name:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
374
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
416

417
418
        self.config = config

419
420
421
        self.model = self._init_model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
422
423
424

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
425
                config.vocab_size,
426
427
428
429
430
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if config.tie_word_embeddings:
431
                self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
432
433

            logit_scale = getattr(config, "logit_scale", 1.0)
434
            self.logits_processor = LogitsProcessor(
435
                config.vocab_size, scale=logit_scale
436
            )
437
438
439
440
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
441
442
            self.model.make_empty_intermediate_tensors
        )
443
444
445
446

    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return DeciModel(vllm_config=vllm_config, prefix=prefix)

447
448
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
449
450
451

    def forward(
        self,
452
        input_ids: torch.Tensor | None,
453
        positions: torch.Tensor,
454
455
456
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
457
458
459
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
460
461
462
463
464
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
465
    ) -> torch.Tensor | None:
466
        logits = self.logits_processor(self.lm_head, hidden_states)
467
468
        return logits

469
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
470
471
        loader = AutoWeightsLoader(
            self,
472
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
473
474
        )
        return loader.load_weights(weights)