llama_eagle.py 6.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
12
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
13
from vllm.distributed.parallel_state import get_pp_group
14
15
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
17
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
18
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
19
20
21
22
23
24
25
26
27

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(
        self,
28
        vllm_config: VllmConfig,
29
30
        disable_input_layernorm: bool,
        prefix: str = "",
31
        config: Optional[LlamaConfig] = None,
32
    ) -> None:
33
        super().__init__(vllm_config, prefix=prefix, config=config)
34
35
36
37
38
39
40
41

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()


42
@support_torch_compile
43
44
45
46
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
47
        vllm_config: VllmConfig,
48
        prefix: str = "",
49
        start_layer_id: int = 0,
50
51
    ) -> None:
        super().__init__()
52
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
53
        self.vocab_size = self.config.vocab_size
54

55
56
57
58
59
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    vllm_config,
                    i == 0,
                    prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
                    config=self.config,
                )
                for i in range(self.config.num_hidden_layers)
            ]
        )
        self.fc = torch.nn.Linear(
            self.config.hidden_size * 2, self.config.hidden_size, bias=False
        )
75

76
77
78
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

79
80
81
82
83
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
84
    ) -> tuple[torch.Tensor, torch.Tensor]:
85
        input_embeds = self.embed_tokens(input_ids)
86
        hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
87
        residual = None
88
        for layer in self.layers:
89
90
91
92
93
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
94
95
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
96

97
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
98
99
100
101
102
103
104
105
106
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
107
        loaded_params: set[str] = set()
108
109
110
111
112
113
114
115
116
117
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
118
                # if PP disabled then draft will share embed with target
119
                if get_pp_group().world_size == 1 and "embed_tokens." in name:
120
121
                    continue

122
                param = params_dict[name]
123
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
124
125
126
127
128
129
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):
130
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
131
        nn.Module.__init__(self)
132
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
133
134
135
136
137
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
138
        target_layer_num = vllm_config.model_config.get_num_layers(
139
140
141
142
143
            vllm_config.parallel_config
        )
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
144
145

        logit_scale = getattr(self.config, "logit_scale", 1.0)
146
147
148
        self.logits_processor = LogitsProcessor(
            self.config.vocab_size, scale=logit_scale
        )
149

150
151
152
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

153
154
155
156
157
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
158
        inputs_embeds: Optional[torch.Tensor] = None,
159
    ) -> tuple[torch.Tensor, torch.Tensor]:
160
161
162
163
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
164
165
        return self.model(input_ids, positions, hidden_states)

166
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
167
168
169
170
171
172
        def transform(inputs):
            name, loaded_weight = inputs
            if "lm_head" not in name:
                name = "model." + name
            return name, loaded_weight

173
174
        loader = AutoWeightsLoader(
            self,
175
            skip_prefixes=None,
176
        )
177
        loader.load_weights(map(transform, weights))