llama_eagle.py 7.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
11
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
12
from vllm.logger import init_logger
13
from vllm.model_executor.layers.linear import ReplicatedLinear
14
from vllm.model_executor.layers.logits_processor import LogitsProcessor
15
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
16
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
17
18
19
20
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
21
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
22

23
24
25
26
27
28
from .utils import (
    AutoWeightsLoader,
    get_draft_quant_config,
    maybe_prefix,
    process_eagle_weight,
)
29
30
31
32
33
34
35

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(
        self,
36
        vllm_config: VllmConfig,
37
38
        disable_input_layernorm: bool,
        prefix: str = "",
39
        config: LlamaConfig | None = None,
40
    ) -> None:
41
        super().__init__(vllm_config, prefix=prefix, config=config)
42
43
44
45
46
47
48

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()

49
50
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
        """Use drafter's quantization config instead of verifier's."""
51
        return get_draft_quant_config(vllm_config)
52

53

54
@support_torch_compile
55
56
57
58
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
59
        vllm_config: VllmConfig,
60
        prefix: str = "",
61
        start_layer_id: int = 0,
62
63
    ) -> None:
        super().__init__()
64
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
65
        self.vocab_size = self.config.vocab_size
66

67
68
69
        # Get drafter's quantization config
        self.quant_config = get_draft_quant_config(vllm_config)

70
71
72
73
74
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
75

76
77
78
79
80
81
82
83
84
85
86
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    vllm_config,
                    i == 0,
                    prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
                    config=self.config,
                )
                for i in range(self.config.num_hidden_layers)
            ]
        )
87
88
89
90
91
92
93
94
        self.fc = ReplicatedLinear(
            input_size=self.config.hidden_size * 2,
            output_size=self.config.hidden_size,
            bias=False,
            params_dtype=vllm_config.model_config.dtype,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "fc"),
            return_bias=False,
95
        )
96

97
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
98
99
        return self.embed_tokens(input_ids)

100
101
102
103
104
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
105
    ) -> tuple[torch.Tensor, torch.Tensor]:
106
        input_embeds = self.embed_tokens(input_ids)
107
        hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
108
        residual = None
109
        for layer in self.layers:
110
111
112
113
114
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
115
116
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
117

118
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
119
120
121
122
123
124
125
126
127
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
128
        loaded_params: set[str] = set()
129
        for name, loaded_weight in weights:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            # Handle kv cache quantization scales
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            # Remapping the name FP8 kv-scale
            if "scale" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
148
149
150
151
152
153
154
155
156
157
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
158
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
159
160
161
162
163
164
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):
165
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
166
        nn.Module.__init__(self)
167
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
168
169
170
171
172
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
173
        target_layer_num = vllm_config.model_config.get_num_layers(
174
175
176
177
178
            vllm_config.parallel_config
        )
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
179
180

        logit_scale = getattr(self.config, "logit_scale", 1.0)
181
182
183
        self.logits_processor = LogitsProcessor(
            self.config.vocab_size, scale=logit_scale
        )
184

185
186
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
187

188
189
190
191
192
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
193
        inputs_embeds: torch.Tensor | None = None,
194
    ) -> tuple[torch.Tensor, torch.Tensor]:
195
196
197
198
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
199
200
        return self.model(input_ids, positions, hidden_states)

201
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
202
203
204
205
        def transform(inputs):
            name, loaded_weight = inputs
            if "lm_head" not in name:
                name = "model." + name
206
            process_eagle_weight(self, name)
207
208
            return name, loaded_weight

209
210
        loader = AutoWeightsLoader(
            self,
211
            skip_prefixes=None,
212
        )
213
        loader.load_weights(map(transform, weights))