arctic.py 23.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Snowflake Arctic model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
from typing import Optional, Union
8
9
10
11

import torch
from torch import nn

12
from vllm.attention import Attention
13
from vllm.compilation.decorators import support_torch_compile
14
from vllm.config import CacheConfig, VllmConfig
15
16
17
18
19
20
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
21
22
23
24
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
30
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.layers.quantization.deepspeedfp import (
34
35
36
    DeepSpeedFPConfig,
    DeepSpeedFPParameter,
)
37
38
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
40
41
    ParallelLMHead,
    VocabParallelEmbedding,
)
42
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
44
from vllm.platforms import current_platform
45
from vllm.sequence import IntermediateTensors
46
47
from vllm.transformers_utils.configs.arctic import ArcticConfig

48
from .interfaces import SupportsPP, SupportsQuant
49
50
51
52
53
54
55
from .utils import (
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
56

57
58
59
60
logger = init_logger(__name__)


class ArcticMLP(nn.Module):
61
62
63
64
65
66
67
68
69
    def __init__(
        self,
        config: ArcticConfig,
        expert_id: int = -1,
        is_residual_mlp: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ):
70
        super().__init__()
71
72
73
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

74
75
76
77
78
79
80
81
82
83
84
85
86
87
        self.ffn_dim = (
            config.intermediate_size if not is_residual_mlp else self.hidden_size
        )

        self.w13 = MergedColumnParallelLinear(
            self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config
        )
        self.w2 = RowParallelLinear(
            self.ffn_dim,
            self.hidden_size,
            bias=False,
            reduce_results=reduce_results,
            quant_config=quant_config,
        )
88
        if config.hidden_act != "silu":
89
90
91
92
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

107
108
109
110
111
112
113
114
115
    def __init__(
        self,
        config: ArcticConfig,
        tp_size: Optional[int] = None,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
        prefix: str = "",
    ):
116
        super().__init__()
117

118
        layer_id = extract_layer_index(prefix)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
135
136
137
138
139
140
            self.mlp = ArcticMLP(
                config,
                quant_config=quant_config,
                reduce_results=reduce_results,
                prefix=f"{prefix}.mlp",
            )
141
        else:
142
143
144
145
146
147
148
149
            self.gate = ReplicatedLinear(
                self.hidden_size,
                self.num_experts,
                bias=False,
                params_dtype=self.params_dtype,
                quant_config=quant_config,
                prefix=f"{prefix}.gate",
            )
150
151
            if self.is_quant:
                self.ws = DeepSpeedFPParameter(
152
153
154
                    torch.Size(
                        (self.num_experts, 2 * self.intermediate_size, self.hidden_size)
                    ),
155
156
157
158
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
                self.w2s = DeepSpeedFPParameter(
159
160
161
                    torch.Size(
                        (self.num_experts, self.hidden_size, self.intermediate_size)
                    ),
162
163
164
165
166
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
            else:
                self.ws = nn.Parameter(
167
168
169
170
171
172
173
174
                    torch.empty(
                        self.num_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
175
                self.w2s = nn.Parameter(
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
                    torch.empty(
                        self.num_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
            set_weight_attrs(
                self.ws,
                {
                    "weight_loader": self.weight_loader,
                },
            )
            set_weight_attrs(
                self.w2s,
                {
                    "weight_loader": self.weight_loader,
                },
            )

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        expert_id: int,
    ):
204
205
206
207
208
209
210
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.ds_dequantize() if self.is_quant else param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
211
212
213
            param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
                shard, :
            ]
214
215
216
217
218
219
220
221
222
223
224
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]
        if self.is_quant:
            param.ds_quantize_(param_data)

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
225
        topk_weights, topk_ids, token_expert_indices = fused_topk(
226
227
            hidden_states, router_logits, self.top_k, renormalize=do_normalize
        )
228
229
230
231
        # topk_ids: (num_tokens, k)
        if self.is_quant:
            if 2 * num_tokens <= self.num_experts:
                # If much fewer tokens than experts, use selective dequantize.
232
233
                ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten())
                w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten())
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
                # We gathered the experts to the tokens so update the mapping.
                topk_ids = torch.arange(
                    0,
                    topk_ids.numel(),
                    device=topk_ids.device,
                ).reshape(topk_ids.shape)
            else:
                ws_dequantized = self.ws.ds_dequantize()
                w2s_dequantized = self.w2s.ds_dequantize()

        final_hidden_states = fused_experts(
            hidden_states,
            ws_dequantized if self.is_quant else self.ws,
            w2s_dequantized if self.is_quant else self.w2s,
            topk_weights,
            topk_ids,
250
251
            inplace=True,
        )
252
        if self.reduce_results and self.tp_size > 1:
253
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
268
        cache_config: Optional[CacheConfig] = None,
269
        quant_config: Optional[QuantizationConfig] = None,
270
        prefix: str = "",
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.scaling = self.head_dim**-0.5

294
295
296
297
298
299
300
301
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

318
319
320
321
322
323
324
325
326
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
327
328
329
330
331
332
333
334
335

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
336
        attn_output = self.attn(q, k, v)
337
338
339
340
341
342
343
344
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
345
        cache_config: Optional[CacheConfig] = None,
346
        quant_config: Optional[QuantizationConfig] = None,
347
        prefix: str = "",
348
349
350
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
351
        layer_idx = extract_layer_index(prefix)
352
353
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
354
355
356
357
358
359
        self.self_attn = ArcticAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
360
361
362
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
363
364
365
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
366

367
368
369
370
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
371
372

        if self.use_residual:
373
374
375
376
377
378
379
380
381
            self.residual_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
            self.residual_mlp = ArcticMLP(
                config,
                is_residual_mlp=True,
                reduce_results=False,
                prefix=f"{prefix}.residual_mlp",
            )
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


413
@support_torch_compile
414
class ArcticModel(nn.Module):
415
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
416
        super().__init__()
417
418
419
420
421

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

422
423
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
424
425
            self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
        )
426
427
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
428
            lambda prefix: ArcticDecoderLayer(
429
430
431
432
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
433
434
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
435
436
437
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
438

439
440
441
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

442
443
444
445
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
446
        intermediate_tensors: Optional[IntermediateTensors],
447
        inputs_embeds: Optional[torch.Tensor] = None,
448
449
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
450
451
452
453
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
454
455
456
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
457
        for layer in islice(self.layers, self.start_layer, self.end_layer):
458
            hidden_states = layer(positions, hidden_states)
459
460
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
461
462
463
464
        hidden_states = self.norm(hidden_states)
        return hidden_states


465
466
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
467

468
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
469
        super().__init__()
470
471
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
472
        self.config = config
473
474
475
        self.model = ArcticModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
476
477
478
479
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
480
            quant_config=quant_config,
481
            prefix=maybe_prefix(prefix, "lm_head"),
482
        )
483
484
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
485
486
487
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.unpadded_vocab_size = config.vocab_size
488
489
490
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
491
        self.make_empty_intermediate_tensors = (
492
493
            self.model.make_empty_intermediate_tensors
        )
494

495
496
497
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

498
499
500
501
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
502
        intermediate_tensors: Optional[IntermediateTensors] = None,
503
        inputs_embeds: Optional[torch.Tensor] = None,
504
    ) -> Union[torch.Tensor, IntermediateTensors]:
