dbrx.py 17.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Iterable, Optional, Set, Tuple, Union
4
5
6
7

import torch
import torch.nn as nn

8
from vllm.attention import Attention
9
from vllm.config import CacheConfig, VllmConfig
10
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
11
12
                              get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
13
from vllm.model_executor.layers.linear import (QKVParallelLinear,
14
15
16
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
19
20
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
21
22
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
23
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.sequence import IntermediateTensors
25
26
from vllm.transformers_utils.configs.dbrx import DbrxConfig

27
28
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
29
30
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

class DbrxRouter(nn.Module):
    """A Router implementation for DBRX that returns logits for each expert
    per token.
    """

    def __init__(
        self,
        config: DbrxConfig,
        params_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = config.ffn_config.moe_num_experts
        self.d_model = config.d_model
        self.layer = ReplicatedLinear(
            self.d_model,
            self.num_total_experts,
            bias=False,
            params_dtype=params_dtype,
52
            quant_config=None,
53
54
55
56
57
58
59
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        router_logits, _ = self.layer(hidden_states)
        return router_logits


60
class DbrxExperts(FusedMoE):
61
62
63
64

    def __init__(
        self,
        config: DbrxConfig,
65
        quant_config: Optional[QuantizationConfig] = None,
66
        params_dtype: Optional[torch.dtype] = None,
67
        prefix: str = "",
68
    ):
69
70
71
72
73
74
75
76
77
78
        super().__init__(
            num_experts=config.ffn_config.moe_num_experts,
            top_k=config.ffn_config.moe_top_k,
            hidden_size=config.d_model,
            intermediate_size=config.ffn_config.ffn_hidden_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=get_tensor_model_parallel_world_size(),
79
            prefix=prefix,
80
81
        )
        self.config = config
82
        self.d_model = config.d_model
83
        self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
84
85
                                  self.tp_size)

86
    # Define custom weight loader for dbrx model
87
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
88
                      weight_name: str, param_name: str):
89
90
91
92
93
94
95
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        # DBRX uses GLU for each experts.
        # GLU has 3 linear layers: w1, v1 and w2.
        if weight_name.endswith("w1"):
96
97
98
99
100
101
102
103
104
105
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                )
                param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
            elif param_name.endswith("weight_scale"):
                param_data[:, 0] = loaded_weight
            else:
                param_data = loaded_weight
106
        if weight_name.endswith("v1"):
107
108
109
110
111
112
113
114
115
116
117
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                )
                param_data[:, shard_size:2 *
                           shard_size, :] = loaded_weight[:, shard, :]
            elif param_name.endswith("weight_scale"):
                param_data[:, 1] = loaded_weight
            else:
                param_data[:] = loaded_weight
118
        if weight_name.endswith("w2"):
119
120
121
122
123
124
125
126
            if param_name.endswith("weight"):
                loaded_weight = torch.reshape(
                    loaded_weight,
                    [-1, self.intermediate_size * self.tp_size, self.d_model],
                ).transpose(1, 2)
                param_data[:] = loaded_weight[:, :, shard]
            else:
                param_data[:] = loaded_weight
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141

