granitemoeshared.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Inference-only GraniteMoeShared model.

The architecture is the same as granitemoe but with the addition of shared
experts.
"""
8

9
from collections.abc import Iterable
10
from itertools import islice
11
12
13
14
15
16
17
18
19
20

import torch
from torch import nn
from transformers.models.granitemoeshared import GraniteMoeSharedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
24
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
31
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
32
33
from vllm.sequence import IntermediateTensors

34
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
35
from .interfaces import SupportsLoRA, SupportsPP
36
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
37
38
39
40
41
42


class GraniteMoeSharedMLP(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
43
        quant_config: QuantizationConfig | None = None,
44
45
46
47
48
49
50
51
52
53
54
        prefix: str = "",
    ):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.shared_intermediate_size
        self.input_linear = MergedColumnParallelLinear(
            input_size=self.input_size,
            output_sizes=[self.hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
55
56
            prefix=f"{prefix}.input_linear",
        )
57
58
59
60
61
        self.output_linear = RowParallelLinear(
            self.hidden_size,
            self.input_size,
            bias=False,
            quant_config=quant_config,
62
63
            prefix=f"{prefix}.output_linear",
        )
64
        if config.hidden_act != "silu":
65
66
67
68
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
69
70
71
72
73
74
75
76
77
78
79
80
81
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.input_linear(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.output_linear(hidden_states)
        return hidden_states


class GraniteMoeSharedDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
82
83
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
84
85
86
87
88
89
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
90
        rope_scaling = getattr(config, "rope_scaling", None)
91
92
93
94
95
96
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
97
            rope_scaling=rope_scaling,
98
99
100
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
101
102
            attention_multiplier=config.attention_multiplier,
        )
103
104
105
106
107
108
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
109
110
111
112
113
            prefix=f"{prefix}.block_sparse_moe",
        )
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
114
            else GraniteMoeSharedMLP(
115
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
116
            )
117
        )
118

119
120
121
122
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
            hidden_states = self.block_sparse_moe(hidden_states)
        else:
            # create a copy since block_sparse_moe modifies in-place
            moe_hidden_states = hidden_states.clone()
            moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
            del moe_hidden_states
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


@support_torch_compile
class GraniteMoeSharedModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

164
165
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
166
        self.padding_idx = config.pad_token_id
167
168
169
170
171
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=quant_config,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GraniteMoeSharedDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
188
189
            prefix=f"{prefix}.layers",
        )
190
191
192
193
194
195
196
197
198
199

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
200
201
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
202
203
204
205
206
207
208
209
210
211
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
212
        for layer in islice(self.layers, self.start_layer, self.end_layer):
213
214
            hidden_states = layer(positions, hidden_states)
        if not get_pp_group().is_last_rank:
215
216
217
218
219
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
220
221
222
        hidden_states = self.norm(hidden_states)
        return hidden_states

223
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
224
225
        new_weights = {}
        for n, p in weights:
226
            if n.endswith(".block_sparse_moe.input_linear.weight"):
227
228
                for e in range(p.size(0)):
                    w1_name = n.replace(
229
230
231
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
232
                    w3_name = n.replace(
233
234
235
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
236
237
238
239
240
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
241
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
242
243
                for e in range(p.size(0)):
                    w2_name = n.replace(
244
245
246
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
247
248
249
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
250
251
252
253
254
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
255
256
257
258
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
259
        return GraniteMoeModel._load_weights(self, new_weights.items())
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config

289
290
291
        self.model = GraniteMoeSharedModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
292
293
294
295
296
297
298
299
300
301
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
302
303
            if not lora_config
            else lora_config.lora_vocab_padding_size,
304
            quant_config=quant_config,
305
306
            prefix=maybe_prefix(prefix, "lm_head"),
        )
307
308
309
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

310
311
312
313
314
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size,
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
315
316
317
318
319
320
321
322

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
323
324
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
325
    ) -> torch.Tensor:
326
327
328
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
329
330
        return hidden_states

331
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
332
        logits = self.logits_processor(self.lm_head, hidden_states)
333
334
335
        return logits

    def make_empty_intermediate_tensors(
336
337
338
339
340
341
342
343
344
345
346
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
347
348
        loader = AutoWeightsLoader(
            self,
349
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
350
351
        )
        return loader.load_weights(weights)