arctic.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Snowflake Arctic model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10

import torch
from torch import nn

11
from vllm.attention import Attention
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, VllmConfig
14
15
16
17
18
19
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
20
21
22
23
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.quantization.deepspeedfp import (
33
34
35
    DeepSpeedFPConfig,
    DeepSpeedFPParameter,
)
36
37
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
39
40
    ParallelLMHead,
    VocabParallelEmbedding,
)
41
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.platforms import current_platform
44
from vllm.sequence import IntermediateTensors
45
46
from vllm.transformers_utils.configs.arctic import ArcticConfig

47
from .interfaces import SupportsPP, SupportsQuant
48
49
50
51
52
53
54
from .utils import (
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
55

56
57
58
59
logger = init_logger(__name__)


class ArcticMLP(nn.Module):
60
61
62
63
64
    def __init__(
        self,
        config: ArcticConfig,
        expert_id: int = -1,
        is_residual_mlp: bool = False,
65
        quant_config: QuantizationConfig | None = None,
66
67
68
        reduce_results: bool = True,
        prefix: str = "",
    ):
69
        super().__init__()
70
71
72
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

73
74
75
76
77
        self.ffn_dim = (
            config.intermediate_size if not is_residual_mlp else self.hidden_size
        )

        self.w13 = MergedColumnParallelLinear(
78
79
80
81
82
            self.hidden_size,
            [self.ffn_dim] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w13",
83
84
85
86
87
88
89
        )
        self.w2 = RowParallelLinear(
            self.ffn_dim,
            self.hidden_size,
            bias=False,
            reduce_results=reduce_results,
            quant_config=quant_config,
90
            prefix=f"{prefix}.w2",
91
        )
92
        if config.hidden_act != "silu":
93
94
95
96
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

111
112
113
    def __init__(
        self,
        config: ArcticConfig,
114
115
116
        tp_size: int | None = None,
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
117
118
119
        reduce_results: bool = True,
        prefix: str = "",
    ):
120
        super().__init__()
121

122
        layer_id = extract_layer_index(prefix)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
139
140
141
142
143
144
            self.mlp = ArcticMLP(
                config,
                quant_config=quant_config,
                reduce_results=reduce_results,
                prefix=f"{prefix}.mlp",
            )
145
        else:
146
147
148
149
150
151
152
153
            self.gate = ReplicatedLinear(
                self.hidden_size,
                self.num_experts,
                bias=False,
                params_dtype=self.params_dtype,
                quant_config=quant_config,
                prefix=f"{prefix}.gate",
            )
154
155
            if self.is_quant:
                self.ws = DeepSpeedFPParameter(
156
157
158
                    torch.Size(
                        (self.num_experts, 2 * self.intermediate_size, self.hidden_size)
                    ),
159
160
161
162
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
                self.w2s = DeepSpeedFPParameter(
163
164
165
                    torch.Size(
                        (self.num_experts, self.hidden_size, self.intermediate_size)
                    ),
166
167
168
169
170
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
            else:
                self.ws = nn.Parameter(
171
172
173
174
175
176
177
178
                    torch.empty(
                        self.num_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
179
                self.w2s = nn.Parameter(
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                    torch.empty(
                        self.num_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
            set_weight_attrs(
                self.ws,
                {
                    "weight_loader": self.weight_loader,
                },
            )
            set_weight_attrs(
                self.w2s,
                {
                    "weight_loader": self.weight_loader,
                },
            )

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        expert_id: int,
    ):
208
209
210
211
212
213
214
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.ds_dequantize() if self.is_quant else param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
215
216
217
            param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
                shard, :
            ]
218
219
220
221
222
223
224
225
226
227
228
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]
        if self.is_quant:
            param.ds_quantize_(param_data)

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
229
        topk_weights, topk_ids, token_expert_indices = fused_topk(
230
231
            hidden_states, router_logits, self.top_k, renormalize=do_normalize
        )
232
233
234
235
        # topk_ids: (num_tokens, k)
        if self.is_quant:
            if 2 * num_tokens <= self.num_experts:
                # If much fewer tokens than experts, use selective dequantize.
236
237
                ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten())
                w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten())
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                # We gathered the experts to the tokens so update the mapping.
                topk_ids = torch.arange(
                    0,
                    topk_ids.numel(),
                    device=topk_ids.device,
                ).reshape(topk_ids.shape)
            else:
                ws_dequantized = self.ws.ds_dequantize()
                w2s_dequantized = self.w2s.ds_dequantize()

        final_hidden_states = fused_experts(
            hidden_states,
            ws_dequantized if self.is_quant else self.ws,
            w2s_dequantized if self.is_quant else self.w2s,
            topk_weights,
            topk_ids,
254
255
            inplace=True,
        )
256
        if self.reduce_results and self.tp_size > 1:
257
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
272
273
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.scaling = self.head_dim**-0.5

298
299
300
301
302
303
304
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
305
            prefix=f"{prefix}.qkv_proj",
306
        )
307
308
309
310
311
312
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
313
            prefix=f"{prefix}.o_proj",
314
315
316
317
318
319
320
321
322
323
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

324
325
326
327
328
329
330
331
332
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
333
334
335
336
337
338
339
340
341

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
342
        attn_output = self.attn(q, k, v)
343
344
345
346
347
348
349
350
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
351
352
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
353
        prefix: str = "",
354
355
356
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
357
        layer_idx = extract_layer_index(prefix)
358
359
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
360
361
362
363
364
365
        self.self_attn = ArcticAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
366
367
368
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
369
370
371
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
372

373
374
375
376
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
377
378

        if self.use_residual:
379
380
381
382
383
384
385
386
387
            self.residual_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
            self.residual_mlp = ArcticMLP(
                config,
                is_residual_mlp=True,
                reduce_results=False,
                prefix=f"{prefix}.residual_mlp",
            )
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


419
@support_torch_compile
420
class ArcticModel(nn.Module):
421
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
422
        super().__init__()
423
424
425
426
427

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

428
429
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
430
431
            self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
        )
