bailing_moe.py 22.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Adapted from
# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py
# Copyright 2023 The vLLM team.
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BailingMoE model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
30
31
32
33
34

import torch
import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig

35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
38
39
40
41
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
42
from vllm.model_executor.layers.activation import SiluAndMul
43
from vllm.model_executor.layers.attention import Attention
44
45
46
47
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
48
from vllm.model_executor.layers.layernorm import RMSNorm
49
50
51
52
53
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
54
from vllm.model_executor.layers.logits_processor import LogitsProcessor
55
from vllm.model_executor.layers.quantization import QuantizationConfig
56
57
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
59
60
    ParallelLMHead,
    VocabParallelEmbedding,
)
61
62
63
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

64
from .interfaces import SupportsLoRA, SupportsPP
65
66
67
68
69
70
71
72
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
73
74
75
76
77
78


class BailingAttention(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
79
80
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
ant-yy's avatar
ant-yy committed
81
        reduce_results: bool = True,
82
83
84
85
86
87
88
89
90
91
92
93
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.total_kv_heads = config.num_key_value_heads
        tp_size = get_tensor_model_parallel_world_size()

        assert self.total_num_heads % tp_size == 0
        assert self.total_num_heads >= self.total_kv_heads

        self.num_heads = self.total_num_heads // tp_size
94
        self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
95
        self.q_size_per_rank = self.head_dim * self.num_heads
96
        self.num_kv_heads = max(1, self.total_kv_heads // tp_size)
97
98
        self.kv_size_per_rank = self.num_kv_heads * self.head_dim
        self.scale = self.head_dim**-0.5
ant-yy's avatar
ant-yy committed
99
100
        self.use_qk_norm = getattr(config, "use_qk_norm", False)
        self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
101
102
103
104
105
106
107
108
109
110
111

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_kv_heads,
            bias=(config.use_bias or config.use_qkv_bias),
            quant_config=quant_config,
            prefix=f"{prefix}.query_key_value",
        )

ant-yy's avatar
ant-yy committed
112
        if self.use_qk_norm:
113
114
115
116
117
118
119
120
121
122
            self.query_layernorm = (
                RMSNorm(self.head_dim, eps=config.rms_norm_eps)
                if self.use_rmsnorm
                else nn.LayerNorm(self.head_dim, eps=1e-6)
            )
            self.key_layernorm = (
                RMSNorm(self.head_dim, eps=config.rms_norm_eps)
                if self.use_rmsnorm
                else nn.LayerNorm(self.head_dim, eps=1e-6)
            )
ant-yy's avatar
ant-yy committed
123

124
125
126
127
128
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=config.use_bias,
            quant_config=quant_config,
ant-yy's avatar
ant-yy committed
129
            reduce_results=reduce_results,
130
131
132
            prefix=f"{prefix}.dense",
        )

133
134
        rotary_dim = getattr(config, "rotary_dim", self.head_dim)
        config.rope_parameters["partial_rotary_factor"] = rotary_dim / self.head_dim
135
136
137
138

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=config.max_position_embeddings,
139
            rope_parameters=config.rope_parameters,
140
            is_neox_style=True,
ant-yy's avatar
ant-yy committed
141
142
143
144
145
146
147
148
149
        )

        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scale,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
