granitemoe.py 20.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
from typing import Any
30
31
32
33

import torch
from torch import nn

34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
37
38
39
40
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
41
from vllm.model_executor.layers.attention import Attention
42
43
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
44
45
46
47
48
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
49
from vllm.model_executor.layers.logits_processor import LogitsProcessor
50
from vllm.model_executor.layers.quantization import QuantizationConfig
51
52
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
53
54
55
    ParallelLMHead,
    VocabParallelEmbedding,
)
56
from vllm.model_executor.model_loader.weight_utils import (
57
58
59
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
60
from vllm.model_executor.models.utils import sequence_parallel_chunk
61
62
from vllm.sequence import IntermediateTensors

63
from .interfaces import SupportsLoRA, SupportsPP
64
from .utils import AutoWeightsLoader, is_pp_missing_parameter, make_layers, maybe_prefix
65
66
67
68
69
70
71
72
73
74


class GraniteMoeMoE(nn.Module):
    """A tensor-parallel MoE implementation for GraniteMoe that shards each
    expert across all ranks.
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

75
76
77
78
79
80
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
81
82
83
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
84
85
86
        is_sequence_parallel=False,
        prefix: str = "",
    ):
87
88
        super().__init__()
        self.hidden_size = hidden_size
89
        self.is_sequence_parallel = is_sequence_parallel
90
91

        # Gate always runs at half / full precision for now.
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
            is_sequence_parallel=self.is_sequence_parallel,
        )
113
114
115
116
117

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
118
119
120
121

        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

122
123
124
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = self.experts(hidden_states, router_logits)
125
126
127

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
128
129
                final_hidden_states, 0
            )
130
131
132
            num_tokens = orig_shape[0]
            final_hidden_states = final_hidden_states[:num_tokens]

133
134
135
136
137
138
139
140
141
142
        return final_hidden_states.view(orig_shape)


class GraniteMoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
143
        rope_parameters: dict[str, Any] | None = None,
144
145
146
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        attention_multiplier: float | None = None,
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
168
169
170
171
172
        self.scaling = (
            attention_multiplier
            if attention_multiplier is not None
            else self.head_dim**-1
        )
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
193
            rope_parameters=rope_parameters,
194
195
            is_neox_style=True,
        )
196
197
198
199
200
201
202
203
204
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
205
206
207
208
209
210
211
212
213

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
214
        attn_output = self.attn(q, k, v)
215
216
217
218
219
220
221
        output, _ = self.o_proj(attn_output)
        return output


class GraniteMoeDecoderLayer(nn.Module):
    def __init__(
        self,
222
        vllm_config: VllmConfig,
223
224
225
        prefix: str = "",
    ) -> None:
        super().__init__()
226
227
228
229
230
231

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

232
233
234
235
236
237
        self.hidden_size = config.hidden_size
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
238
            rope_parameters=config.rope_parameters,
239
240
241
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
242
243
            attention_multiplier=config.attention_multiplier,
        )
244
245
246
247
248
249
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
250
            is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
251
252
            prefix=f"{prefix}.block_sparse_moe",
        )
253

254
255
256
257
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.block_sparse_moe(hidden_states)
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


282
@support_torch_compile
283
class GraniteMoeModel(nn.Module):
284
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
285
        super().__init__()
286
287
288
289

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

290
291
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
292
293

        self.vocab_size = config.vocab_size
294
295
296
297
298
299
300
301
302

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
303
            lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
304
305
            prefix=f"{prefix}.layers",
        )
306
307
308

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

309
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
310
311
        return self.embed_tokens(input_ids)

312
313
    def forward(
        self,
314
        input_ids: torch.Tensor | None,
315
        positions: torch.Tensor,
316
317
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
318
319
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
320
321
322
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
323
                hidden_states = self.embed_input_ids(input_ids)
324
325
326
327
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
328
        for layer in islice(self.layers, self.start_layer, self.end_layer):
329
            hidden_states = layer(positions, hidden_states)
330
        if not get_pp_group().is_last_rank:
331
332
333
334
335
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
336
337
338
        hidden_states = self.norm(hidden_states)
        return hidden_states

339
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
340
        """
341
342
        This function is copied from `MixtralModel.load_weights`, mainly to
        decouple from mixtral, avoiding impact on support like BNB
343
344
345
346
347
348
349
350
351
352
353
354
        quantization.
        """
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
355
            self,
356
357
358
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
359
360
            num_experts=self.config.num_local_experts,
        )
361
362
363
364

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
365
366
367
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
368
369
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
370
371
372
373
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
374
375
376
377
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

378
            for param_name, weight_name, shard_id in stacked_params_mapping:
379
380
381
382
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
383
384
385
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
408
409
410
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
411
412
413
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
414
415
416
417
418
419
420
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
421
422
423
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
424
425
426
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
427
428
429
430
431
432
433
434
435
436
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    param = params_dict[name]
437
438
439
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
440
441
442
443
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

444
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
445
446
        new_weights = {}
        for n, p in weights:
447
            if n.endswith(".block_sparse_moe.input_linear.weight"):
448
449
                for e in range(p.size(0)):
                    w1_name = n.replace(
450
451
452
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
453
                    w3_name = n.replace(
454
455
456
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
457
458
459
460
461
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
462
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
463
464
                for e in range(p.size(0)):
                    w2_name = n.replace(
465
466
467
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
468
469
470
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
471
472
473
474
475
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
476
477
478
479
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
480
        return self._load_weights(new_weights.items())
481

482

483
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

500
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
501
        super().__init__()
502
503
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
504
505
506

        self.config = config

507
508
509
        self.model = GraniteMoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
510

511
        self.lm_head = ParallelLMHead(
512
            config.vocab_size,
513
514
            config.hidden_size,
            quant_config=quant_config,
515
            prefix=maybe_prefix(prefix, "lm_head"),
516
517
518
519
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

520
521
522
523
        self.logits_processor = LogitsProcessor(
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
524

525
526
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
527

528
529
    def forward(
        self,
530
        input_ids: torch.Tensor | None,
531
        positions: torch.Tensor,
532
533
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
534
    ) -> torch.Tensor:
535
536
537
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
538
539
        return hidden_states

540
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
541
        logits = self.logits_processor(self.lm_head, hidden_states)
542
543
544
        return logits

    def make_empty_intermediate_tensors(
545
546
547
548
549
550
551
552
553
554
555
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
556
557
        loader = AutoWeightsLoader(
            self,
558
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
559
560
        )
        return loader.load_weights(weights)