432
433
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
434
            lambda prefix: ArcticDecoderLayer(
435
436
437
438
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
439
440
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
441
442
443
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
444

445
446
447
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

448
449
450
451
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
452
453
454
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
455
        if get_pp_group().is_first_rank:
456
457
458
459
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
460
461
462
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
463
        for layer in islice(self.layers, self.start_layer, self.end_layer):
464
            hidden_states = layer(positions, hidden_states)
465
466
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
467
468
469
470
        hidden_states = self.norm(hidden_states)
        return hidden_states


471
472
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
473

474
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
475
        super().__init__()
476
477
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
478
        self.config = config
479
480
481
        self.model = ArcticModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
482
483
484
485
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
486
            quant_config=quant_config,
487
            prefix=maybe_prefix(prefix, "lm_head"),
488
        )
489
490
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
491
492
493
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.unpadded_vocab_size = config.vocab_size
494
495
496
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
497
        self.make_empty_intermediate_tensors = (
498
499
            self.model.make_empty_intermediate_tensors
        )
500

501
502
503
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

504
505
506
507
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
508
509
510
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
511
512
513
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
514
515
        return hidden_states

516
517
518
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
519
    ) -> torch.Tensor | None:
520
        logits = self.logits_processor(self.lm_head, hidden_states)
521
522
        return logits

523
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
524
525
526
527
528
529
530
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

531
532
        mlp_params_mapping: list[tuple[str, str, int]] = []
        expert_params_mapping: list[tuple[str, str, int]] = []
533
534
535
536
        num_layers = self.config.num_hidden_layers

        for layer in range(num_layers):
            mlp_params_mapping.append(
537
538
539
540
541
542
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w1.weight",
                    0,
                )
            )
543
            mlp_params_mapping.append(
544
545
546
547
548
549
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w3.weight",
                    1,
                )
            )
550
551
552
            if layer % 2 == 0:
                # MLP layers
                mlp_params_mapping.append(
553
554
555
556
557
558
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
                        0,
                    )
                )
559
                mlp_params_mapping.append(
560
561
562
563
564
565
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
                        1,
                    )
                )
566
567
568
569
            else:
                # MoE layers
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
570
571
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id)
                    )
572
                    expert_params_mapping.append(
573
574
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id)
                    )
575
                    expert_params_mapping.append(
576
577
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id)
                    )
578
579

        params_dict = dict(self.named_parameters())
580
        loaded_params: set[str] = set()
581
582
583
584

        logger.info(
            "It will take ~10 minutes loading from the 16-bit weights. "
            "Alternatively, use the prequantized 8-bit weights of arctic "
585
586
            "and set load-format to `sharded_state` will accelerate loading."
        )
587
        for name, loaded_weight in weights:
588
            for param_name, weight_name, shard_id in stacked_params_mapping:
589
590
591
592
593
594
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
595
596
                if is_pp_missing_parameter(name, self):
                    continue
597
598
599
600
601
602
603
604
605
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
606
607
                    if is_pp_missing_parameter(name, self):
                        continue
608
609
610
611
612
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
613
                    for param_name, weight_name, shard_id in expert_params_mapping:
614
615
616
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
617
618
                        if is_pp_missing_parameter(name, self):
                            continue
619
620
                        param = params_dict[name]
                        weight_loader = param.weight_loader
621
622
623
                        weight_loader(
                            param, loaded_weight, weight_name, expert_id=shard_id
                        )
624
625
626
627
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
628
629
                        if is_pp_missing_parameter(name, self):
                            continue
630
631
                        param = params_dict[name]

632
633
634
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
635
                        weight_loader(param, loaded_weight)
636
637
            loaded_params.add(name)
        return loaded_params