llama_eagle.py 6.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
12
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
13
from vllm.distributed.parallel_state import get_pp_group
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
        disable_input_layernorm: bool,
        prefix: str = "",
    ) -> None:
        super().__init__(config, prefix=prefix)

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()


44
@support_torch_compile
45
46
47
48
49
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
50
        vllm_config: VllmConfig,
51
        prefix: str = "",
52
        start_layer_id: int = 0,
53
54
    ) -> None:
        super().__init__()
55
56
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
57
        self.vocab_size = self.config.vocab_size
58

59
60
61
62
63
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
64

65
66
67
68
69
70
71
72
73
74
75
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
                self.config,
                i == 0,
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
            ) for i in range(self.config.num_hidden_layers)
        ])
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
                                  self.config.hidden_size,
                                  bias=False)

76
77
78
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

79
80
81
82
83
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
84
    ) -> tuple[torch.Tensor, torch.Tensor]:
85
86
87
88
        input_embeds = self.embed_tokens(input_ids)
        hidden_states = self.fc(
            torch.cat((input_embeds, hidden_states), dim=-1))
        residual = None
89
        for layer in self.layers:
90
91
92
93
94
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
95
96
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
97

98
99
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
100
101
102
103
104
105
106
107
108
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
109
        loaded_params: set[str] = set()
110
111
112
113
114
115
116
117
118
119
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
120
121
122
123
124
125

                # if PP disabled then draft will share embed with target
                if get_pp_group().world_size == 1 and \
                    "embed_tokens." in name:
                    continue

126
127
128
129
130
131
132
133
134
135
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):

136
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
137
        nn.Module.__init__(self)
138
139
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
140
141
142
143
144
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
145
146
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
147
148
        self.model = LlamaModel(vllm_config=vllm_config,
                                prefix="model",
149
                                start_layer_id=target_layer_num)
150
151
152
153
154

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.config.vocab_size,
                                                scale=logit_scale)

155
156
157
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

158
159
160
161
162
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
163
        inputs_embeds: Optional[torch.Tensor] = None,
164
    ) -> tuple[torch.Tensor, torch.Tensor]:
165
166
167
168
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
169
170
        return self.model(input_ids, positions, hidden_states)

171
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
172
173
174
175
176
177
178

        def transform(inputs):
            name, loaded_weight = inputs
            if "lm_head" not in name:
                name = "model." + name
            return name, loaded_weight

179
180
        loader = AutoWeightsLoader(
            self,
181
            skip_prefixes=None,
182
        )
183
        loader.load_weights(map(transform, weights))