granitemoeshared.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Inference-only GraniteMoeShared model.

The architecture is the same as granitemoe but with the addition of shared
experts.
"""
8

9
from collections.abc import Iterable
10
from itertools import islice
11
12
13
14
15
16
17
18
19
20

import torch
from torch import nn
from transformers.models.granitemoeshared import GraniteMoeSharedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
24
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
    ParallelLMHead,
    VocabParallelEmbedding,
)
31
32
from vllm.sequence import IntermediateTensors

33
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
34
from .interfaces import SupportsLoRA, SupportsPP
35
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
36
37
38
39
40
41


class GraniteMoeSharedMLP(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
42
        quant_config: QuantizationConfig | None = None,
43
44
45
46
47
48
49
50
51
52
53
        prefix: str = "",
    ):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.shared_intermediate_size
        self.input_linear = MergedColumnParallelLinear(
            input_size=self.input_size,
            output_sizes=[self.hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
54
55
            prefix=f"{prefix}.input_linear",
        )
56
57
58
59
60
        self.output_linear = RowParallelLinear(
            self.hidden_size,
            self.input_size,
            bias=False,
            quant_config=quant_config,
61
62
            prefix=f"{prefix}.output_linear",
        )
63
        if config.hidden_act != "silu":
64
65
66
67
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.input_linear(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.output_linear(hidden_states)
        return hidden_states


class GraniteMoeSharedDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
81
82
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
83
84
85
86
87
88
89
90
91
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
92
            rope_parameters=config.rope_parameters,
93
94
95
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
96
97
            attention_multiplier=config.attention_multiplier,
        )
98
99
100
101
102
103
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
104
105
106
107
108
            prefix=f"{prefix}.block_sparse_moe",
        )
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
109
            else GraniteMoeSharedMLP(
110
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
111
            )
112
        )
113

114
115
116
117
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
            hidden_states = self.block_sparse_moe(hidden_states)
        else:
            # create a copy since block_sparse_moe modifies in-place
            moe_hidden_states = hidden_states.clone()
            moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
            del moe_hidden_states
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


@support_torch_compile
class GraniteMoeSharedModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

158
159
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
160
        self.padding_idx = config.pad_token_id
161
162

        self.vocab_size = config.vocab_size
163
164
165
166
167
168
169
170
171
172
173
174
175

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GraniteMoeSharedDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
176
177
            prefix=f"{prefix}.layers",
        )
178
179
180

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

181
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
182
183
184
185
186
187
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
188
189
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
190
191
192
193
194
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
195
                hidden_states = self.embed_input_ids(input_ids)
196
197
198
199
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
200
        for layer in islice(self.layers, self.start_layer, self.end_layer):
201
202
            hidden_states = layer(positions, hidden_states)
        if not get_pp_group().is_last_rank:
203
204
205
206
207
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
208
209
210
        hidden_states = self.norm(hidden_states)
        return hidden_states

211
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
212
213
        new_weights = {}
        for n, p in weights:
214
            if n.endswith(".block_sparse_moe.input_linear.weight"):
215
216
                for e in range(p.size(0)):
                    w1_name = n.replace(
217
218
219
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
220
                    w3_name = n.replace(
221
222
223
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
224
225
226
227
228
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
229
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
230
231
                for e in range(p.size(0)):
                    w2_name = n.replace(
232
233
234
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
235
236
237
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
238
239
240
241
242
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
243
244
245
246
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
247
        return GraniteMoeModel._load_weights(self, new_weights.items())
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config

274
275
276
        self.model = GraniteMoeSharedModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
277

278
        self.lm_head = ParallelLMHead(
279
            config.vocab_size,
280
281
            config.hidden_size,
            quant_config=quant_config,
282
283
            prefix=maybe_prefix(prefix, "lm_head"),
        )
284
285
286
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

287
        self.logits_processor = LogitsProcessor(
288
            config.vocab_size,
289
290
291
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
292

293
294
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
295
296
297
298
299

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
300
301
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
302
    ) -> torch.Tensor:
303
304
305
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
306
307
        return hidden_states

308
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
309
        logits = self.logits_processor(self.lm_head, hidden_states)
310
311
312
        return logits

    def make_empty_intermediate_tensors(
313
314
315
316
317
318
319
320
321
322
323
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
324
325
        loader = AutoWeightsLoader(
            self,
326
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
327
328
        )
        return loader.load_weights(weights)