class DbrxMoE(nn.Module):
    """A tensor-parallel MoE implementation for DBRX.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

    def __init__(
        self,
        config: DbrxConfig,
        quant_config: Optional[QuantizationConfig] = None,
        params_dtype: Optional[torch.dtype] = None,
142
        prefix: str = "",
143
144
145
146
147
148
149
150
151
152
153
    ):
        super().__init__()
        self.d_model = config.d_model
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.router = DbrxRouter(config, self.params_dtype)

        self.experts = DbrxExperts(config=config,
                                   quant_config=quant_config,
154
155
                                   params_dtype=self.params_dtype,
                                   prefix=f"{prefix}.experts")
156

157
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
158
        orig_shape = hidden_states.shape
159
160
161
        hidden_states = hidden_states.view(-1, self.d_model)
        # router_logits: (num_tokens, n_experts)
        router_logits = self.router(hidden_states)
162
163
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)
164
165
166
167
168
169
170


class DbrxAttention(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
171
        cache_config: Optional[CacheConfig] = None,
172
        quant_config: Optional[QuantizationConfig] = None,
173
        prefix: str = "",
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ):
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
        self.head_dim = self.d_model // self.total_num_heads
        self.total_num_kv_heads = config.attn_config.kv_n_heads
        self.clip_qkv = config.attn_config.clip_qkv
        self.rope_theta = config.attn_config.rope_theta
        self.max_position = config.max_seq_len

        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
            self.d_model,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
191
            quant_config=quant_config,
192
193
194
195
196
        )
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=False,
197
            quant_config=quant_config,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        self.tp_size = tp_world_size
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
223
224
225
226
227
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
228
229
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
230
231
232
233
234
235
236
237
238
239
240

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
241
        attn_output = self.attn(q, k, v)
242
243
244
245
246
247
248
249
250
        hidden_states, _ = self.out_proj(attn_output)
        return hidden_states


class DbrxFusedNormAttention(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
251
        cache_config: Optional[CacheConfig] = None,
252
        quant_config: Optional[QuantizationConfig] = None,
253
        prefix: str = "",
254
255
256
    ):
        super().__init__()
        self.d_model = config.d_model
257
258
259
260
        self.attn = DbrxAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        self.norm_1 = nn.LayerNorm(self.d_model)
        self.norm_2 = nn.LayerNorm(self.d_model)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = residual + x
        residual = hidden_states
        hidden_states = self.norm_2(hidden_states)
        return hidden_states, residual


class DbrxBlock(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
286
        cache_config: Optional[CacheConfig] = None,
287
        quant_config: Optional[QuantizationConfig] = None,
288
        prefix: str = "",
289
290
    ):
        super().__init__()
291
292
293
294
295
        self.norm_attn_norm = DbrxFusedNormAttention(
            config,
            cache_config,
            quant_config,
            prefix=f"{prefix}.norm_attn_norm")
296
        self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states, residual = self.norm_attn_norm(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = self.ffn(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states


class DbrxModel(nn.Module):

314
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
315
        super().__init__()
316
317
318
319
320

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

321
322
323
324
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
325
326
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
327
328
            lambda prefix: DbrxBlock(
                config, cache_config, quant_config, prefix=prefix),
329
330
            prefix=f"{prefix}.blocks",
        )
331
332
333
334
335
336
        self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
        for module in self.modules():
            if hasattr(module, "bias") and isinstance(module.bias,
                                                      nn.Parameter):
                # Remove the bias term in Linear and LayerNorm.
                module.register_parameter("bias", None)
337
338
339
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.d_model))
340

341
342
343
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

344
345
346
347
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
348
        intermediate_tensors: Optional[IntermediateTensors],
349
        inputs_embeds: Optional[torch.Tensor] = None,
350
351
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
352
353
354
355
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
356
357
358
        else:
            assert intermediate_tensors
            hidden_states = intermediate_tensors["hidden_states"]
359
360
        for block in self.blocks[self.start_layer:self.end_layer]:
            hidden_states = block(position_ids, hidden_states)
361
362
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
363
364
365
366
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


367
class DbrxForCausalLM(nn.Module, SupportsPP):
368

369
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
370
        super().__init__()
371
372
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
373
        self.config = config
374
375
376
        if config.tie_word_embeddings:
            raise ValueError(
                "tie_word_embeddings is not supported for Dbrx models.")
377
        self.quant_config = quant_config
378
        self.unpadded_vocab_size = config.vocab_size
379
380
381
        self.transformer = DbrxModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
382
383
384
385
386
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
387
            quant_config=quant_config,
388
389
390
        )
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
391
392
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
393

394
395
396
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

397
398
399
400
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
401
        intermediate_tensors: Optional[IntermediateTensors] = None,
402
        inputs_embeds: Optional[torch.Tensor] = None,
403
    ) -> Union[torch.Tensor, IntermediateTensors]:
404
405
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
406
407
        return hidden_states

408
409
410
411
412
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
413
        logits = self.logits_processor(self.lm_head, hidden_states,
414
415
416
                                       sampling_metadata)
        return logits

417
418
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
419
        expert_params_mapping = [(
420
            "w13" if weight_name in ["w1", "v1"] else "w2",
421
            f"mlp.{weight_name}",
422
423
        ) for weight_name in ["w1", "v1", "w2"]]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
424
        loaded_params: Set[str] = set()
425

426
        for name, loaded_weight in weights:
427
428
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
429
                # Loading kv cache quantization scales
430
431
432
433
434
435
436
437
438
439
440
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            if name.endswith(("w1", "w2", "v1")):
                name = name + "_weight"
441
442
443
444
            for param_name, weight_name in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
445
446
                if is_pp_missing_parameter(name, self):
                    continue
447
448
                param = params_dict[name]
                weight_loader = param.weight_loader
449
                weight_loader(param, loaded_weight, weight_name, name)
450
                break
451

452
            else:
453
454
                if is_pp_missing_parameter(name, self):
                    continue
455
                # Remapping the name of FP8 kv-scale.
456
457
458
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
459
460
461
462
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
463
464
            loaded_params.add(name)
        return loaded_params