505
506
507
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
508
509
        return hidden_states

510
511
512
513
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
514
        logits = self.logits_processor(self.lm_head, hidden_states)
515
516
        return logits

517
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
518
519
520
521
522
523
524
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

525
526
        mlp_params_mapping: list[tuple[str, str, int]] = []
        expert_params_mapping: list[tuple[str, str, int]] = []
527
528
529
530
        num_layers = self.config.num_hidden_layers

        for layer in range(num_layers):
            mlp_params_mapping.append(
531
532
533
534
535
536
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w1.weight",
                    0,
                )
            )
537
            mlp_params_mapping.append(
538
539
540
541
542
543
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w3.weight",
                    1,
                )
            )
544
545
546
            if layer % 2 == 0:
                # MLP layers
                mlp_params_mapping.append(
547
548
549
550
551
552
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
                        0,
                    )
                )
553
                mlp_params_mapping.append(
554
555
556
557
558
559
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
                        1,
                    )
                )
560
561
562
563
            else:
                # MoE layers
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
564
565
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id)
                    )
566
                    expert_params_mapping.append(
567
568
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id)
                    )
569
                    expert_params_mapping.append(
570
571
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id)
                    )
572
573

        params_dict = dict(self.named_parameters())
574
        loaded_params: set[str] = set()
575
576
577
578

        logger.info(
            "It will take ~10 minutes loading from the 16-bit weights. "
            "Alternatively, use the prequantized 8-bit weights of arctic "
579
580
            "and set load-format to `sharded_state` will accelerate loading."
        )
581
        for name, loaded_weight in weights:
582
            for param_name, weight_name, shard_id in stacked_params_mapping:
583
584
585
586
587
588
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
589
590
                if is_pp_missing_parameter(name, self):
                    continue
591
592
593
594
595
596
597
598
599
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
600
601
                    if is_pp_missing_parameter(name, self):
                        continue
602
603
604
605
606
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
607
                    for param_name, weight_name, shard_id in expert_params_mapping:
608
609
610
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
611
612
                        if is_pp_missing_parameter(name, self):
                            continue
613
614
                        param = params_dict[name]
                        weight_loader = param.weight_loader
615
616
617
                        weight_loader(
                            param, loaded_weight, weight_name, expert_id=shard_id
                        )
618
619
620
621
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
622
623
                        if is_pp_missing_parameter(name, self):
                            continue
624
625
                        param = params_dict[name]

626
627
628
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
629
                        weight_loader(param, loaded_weight)
630
631
            loaded_params.add(name)
        return loaded_params