150
151
152
153
154
155
156
157
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
158
159
160
        q, k, v = qkv.split(
            [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1
        )
161

ant-yy's avatar
ant-yy committed
162
163
164
165
166
167
168
169
        if self.use_qk_norm:
            q = q.view(-1, self.num_heads, self.head_dim)
            k = k.view(-1, self.num_kv_heads, self.head_dim)
            q = self.query_layernorm(q)
            k = self.key_layernorm(k)
            q = q.view(-1, self.q_size_per_rank)
            k = k.view(-1, self.kv_size_per_rank)

170
171
172
173
174
175
176
177
178
179
180
181
182
        q, k = self.rotary_emb(position_ids, q, k)

        context_layer = self.attn(q, k, v)

        attn_output, _ = self.dense(context_layer)
        return attn_output


class BailingMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: PretrainedConfig,
183
184
        quant_config: QuantizationConfig | None = None,
        reduce_results: bool | None = True,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [intermediate_size] * 2,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            config.hidden_size,
            bias=config.use_bias,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
        x, _ = self.down_proj(x)
        return x


class BailingMoE(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: PretrainedConfig,
217
218
        quant_config: QuantizationConfig | None = None,
        reduce_results: bool | None = True,
219
220
221
222
223
224
225
226
227
228
229
230
        prefix: str = "",
    ):
        super().__init__()

        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_expert_prob = config.norm_topk_prob
        self.hidden_size = config.hidden_size
        self.quant_config = quant_config
        self.num_shared_experts = config.num_shared_experts
ant-yy's avatar
ant-yy committed
231
232
233
        self.score_function = getattr(config, "score_function", None)
        self.n_group = getattr(config, "n_group", None)
        self.topk_group = getattr(config, "topk_group", None)
234
235
        self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
ant-yy's avatar
ant-yy committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

        router_dtype = getattr(config, "router_dtype", None)
        if router_dtype is None:
            self.router_dtype = None
        elif router_dtype == "fp32":
            self.router_dtype = torch.float32
        else:
            self.router_dtype = torch.bfloat16

        self.gate = nn.Linear(
            self.hidden_size,
            self.num_experts,
            bias=False,
            dtype=self.router_dtype,
        )

        if getattr(config, "moe_router_enable_expert_bias", False):
            self.gate.expert_bias = nn.Parameter(
254
255
                torch.empty((config.num_experts,), dtype=torch.float32)
            )
ant-yy's avatar
ant-yy committed
256
257
258
        else:
            self.gate.expert_bias = None

259
260
261
        self.correction_bias = (
            self.gate.expert_bias.data if self.gate.expert_bias is not None else None
        )
ant-yy's avatar
ant-yy committed
262
263
264

        if self.score_function is not None:
            assert (
265
                self.score_function == "softmax" and self.correction_bias is None
ant-yy's avatar
ant-yy committed
266
            ) or (
267
268
                self.score_function == "sigmoid" and self.correction_bias is not None
            ), (
269
270
                "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"  # noqa: E501
            )
ant-yy's avatar
ant-yy committed
271
272
273
274
        else:
            # default value for scoring_func
            self.score_function = "softmax"

275
        if self.num_shared_experts > 0:
ant-yy's avatar
ant-yy committed
276
277
278
279
280
            if hasattr(config, "moe_shared_expert_intermediate_size"):
                intermediate_size = config.moe_shared_expert_intermediate_size
            else:
                intermediate_size = config.moe_intermediate_size
            intermediate_size *= config.num_shared_experts
281
282
283
284
285
            self.shared_experts = BailingMLP(
                intermediate_size=intermediate_size,
                config=config,
                quant_config=quant_config,
                reduce_results=False,
286
287
                prefix=f"{prefix}.shared_experts",
            )
288
289
290
        else:
            self.shared_experts = None

291
        self.experts = FusedMoE(
292
293
294
295
296
297
298
299
300
301
302
303
304
            shared_experts=self.shared_experts,
            num_experts=self.num_experts,
            top_k=self.top_k,
            hidden_size=self.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=self.norm_expert_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
            scoring_func=self.score_function,
            e_score_correction_bias=self.gate.expert_bias,
            num_expert_group=self.n_group,
            topk_group=self.topk_group,
            use_grouped_topk=self.use_grouped_topk,
305
            router_logits_dtype=self.router_dtype,
306
            routed_scaling_factor=self.routed_scaling_factor,
307
308
        )

309
310
311
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_size)
312

313
        # router_logits: (num_tokens, n_experts)
ant-yy's avatar
ant-yy committed
314
315
316
        router_logits = self.gate(hidden_states.to(self.router_dtype))
        router_logits = router_logits.to(hidden_states.dtype)

317
318
319
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
320
321
322
323
324
325
326
        return final_hidden_states.view(num_tokens, hidden_size)


class BailingMoeBlock(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
327
328
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
329
330
331
        prefix: str = "",
    ):
        super().__init__()
332
        layer_idx = int(prefix.split(".")[-1])
ant-yy's avatar
ant-yy committed
333
        self.config = config
334
335
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size
ant-yy's avatar
ant-yy committed
336

337
        self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
338
339
340
        self.attention = BailingAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attention"
        )
ant-yy's avatar
ant-yy committed
341

342
        self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
ant-yy's avatar
ant-yy committed
343
344
345
346
347
348

        # Choose MLP class based on the number of experts and layer index
        if layer_idx < config.first_k_dense_replace:
            mlp_class = BailingMLP
        else:
            mlp_class = BailingMoE
