llama4_eagle.py 8.74 KB
Newer Older
zhiweiz's avatar
zhiweiz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable

import torch
import torch.nn as nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
29
from vllm.model_executor.layers.quantization import QuantizationConfig
zhiweiz's avatar
zhiweiz committed
30
from vllm.model_executor.layers.quantization.torchao import TorchAOConfig
31
32
33
34
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
zhiweiz's avatar
zhiweiz committed
35
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM
zhiweiz's avatar
zhiweiz committed
37
38
from vllm.model_executor.models.utils import extract_layer_index

39
from .interfaces import SupportsMultiModal
40
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
zhiweiz's avatar
zhiweiz committed
41
42
43
44
45
46
47
48
49
50
51
52

logger = init_logger(__name__)


@support_torch_compile
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        start_layer_id: int = 0,
53
        quant_config: QuantizationConfig | None = None,
zhiweiz's avatar
zhiweiz committed
54
55
    ) -> None:
        super().__init__()
56
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
zhiweiz's avatar
zhiweiz committed
57
58
59
60
61
62
63
64
        self.validate_and_update_config(start_layer_id, quant_config)
        self.vocab_size = self.config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        # Temporarily modify vllm_config.quant_config for draft model layers
        original_quant_config = vllm_config.quant_config
        vllm_config.quant_config = quant_config
        try:
            self.layers = nn.ModuleList(
                [
                    Llama4DecoderLayer(
                        vllm_config=vllm_config,
                        prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
                        config=self.config,
                    )
                    for i in range(self.config.num_hidden_layers)
                ]
            )
        finally:
            # Restore original quant_config
            vllm_config.quant_config = original_quant_config
82
83
84
85
        self.fc = torch.nn.Linear(
            self.config.hidden_size * 2, self.config.hidden_size, bias=False
        )
        self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
zhiweiz's avatar
zhiweiz committed
86

87
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
88
89
        return self.embed_tokens(input_ids)

zhiweiz's avatar
zhiweiz committed
90
91
    def forward(
        self,
92
        input_ids: torch.Tensor | None,
zhiweiz's avatar
zhiweiz committed
93
94
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
95
        inputs_embeds: torch.Tensor | None = None,
zhiweiz's avatar
zhiweiz committed
96
    ) -> tuple[torch.Tensor, torch.Tensor]:
97
        if inputs_embeds is None:
98
            inputs_embeds = self.embed_input_ids(input_ids)
99
        hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
zhiweiz's avatar
zhiweiz committed
100
101
102
103
104
105
106
107
108
109
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states, hidden_states

110
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
zhiweiz's avatar
zhiweiz committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            name = name.removeprefix("model.")
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
133
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
zhiweiz's avatar
zhiweiz committed
134
135
136
137
138
139
140
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        for name in params_dict:
            assert name in loaded_params, f"{name} is not loaded!"
        return loaded_params

    def validate_and_update_config(
141
        self, start_layer_id: int, quant_config: QuantizationConfig | None = None
142
    ) -> None:
zhiweiz's avatar
zhiweiz committed
143
144
145
146
147
148
        # yoco and moe is not supported by draft model yet
        assert self.config.yoco_global_kv_layer is None
        assert self.config.yoco_local_kv_layer is None
        assert len(self.config.moe_layers) == 0
        # draft model layer index is increased by start_layer_id,
        # so we need to pad relevant configs accordingly
149
        self.config.no_rope_layers = [0] * start_layer_id + self.config.no_rope_layers
zhiweiz's avatar
zhiweiz committed
150
151
152
153
154
        # currently only TorchAO quantization is supported
        if isinstance(quant_config, TorchAOConfig):

            def pad_layer_name(layer: str) -> str:
                layer_index = extract_layer_index(layer)
155
156
157
                return layer.replace(
                    str(layer_index), str(layer_index + start_layer_id)
                )
zhiweiz's avatar
zhiweiz committed
158

159
160
            torchao_config = quant_config.torchao_config
            torchao_config.module_fqn_to_config = {
zhiweiz's avatar
zhiweiz committed
161
                pad_layer_name(layer): quantization
162
                for layer, quantization in torchao_config.module_fqn_to_config.items()
zhiweiz's avatar
zhiweiz committed
163
164
165
166
167
168
            }


class EagleLlama4ForCausalLM(Llama4ForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        nn.Module.__init__(self)
169
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
zhiweiz's avatar
zhiweiz committed
170
        target_layer_num = vllm_config.model_config.get_num_layers(
171
172
            vllm_config.parallel_config
        )
zhiweiz's avatar
zhiweiz committed
173
174
        # draft model quantization config may differ from target model
        quant_config = VllmConfig.get_quantization_config(
175
176
177
178
179
180
181
182
            vllm_config.speculative_config.draft_model_config, vllm_config.load_config
        )
        self.model = LlamaModel(
            vllm_config=vllm_config,
            prefix="model",
            start_layer_id=target_layer_num,
            quant_config=quant_config,
        )
zhiweiz's avatar
zhiweiz committed
183
        logit_scale = getattr(self.config, "logit_scale", 1.0)
184
185
186
        self.logits_processor = LogitsProcessor(
            self.config.vocab_size, scale=logit_scale
        )
zhiweiz's avatar
zhiweiz committed
187

188
189
190
191
192
193
        self.lm_head = ParallelLMHead(
            self.config.draft_vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "lm_head"),
        )

194
195
196
        # Set MoE hyperparameters
        self.set_moe_parameters()

197
198
199
    def get_language_model(self) -> torch.nn.Module:
        return self.model

200
    embed_input_ids = SupportsMultiModal.embed_input_ids  # type: ignore
201

zhiweiz's avatar
zhiweiz committed
202
203
204
205
206
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
207
        inputs_embeds: torch.Tensor | None = None,
zhiweiz's avatar
zhiweiz committed
208
    ) -> tuple[torch.Tensor, torch.Tensor]:
209
        return self.model(input_ids, positions, hidden_states, inputs_embeds)
zhiweiz's avatar
zhiweiz committed
210

211
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
212
213
        def transform(inputs):
            name, loaded_weight = inputs
214
            name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
215
216
            if "lm_head" not in name:
                name = "model." + name
217
            process_eagle_weight(self, name)
218
219
            return name, weight

zhiweiz's avatar
zhiweiz committed
220
221
222
        loader = AutoWeightsLoader(
            self,
            # lm_head is tied with target model (Llama4ForCausalLM)
223
            skip_prefixes=([]),
zhiweiz's avatar
zhiweiz committed
224
        )
225
        loader.load_weights(map(transform, weights))