dots1.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2025 The rednote-hilab team.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only dots1 model."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
31
32

import torch
from torch import nn
33
from transformers import Dots1Config
34
35
36

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
37
38
39
40
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
41
from vllm.model_executor.layers.activation import SiluAndMul
42
from vllm.model_executor.layers.attention import Attention
43
from vllm.model_executor.layers.fused_moe import FusedMoE
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
51
52
53
54
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.model_executor.model_loader.weight_utils import (
59
60
61
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
62
63
from vllm.sequence import IntermediateTensors

64
from .interfaces import SupportsLoRA, SupportsPP
65
66
67
68
69
70
71
72
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
73
74
75
76
77
78
79
80


class Dots1MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
81
        quant_config: QuantizationConfig | None = None,
82
83
84
85
86
        reduce_results: bool = True,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
87
88
89
90
91
92
93
94
95
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
96
97
            bias=False,
            quant_config=quant_config,
98
99
100
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
101
        if hidden_act != "silu":
102
103
104
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
105
106
107
108
109
110
111
112
113
114
115
116
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Dots1MoE(nn.Module):
    def __init__(
        self,
117
        config: Dots1Config,
118
        quant_config: QuantizationConfig | None = None,
119
120
121
122
123
124
125
        prefix: str = "",
    ):
        super().__init__()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts

        if config.hidden_act != "silu":
126
127
128
129
130
131
132
133
134
135
136
137
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
138
        if config.topk_method == "noaux_tc":
139
140
141
            self.gate.e_score_correction_bias = nn.Parameter(
                torch.empty(config.n_routed_experts)
            )
142
143
144
        else:
            self.gate.e_score_correction_bias = None

145
146
147
148
149
150
151
152
153
154
155
156
157
        if config.n_shared_experts is not None:
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
            self.shared_experts = Dots1MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=f"{prefix}.shared_experts",
            )
        else:
            self.shared_experts = None

158
        self.experts = FusedMoE(
159
            shared_experts=self.shared_experts,
160
161
162
163
164
165
166
167
168
169
170
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
171
            e_score_correction_bias=self.gate.e_score_correction_bias,
172
173
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scale_to_output=True,
174
        )
175
176
177
178

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
179

180
        router_logits, _ = self.gate(hidden_states)
181

182
        final_hidden_states = self.experts(
183
184
            hidden_states=hidden_states, router_logits=router_logits
        )
185
186
187
188
189
190
191
192
193
        return final_hidden_states.view(num_tokens, hidden_dim)


class Dots1Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
194
        config: Dots1Config,
195
        max_position_embeddings: int = 8192,
196
197
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
216
        self.head_dim = getattr(config, "head_dim", hidden_size // self.total_num_heads)
217
218
219
220
221
222
223
224
225
226
227
228
229
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
        attention_bias = config.attention_bias

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=attention_bias,
            quant_config=quant_config,
230
            prefix=f"{prefix}.qkv_proj",
231
232
233
234
235
236
237
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
238
            prefix=f"{prefix}.o_proj",
239
240
241
242
243
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
244
            rope_parameters=config.rope_parameters,
245
246
247
248
249
250
251
252
253
254
255
256
257
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)

258
259
260
    def forward(
        self, positions: torch.Tensor, hidden_states: torch.Tensor
    ) -> torch.Tensor:
261
262
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
263
264
265
266
        q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape)
        k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
            k.shape
        )
267
268
269
270
271
272
273
274
275
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class Dots1DecoderLayer(nn.Module):
    def __init__(
        self,
276
        config: Dots1Config,
277
278
        prefix: str,
        model_config: ModelConfig,
279
280
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
281
282
283
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
284
285
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        layer_idx = int(prefix.split(sep=".")[-1])
286
287
288
289
290
291
292
293
294
295
296
297
        self.layer_idx = layer_idx

        self.self_attn = Dots1Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            config=config,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
298
299
300
301
302
303
304
305
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0
        ):
            self.mlp = Dots1MoE(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
306
307
308
309
310
311
312
313
        else:
            self.mlp = Dots1MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
314
315
316
317
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
318
319
320
321
322
323
        self.routed_scaling_factor = config.routed_scaling_factor

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
324
        residual: torch.Tensor | None,
325
326
327
328
329
    ) -> torch.Tensor:
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
330
331
332
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333
334
335
336
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


337
@support_torch_compile
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class Dots1Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config

        self.vocab_size = config.vocab_size

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
357
358
                prefix=f"{prefix}.embed_tokens",
            )
359
360
361
362
363
364
365
366
367
368
369
370
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Dots1DecoderLayer(
                config,
                prefix,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
            ),
371
372
            prefix=f"{prefix}.layers",
        )
373
374
375
376
377

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
378
379
380
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
381

382
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
383
384
385
386
        return self.embed_tokens(input_ids)

    def forward(
        self,
387
        input_ids: torch.Tensor | None,
388
        positions: torch.Tensor,
389
390
391
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
392
393
394
395
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
396
                hidden_states = self.embed_input_ids(input_ids)
397
398
399
400
401
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
402
        for layer in islice(self.layers, self.start_layer, self.end_layer):
403
404
405
406
407
408
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        if not get_pp_group().is_last_rank:
409
410
411
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
412
413
414
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

415
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
416
        return FusedMoE.make_expert_params_mapping(
417
            self,
418
419
420
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
421
422
            num_experts=self.config.n_routed_experts,
        )
423

424
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
425
426
427
428
429
430
431
432
433
434
        stacked_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
435
        expert_params_mapping = self.get_expert_mapping()
436
437
438
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
439
            for param_name, weight_name, shard_id in stacked_params_mapping:
440
441
                if weight_name not in name:
                    continue
442
                if ("mlp.experts." in name) and name not in params_dict:
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
                    continue
                name = name.replace(weight_name, param_name)
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = param.weight_loader
465
466
467
468
469
470
471
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
472
473
474
475
476
477
478
479
480
481
                    break
                else:
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                    if is_pp_missing_parameter(name, self):
                        continue
                    param = params_dict[name]
482
483
484
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
485
486
487
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508


class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
509
510
511
        self.model = Dots1Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
512
        if get_pp_group().is_last_rank:
513
514
515
516
517
518
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
519
520
521
522
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
523
524
            self.model.make_empty_intermediate_tensors
        )
525

526
527
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
528
529
530

    def forward(
        self,
531
        input_ids: torch.Tensor | None,
532
        positions: torch.Tensor,
533
534
535
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
536
537
538
539
540
541
542
543
544
545
546
        hidden_states = self.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
547
    ) -> torch.Tensor | None:
548
        logits = self.logits_processor(self.lm_head, hidden_states)
549
550
        return logits

551
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
552
553
554
555
556
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()