llama_eagle.py 5.81 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections.abc import Iterable
4
5
6
7
8

import torch
import torch.nn as nn
from transformers import LlamaConfig

9
10
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
11
from vllm.distributed.parallel_state import get_pp_group
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
        disable_input_layernorm: bool,
        prefix: str = "",
    ) -> None:
        super().__init__(config, prefix=prefix)

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()


42
@support_torch_compile
43
44
45
46
47
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
48
        vllm_config: VllmConfig,
49
        prefix: str = "",
50
        start_layer_id: int = 0,
51
52
    ) -> None:
        super().__init__()
53
54
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
55
        self.vocab_size = self.config.vocab_size
56
57
58
59
60
61
62
63
64

        # if PP disabled then draft will share embed with target
        if get_pp_group().world_size > 1:
            self.embed_tokens = VocabParallelEmbedding(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "embed_tokens"),
            )

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
                self.config,
                i == 0,
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
            ) for i in range(self.config.num_hidden_layers)
        ])
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
                                  self.config.hidden_size,
                                  bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
81
    ) -> tuple[torch.Tensor, torch.Tensor]:
82
83
84
85
        input_embeds = self.embed_tokens(input_ids)
        hidden_states = self.fc(
            torch.cat((input_embeds, hidden_states), dim=-1))
        residual = None
86
        for layer in self.layers:
87
88
89
90
91
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
92
93
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
94

95
96
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
97
98
99
100
101
102
103
104
105
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
106
        loaded_params: set[str] = set()
107
108
109
110
111
112
113
114
115
116
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
117
118
119
120
121
122

                # if PP disabled then draft will share embed with target
                if get_pp_group().world_size == 1 and \
                    "embed_tokens." in name:
                    continue

123
124
125
126
127
128
129
130
131
132
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):

133
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
134
        nn.Module.__init__(self)
135
136
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
137
138
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
139
140
        self.model = LlamaModel(vllm_config=vllm_config,
                                prefix="model",
141
                                start_layer_id=target_layer_num)
142
143
144
145
146
147
148
149
150
151

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.config.vocab_size,
                                                scale=logit_scale)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
152
    ) -> tuple[torch.Tensor, torch.Tensor]:
153
154
        return self.model(input_ids, positions, hidden_states)

155
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
156
157
        loader = AutoWeightsLoader(
            self,
158
            skip_prefixes=None,
159
160
161
162
163
164
165
        )

        model_weights = {}
        for name, loaded_weight in weights:
            if "lm_head" not in name:
                name = "model." + name
            model_weights[name] = loaded_weight
166
        return loader.load_weights(model_weights.items())