dbrx.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from itertools import islice
6
7
8

import torch
import torch.nn as nn
9
from transformers import DbrxConfig
10

11
from vllm.attention import Attention
12
from vllm.config import CacheConfig, VllmConfig
13
14
15
16
17
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
18
from vllm.model_executor.layers.fused_moe import FusedMoE
19
20
21
22
23
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
24
from vllm.model_executor.layers.logits_processor import LogitsProcessor
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
27
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
31
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
32
from vllm.model_executor.model_loader.weight_utils import (
33
34
35
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
36
from vllm.sequence import IntermediateTensors
37

38
from .interfaces import SupportsPP
39
40
41
42
43
44
45
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
46

47
48
49
50
51
52
53
54

class DbrxRouter(nn.Module):
    """A Router implementation for DBRX that returns logits for each expert
    per token.
    """

    def __init__(
        self,
55
        config: DbrxConfig,
56
        params_dtype: torch.dtype | None = None,
57
58
59
60
61
62
63
64
65
66
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = config.ffn_config.moe_num_experts
        self.d_model = config.d_model
        self.layer = ReplicatedLinear(
            self.d_model,
            self.num_total_experts,
            bias=False,
            params_dtype=params_dtype,
67
            quant_config=None,
68
69
70
71
72
73
74
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        router_logits, _ = self.layer(hidden_states)
        return router_logits


75
class DbrxExperts(FusedMoE):
76
77
    def __init__(
        self,
78
        config: DbrxConfig,
79
80
        quant_config: QuantizationConfig | None = None,
        params_dtype: torch.dtype | None = None,
81
        prefix: str = "",
82
    ):
83
84
85
86
87
88
89
90
91
92
        super().__init__(
            num_experts=config.ffn_config.moe_num_experts,
            top_k=config.ffn_config.moe_top_k,
            hidden_size=config.d_model,
            intermediate_size=config.ffn_config.ffn_hidden_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=get_tensor_model_parallel_world_size(),
93
            prefix=prefix,
94
95
        )
        self.config = config
96
        self.d_model = config.d_model
97
        self.intermediate_size = self.config.ffn_config.ffn_hidden_size // self.tp_size
98

99
    # Define custom weight loader for dbrx model
100
101
102
103
104
105
106
    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        param_name: str,
    ):
107
108
109
110
111
112
113
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        # DBRX uses GLU for each experts.
        # GLU has 3 linear layers: w1, v1 and w2.
        if weight_name.endswith("w1"):
114
115
116
117
118
119
120
121
122
123
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                )
                param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
            elif param_name.endswith("weight_scale"):
                param_data[:, 0] = loaded_weight
            else:
                param_data = loaded_weight
124
        if weight_name.endswith("v1"):
125
126
127
128
129
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                )
130
131
132
                param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[
                    :, shard, :
                ]
133
134
135
136
            elif param_name.endswith("weight_scale"):
                param_data[:, 1] = loaded_weight
            else:
                param_data[:] = loaded_weight
137
        if weight_name.endswith("w2"):
138
139
140
141
142
143
144
145
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                ).transpose(1, 2)
                param_data[:] = loaded_weight[:, :, shard]
            else:
                param_data[:] = loaded_weight
146

147
148
149
150
151
152
153
154
155
156
157

