dbrx.py 16.4 KB
Newer Older
1
from typing import Iterable, List, Optional, Tuple, Union
2
3
4
5
6

import torch
import torch.nn as nn

from vllm.attention import Attention, AttentionMetadata
7
from vllm.config import CacheConfig, VllmConfig
8
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
9
10
                              get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
11
from vllm.model_executor.layers.linear import (QKVParallelLinear,
12
13
14
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
15
from vllm.model_executor.layers.quantization import QuantizationConfig
16
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
19
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
20
21
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
22
from vllm.model_executor.sampling_metadata import SamplingMetadata
23
from vllm.sequence import IntermediateTensors
24
25
from vllm.transformers_utils.configs.dbrx import DbrxConfig

26
27
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
28
29
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

class DbrxRouter(nn.Module):
    """A Router implementation for DBRX that returns logits for each expert
    per token.
    """

    def __init__(
        self,
        config: DbrxConfig,
        params_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = config.ffn_config.moe_num_experts
        self.d_model = config.d_model
        self.layer = ReplicatedLinear(
            self.d_model,
            self.num_total_experts,
            bias=False,
            params_dtype=params_dtype,
51
            quant_config=None,
52
53
54
55
56
57
58
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        router_logits, _ = self.layer(hidden_states)
        return router_logits


59
class DbrxExperts(FusedMoE):
60
61
62
63

    def __init__(
        self,
        config: DbrxConfig,
64
        quant_config: Optional[QuantizationConfig] = None,
65
66
        params_dtype: Optional[torch.dtype] = None,
    ):
67
68
69
70
71
72
73
74
75
76
77
78
        super().__init__(
            num_experts=config.ffn_config.moe_num_experts,
            top_k=config.ffn_config.moe_top_k,
            hidden_size=config.d_model,
            intermediate_size=config.ffn_config.ffn_hidden_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=get_tensor_model_parallel_world_size(),
        )
        self.config = config
79
80
        self.tp_size = get_tensor_model_parallel_world_size()
        self.d_model = config.d_model
81
        self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
82
83
                                  self.tp_size)

84
    # Define custom weight loader for dbrx model
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        # DBRX uses GLU for each experts.
        # GLU has 3 linear layers: w1, v1 and w2.
        if weight_name.endswith("w1"):
            loaded_weight = torch.reshape(
                loaded_weight,
                [-1, self.intermediate_size * self.tp_size, self.d_model],
            )
            param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
        if weight_name.endswith("v1"):
            loaded_weight = torch.reshape(
                loaded_weight,
                [-1, self.intermediate_size * self.tp_size, self.d_model],
            )
            param_data[:,
                       shard_size:2 * shard_size, :] = loaded_weight[:,
                                                                     shard, :]
        if weight_name.endswith("w2"):
            loaded_weight = torch.reshape(
                loaded_weight,
                [-1, self.intermediate_size * self.tp_size, self.d_model],
            ).transpose(1, 2)
            param_data[:] = loaded_weight[:, :, shard]

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

class DbrxMoE(nn.Module):
    """A tensor-parallel MoE implementation for DBRX.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

    def __init__(
        self,
        config: DbrxConfig,
        quant_config: Optional[QuantizationConfig] = None,
        params_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.d_model = config.d_model
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        self.router = DbrxRouter(config, self.params_dtype)

        self.experts = DbrxExperts(config=config,
                                   quant_config=quant_config,
                                   params_dtype=self.params_dtype)

141
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142
        orig_shape = hidden_states.shape
143
144
145
        hidden_states = hidden_states.view(-1, self.d_model)
        # router_logits: (num_tokens, n_experts)
        router_logits = self.router(hidden_states)
146
147
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)
148
149
150
151
152
153
154


class DbrxAttention(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
155
        cache_config: Optional[CacheConfig] = None,
156
        quant_config: Optional[QuantizationConfig] = None,
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    ):
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
        self.head_dim = self.d_model // self.total_num_heads
        self.total_num_kv_heads = config.attn_config.kv_n_heads
        self.clip_qkv = config.attn_config.clip_qkv
        self.rope_theta = config.attn_config.rope_theta
        self.max_position = config.max_seq_len

        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
            self.d_model,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
174
            quant_config=quant_config,
175
176
177
178
179
        )
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=False,
180
            quant_config=quant_config,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        self.tp_size = tp_world_size
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
206
207
208
209
210
211
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        hidden_states, _ = self.out_proj(attn_output)
        return hidden_states


class DbrxFusedNormAttention(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
235
        cache_config: Optional[CacheConfig] = None,
236
        quant_config: Optional[QuantizationConfig] = None,
237
238
239
    ):
        super().__init__()
        self.d_model = config.d_model
240
        self.attn = DbrxAttention(config, cache_config, quant_config)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        self.norm_1 = nn.LayerNorm(self.d_model)
        self.norm_2 = nn.LayerNorm(self.d_model)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states = residual + x
        residual = hidden_states
        hidden_states = self.norm_2(hidden_states)
        return hidden_states, residual


class DbrxBlock(nn.Module):

    def __init__(
        self,
        config: DbrxConfig,
270
        cache_config: Optional[CacheConfig] = None,
271
        quant_config: Optional[QuantizationConfig] = None,
272
273
    ):
        super().__init__()
274
275
        self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
                                                     quant_config)
276
        self.ffn = DbrxMoE(config, quant_config)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states, residual = self.norm_attn_norm(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states = self.ffn(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states


class DbrxModel(nn.Module):

298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
299
        super().__init__()
300
301
302
303
304

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

305
306
307
308
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
309
310
311
312
313
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
            lambda prefix: DbrxBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.blocks",
        )
314
315
316
317
318
319
        self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
        for module in self.modules():
            if hasattr(module, "bias") and isinstance(module.bias,
                                                      nn.Parameter):
                # Remove the bias term in Linear and LayerNorm.
                module.register_parameter("bias", None)
320
321
322
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.d_model))
323
324
325
326
327
328
329

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
330
331
332
333
334
335
336
337
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
        else:
            assert intermediate_tensors
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
338
339
340
341
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
342
                kv_caches[i - self.start_layer],
343
344
                attn_metadata,
            )
345
346
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
347
348
349
350
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


351
class DbrxForCausalLM(nn.Module, SupportsPP):
352

353
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
354
        super().__init__()
355
356
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
357
        self.config = config
358
359
360
        if config.tie_word_embeddings:
            raise ValueError(
                "tie_word_embeddings is not supported for Dbrx models.")
361
        self.quant_config = quant_config
362
        self.unpadded_vocab_size = config.vocab_size
363
364
365
        self.transformer = DbrxModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
366
367
368
369
370
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.d_model,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
371
            quant_config=quant_config,
372
373
374
        )
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
375
        self.sampler = get_sampler()
376
377
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
378
379
380
381
382
383
384

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
385
        intermediate_tensors: Optional[IntermediateTensors] = None,
386
    ) -> Union[torch.Tensor, IntermediateTensors]:
387
        hidden_states = self.transformer(input_ids, positions, kv_caches,
388
                                         attn_metadata, intermediate_tensors)
389
390
        return hidden_states

391
392
393
394
395
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
396
        logits = self.logits_processor(self.lm_head, hidden_states,
397
398
399
400
401
402
403
404
405
406
407
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

408
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
409

410
        expert_params_mapping = [(
411
412
            "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
            f"mlp.{weight_name}",
413
414
        ) for weight_name in ["w1", "v1", "w2"]]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
415
        for name, loaded_weight in weights:
416
417
418
419
            for param_name, weight_name in expert_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
420
421
                if is_pp_missing_parameter(name, self):
                    continue
422
423
424
425
426
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, weight_name)
                break
            else:
427
428
429
430
431
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

432
433
                if is_pp_missing_parameter(name, self):
                    continue
434
435
436
437
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)