dots1.py 20.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2025 The rednote-hilab team.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only dots1 model."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
from typing import Any
31
32
33

import torch
from torch import nn
34
from transformers import Dots1Config
35
36
37
38

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
39
40
41
42
43
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
44
from vllm.model_executor.layers.activation import SiluAndMul
45
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
46
from vllm.model_executor.layers.layernorm import RMSNorm
47
48
49
50
51
52
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
53
54
55
56
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
58
59
    ParallelLMHead,
    VocabParallelEmbedding,
)
60
from vllm.model_executor.model_loader.weight_utils import (
61
62
63
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
64
65
from vllm.sequence import IntermediateTensors

66
from .interfaces import SupportsLoRA, SupportsPP
67
68
69
70
71
72
73
74
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
75
76
77
78
79
80
81
82


class Dots1MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
83
        quant_config: QuantizationConfig | None = None,
84
85
86
87
88
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
89
90
91
92
93
94
95
96
97
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
98
99
            bias=False,
            quant_config=quant_config,
100
101
102
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
103
        if hidden_act != "silu":
104
105
106
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
107
108
109
110
111
112
113
114
115
116
117
118
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Dots1MoE(nn.Module):
    def __init__(
        self,
119
        config: Dots1Config,
120
        quant_config: QuantizationConfig | None = None,
121
122
123
124
125
126
127
128
        prefix: str = "",
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts

        if config.hidden_act != "silu":
129
130
131
132
133
134
135
136
137
138
139
140
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
141
        if config.topk_method == "noaux_tc":
142
143
144
            self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts)
            )
145
146
147
        else:
            self.gate.e_score_correction_bias = None

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        if config.n_shared_experts is not None:
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
            self.shared_experts = Dots1MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=f"{prefix}.shared_experts",
            )
        else:
            self.shared_experts = None

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
163
164
165
166
167
168
169
170
171
172
173
174
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
175
176
            # we do scaling outside, set factor to 1.0 to avoid double mul
            routed_scaling_factor=1.0,
177
178
            e_score_correction_bias=self.gate.e_score_correction_bias,
        )
179
180
181
182

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
183

184
        router_logits, _ = self.gate(hidden_states)
185
186
187
188
        final_hidden_states = (
            self.experts(hidden_states=hidden_states, router_logits=router_logits)
            * self.routed_scaling_factor
        )
189
190
191
192

        if self.shared_experts is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]

193
        if self.tp_size > 1:
194
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
195
196
197
198
199
200
201
202
203
        return final_hidden_states.view(num_tokens, hidden_dim)


class Dots1Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
204
        config: Dots1Config,
205
        rope_theta: float = 10000,
206
        rope_scaling: dict[str, Any] | None = None,
207
        max_position_embeddings: int = 8192,
208
209
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
228
        self.head_dim = getattr(config, "head_dim", hidden_size // self.total_num_heads)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        attention_bias = config.attention_bias

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=attention_bias,
            quant_config=quant_config,
243
            prefix=f"{prefix}.qkv_proj",
244
245
246
247
248
249
250
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
251
            prefix=f"{prefix}.o_proj",
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)

273
274
275
    def forward(
        self, positions: torch.Tensor, hidden_states: torch.Tensor
    ) -> torch.Tensor:
276
277
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
278
279
280
281
        q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape)
        k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
            k.shape
        )
282
283
284
285
286
287
288
289
290
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Dots1DecoderLayer(nn.Module):
    def __init__(
        self,
291
        config: Dots1Config,
292
293
        prefix: str,
        model_config: ModelConfig,
294
295
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
296
297
298
299
300
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
301
302
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        layer_idx = int(prefix.split(sep=".")[-1])
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        self.layer_idx = layer_idx

        self.self_attn = Dots1Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            config=config,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
317
318
319
320
321
322
323
324
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0
        ):
            self.mlp = Dots1MoE(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
325
326
327
328
329
330
331
332
        else:
            self.mlp = Dots1MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
333
334
335
336
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
337
338
339
340
341
342
        self.routed_scaling_factor = config.routed_scaling_factor

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
343
        residual: torch.Tensor | None,
344
345
346
347
348
    ) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
349
350
351
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
352
353
354
355
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


356
@support_torch_compile
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
class Dots1Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config

        self.vocab_size = config.vocab_size

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
376
377
                prefix=f"{prefix}.embed_tokens",
            )
378
379
380
381
382
383
384
385
386
387
388
389
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Dots1DecoderLayer(
                config,
                prefix,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
            ),
390
391
            prefix=f"{prefix}.layers",
        )
392
393
394
395
396

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
397
398
399
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
400

401
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
402
403
404
405
406
407
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
408
409
410
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
411
412
413
414
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
415
                hidden_states = self.embed_input_ids(input_ids)
416
417
418
419
420
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
421
        for layer in islice(self.layers, self.start_layer, self.end_layer):
422
423
424
425
426
427
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        if not get_pp_group().is_last_rank:
428
429
430
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
431
432
433
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

434
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
435
        return SharedFusedMoE.make_expert_params_mapping(
436
437
438
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
439
440
            num_experts=self.config.n_routed_experts,
        )
441

442
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
443
444
445
446
447
448
449
450
451
452
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
453
        expert_params_mapping = self.get_expert_mapping()
454
455
456
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
457
            for param_name, weight_name, shard_id in stacked_params_mapping:
458
459
                if weight_name not in name:
                    continue
460
                if ("mlp.experts." in name) and name not in params_dict:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
                    continue
                name = name.replace(weight_name, param_name)
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = param.weight_loader
483
484
485
486
487
488
489
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
490
491
492
493
494
495
496
497
498
499
                    break
                else:
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
500
501
502
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
503
504
505
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526


class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
527
528
529
        self.model = Dots1Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
530
        if get_pp_group().is_last_rank:
531
532
533
534
535
536
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
537
538
539
540
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
541
542
            self.model.make_empty_intermediate_tensors
        )
543

544
545
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
546
547
548
549
550

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
551
552
553
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
554
555
556
557
558
559
560
561
562
563
564
        hidden_states = self.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
565
    ) -> torch.Tensor | None:
566
        logits = self.logits_processor(self.lm_head, hidden_states)
567
568
        return logits

569
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
570
571
572
573
574
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()