class DbrxMoE(nn.Module):
    """A tensor-parallel MoE implementation for DBRX.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

    def __init__(
        self,
158
        config: DbrxConfig,
159
160
        quant_config: QuantizationConfig | None = None,
        params_dtype: torch.dtype | None = None,
161
        prefix: str = "",
162
163
164
165
166
167
168
169
170
    ):
        super().__init__()
        self.d_model = config.d_model
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.router = DbrxRouter(config, self.params_dtype)

171
172
173
174
175
176
        self.experts = DbrxExperts(
            config=config,
            quant_config=quant_config,
            params_dtype=self.params_dtype,
            prefix=f"{prefix}.experts",
        )
177

178
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
179
        orig_shape = hidden_states.shape
180
181
182
        hidden_states = hidden_states.view(-1, self.d_model)
        # router_logits: (num_tokens, n_experts)
        router_logits = self.router(hidden_states)
183
184
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)
185
186
187
188
189


class DbrxAttention(nn.Module):
    def __init__(
        self,
190
        config: DbrxConfig,
191
192
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
193
        prefix: str = "",
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    ):
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
        self.head_dim = self.d_model // self.total_num_heads
        self.total_num_kv_heads = config.attn_config.kv_n_heads
        self.clip_qkv = config.attn_config.clip_qkv
        self.rope_theta = config.attn_config.rope_theta
        self.max_position = config.max_seq_len

        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
            self.d_model,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
211
            quant_config=quant_config,
212
            prefix=f"{prefix}.Wqkv",
213
214
215
216
217
        )
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=False,
218
            quant_config=quant_config,
219
            prefix=f"{prefix}.out_proj",
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        self.tp_size = tp_world_size
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
245
246
247
248
249
250
251
252
253
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
254
255
256
257
258
259
260
261
262
263
264

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
265
        attn_output = self.attn(q, k, v)
266
267
268
269
270
271
272
        hidden_states, _ = self.out_proj(attn_output)
        return hidden_states


class DbrxFusedNormAttention(nn.Module):
    def __init__(
        self,
273
        config: DbrxConfig,
274
275
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
276
        prefix: str = "",
277
278
279
    ):
        super().__init__()
        self.d_model = config.d_model
280
281
282
        self.attn = DbrxAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        self.norm_1 = nn.LayerNorm(self.d_model)
        self.norm_2 = nn.LayerNorm(self.d_model)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = residual + x
        residual = hidden_states
        hidden_states = self.norm_2(hidden_states)
        return hidden_states, residual


class DbrxBlock(nn.Module):
    def __init__(
        self,
306
        config: DbrxConfig,
307
308
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
309
        prefix: str = "",
310
311
    ):
        super().__init__()
312
        self.norm_attn_norm = DbrxFusedNormAttention(
313
314
            config, cache_config, quant_config, prefix=f"{prefix}.norm_attn_norm"
        )
315
        self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states, residual = self.norm_attn_norm(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = self.ffn(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states


class DbrxModel(nn.Module):
332
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
333
        super().__init__()
334
335
336
337
338

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

339
        self.quant_config = quant_config
340
341
342
343
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
344
345
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
346
            lambda prefix: DbrxBlock(config, cache_config, quant_config, prefix=prefix),
347
348
            prefix=f"{prefix}.blocks",
        )
349
350
        self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
        for module in self.modules():
351
            if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
352
353
                # Remove the bias term in Linear and LayerNorm.
                module.register_parameter("bias", None)
354
355
356
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.d_model
        )
357

358
359
360
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

361
362
363
364
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
365
366
367
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
368
        if get_pp_group().is_first_rank:
369
370
371
372
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
373
374
375
        else:
            assert intermediate_tensors
            hidden_states = intermediate_tensors["hidden_states"]
376
        for block in islice(self.blocks, self.start_layer, self.end_layer):
377
            hidden_states = block(position_ids, hidden_states)
378
379
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
380
381
382
        hidden_states = self.norm_f(hidden_states)
        return hidden_states

383
384
385
386
387
388
389
390
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        expert_params_mapping = [
            (
                "w13" if weight_name in ["w1", "v1"] else "w2",
                f"mlp.{weight_name}",
            )
            for weight_name in ["w1", "v1", "w2"]
        ]
391
392
393
394
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
395
396
397
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
398
399
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
400
401
402
403
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            if name.endswith(("w1", "w2", "v1")):
                name = name + "_weight"
            for param_name, weight_name in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, weight_name, name)
                break

            else:
                if is_pp_missing_parameter(name, self):
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                param = params_dict[name]
429
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
430
431
432
433
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

434

435
class DbrxForCausalLM(nn.Module, SupportsPP):
436
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
437
        super().__init__()
438
439
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
440
        self.config = config
441
        if config.tie_word_embeddings:
442
            raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
443
        self.quant_config = quant_config
444
        self.unpadded_vocab_size = config.vocab_size
445
446
447
        self.transformer = DbrxModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
448
449
450
451
452
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
453
            quant_config=quant_config,
454
            prefix=maybe_prefix(prefix, "lm_head"),
455
        )
456
457
458
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size
        )
459
        self.make_empty_intermediate_tensors = (
460
461
            self.transformer.make_empty_intermediate_tensors
        )
462

463
464
465
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

466
467
468
469
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
470
471
472
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
473
474
475
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
476
477
        return hidden_states

478
479
480
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
481
    ) -> torch.Tensor | None:
482
        logits = self.logits_processor(self.lm_head, hidden_states)
483
484
        return logits

485
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
486
487
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)