349
350
351
        self.mlp = mlp_class(
            intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp"
        )
352
353
354
355
356

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
357
        residual: torch.Tensor | None,
358
359
360
361
362
    ) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
363
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
364
365
366
367
368
369

        hidden_states = self.attention(
            hidden_states=hidden_states,
            position_ids=position_ids,
        )

370
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
371
372
373
374
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


375
@support_torch_compile
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class BailingMoeModel(nn.Module):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_dim = config.hidden_size
391
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
392

393
394
395
        if get_pp_group().is_first_rank or (
            self.tie_word_embeddings and get_pp_group().is_last_rank
        ):
396
            self.word_embeddings = VocabParallelEmbedding(
ant-yy's avatar
ant-yy committed
397
398
399
400
401
                self.vocab_size,
                self.embed_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.word_embeddings",
            )
402
403
404
405
406
407
408
409
410
411
412
413
414
        else:
            self.word_embeddings = PPMissingLayer()

        self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: BailingMoeBlock(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
415
416
            prefix=f"{prefix}.layers",
        )
417

418
419
420
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
421
422
423
424
425
426

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

427
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
428
429
430
431
        return self.word_embeddings(input_ids)

    def forward(
        self,
432
        input_ids: torch.Tensor | None,
433
        position_ids: torch.Tensor,
434
435
436
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
437
438
439
440
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
441
                hidden_states = self.embed_input_ids(input_ids)
442
443
444
445
446
447
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

448
        for layer in islice(self.layers, self.start_layer, self.end_layer):
449
450
451
452
453
454
455
            hidden_states, residual = layer(
                hidden_states,
                position_ids,
                residual,
            )

        if not get_pp_group().is_last_rank:
456
457
458
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
ant-yy's avatar
ant-yy committed
459
460
461
462
463
        else:
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
464
465
        return hidden_states

466
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
467
        return fused_moe_make_expert_params_mapping(
468
            self,
469
470
471
472
473
474
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts,
        )

475
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
476
477
478
479
480
481
482
483
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
484
        expert_params_mapping = self.get_expert_mapping()
485
        for name, loaded_weight in weights:
486
487
488
489
490
491
492
493
            if (
                hasattr(self.config, "norm_head")
                and self.config.norm_head
                and "lm_head.weight" in name
            ):
                loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)

            for param_name, weight_name, shard_id in stacked_params_mapping:
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
                if weight_name not in name:
                    continue
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name, self):
                        continue
ant-yy's avatar
ant-yy committed
521
522
                    if name not in params_dict:
                        continue
523
524
                    param = params_dict[name]
                    weight_loader = param.weight_loader
ant-yy's avatar
ant-yy committed
525
526
527
528
529
530
531
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
532
533
534
535
536
537
538
539
540
541
542
                    break
                else:
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    if name not in params_dict:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
543
544
545
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
546
547
548
549
550
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


551
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()

ant-yy's avatar
ant-yy committed
568
569
        config = vllm_config.model_config.hf_config.get_text_config()
        vllm_config.model_config.hf_config = config
570
571
572
573
574
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config
        self.max_position_embeddings = config.max_position_embeddings
575
576
577
578
        self.model = BailingMoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
ant-yy's avatar
ant-yy committed
579

580
        if get_pp_group().is_last_rank:
ant-yy's avatar
ant-yy committed
581
582
583
584
585
586
587
            if self.tie_word_embeddings:
                self.lm_head = self.model.word_embeddings
            else:
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
588
                    prefix=maybe_prefix(prefix, "lm_head"),
ant-yy's avatar
ant-yy committed
589
                )
590
591
592
593
594
            self.logits_processor = LogitsProcessor(config.vocab_size)
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
595
596
            self.model.make_empty_intermediate_tensors
        )
597

598
599
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
600
601
602

    def forward(
        self,
603
        input_ids: torch.Tensor | None,
604
        positions: torch.Tensor,
605
606
607
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
608
609
610
        model_output = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
611
612
613
614
615
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
616
    ) -> torch.Tensor | None:
617
        logits = self.logits_processor(self.lm_head, hidden_states)
618
619
        return logits

620
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
621
622
        loader = AutoWeightsLoader(
            self,
ant-yy's avatar
ant-yy committed
623
            skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
624
625
        )
        return loader.load_weights(weights)
626
627
628

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()
ant-yy's avatar
ant-yy committed
629
630
631
632


class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
    pass