mixtral_quant.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple
25
26
27
28
29
30
31

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig
34
35
36
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (QKVParallelLinear,
39
                                               ReplicatedLinear,
40
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
46
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
50
51
52
53
54
55
56
57
58
59
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput


class MixtralMLP(nn.Module):

    def __init__(
        self,
        num_experts: int,
        hidden_size: int,
        intermediate_size: int,
60
        quant_config: Optional[QuantizationConfig] = None,
61
62
63
64
65
66
67
68
69
    ) -> None:
        super().__init__()
        self.num_experts = num_experts
        self.ffn_dim = intermediate_size
        self.hidden_dim = hidden_size

        self.w1 = ReplicatedLinear(self.hidden_dim,
                                   self.ffn_dim,
                                   bias=False,
70
                                   quant_config=quant_config)
71
72
73
        self.w2 = ReplicatedLinear(self.ffn_dim,
                                   self.hidden_dim,
                                   bias=False,
74
                                   quant_config=quant_config)
75
76
77
        self.w3 = ReplicatedLinear(self.hidden_dim,
                                   self.ffn_dim,
                                   bias=False,
78
                                   quant_config=quant_config)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

        # TODO: Use vllm's SiluAndMul
        self.act_fn = nn.SiLU()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        w1_out, _ = self.w1(hidden_states)
        w1_out = self.act_fn(w1_out)
        w3_out, _ = self.w3(hidden_states)
        current_hidden_states = w1_out * w3_out
        current_hidden_states, _ = self.w2(current_hidden_states)
        return current_hidden_states


class MixtralMoE(nn.Module):

    def __init__(
        self,
        config: MixtralConfig,
97
        quant_config: Optional[QuantizationConfig] = None,
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    ):
        super().__init__()
        self.config = config
        self.rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_total_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok
        if self.tp_size > self.num_total_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {self.num_total_experts}.")
        # Split experts equally between ranks
        self.expert_indicies = np.array_split(range(
            self.num_total_experts), self.tp_size)[self.rank].tolist()
        if not self.expert_indicies:
            raise ValueError(
                f"Rank {self.rank} has no experts assigned to it.")

        self.experts = nn.ModuleList([
            MixtralMLP(self.num_total_experts,
                       config.hidden_size,
                       config.intermediate_size,
120
                       quant_config=quant_config)
121
122
123
124
125
126
            if idx in self.expert_indicies else None
            for idx in range(self.num_total_experts)
        ])
        self.gate = ReplicatedLinear(config.hidden_size,
                                     self.num_total_experts,
                                     bias=False,
127
                                     quant_config=None)
128
129

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
130
        num_tokens, hidden_dim = hidden_states.shape
131
        hidden_states = hidden_states.view(-1, hidden_dim)
132
        # router_logits: (num_tokens, n_experts)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        router_logits, _ = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights,
                                                       self.top_k,
                                                       dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

        final_hidden_states = None
        for expert_idx in self.expert_indicies:
            expert_layer = self.experts[expert_idx]
            expert_mask = (selected_experts == expert_idx)
            expert_weights = (routing_weights * expert_mask).sum(dim=-1,
                                                                 keepdim=True)

            current_hidden_states = expert_layer(hidden_states).mul_(
                expert_weights)
            if final_hidden_states is None:
                final_hidden_states = current_hidden_states
            else:
                final_hidden_states.add_(current_hidden_states)

        return tensor_model_parallel_all_reduce(final_hidden_states).view(
156
            num_tokens, hidden_dim)
157
158
159
160


class MixtralAttention(nn.Module):

161
162
163
164
165
166
167
168
169
170
171
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        quant_config: Optional[QuantizationConfig] = None,
        sliding_window: Optional[int] = None,
        cache_config: Optional[CacheConfig] = None,
    ) -> None:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
201
            quant_config=quant_config,
202
203
204
205
206
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
207
            quant_config=quant_config,
208
209
210
211
212
213
214
215
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
216
217
218
219
220
221
222
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              sliding_window=self.sliding_window,
                              cache_config=cache_config,
                              quant_config=quant_config)
223
224
225
226
227

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
228
229
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
230
231
232
233
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
234
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
235
236
237
238
239
240
241
242
243
        output, _ = self.o_proj(attn_output)
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
        config: MixtralConfig,
244
        cache_config: Optional[CacheConfig] = None,
245
        quant_config: Optional[QuantizationConfig] = None,
246
247
248
249
250
251
252
253
254
255
256
257
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MixtralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            sliding_window=config.sliding_window,
258
            cache_config=cache_config,
259
            quant_config=quant_config)
260
        self.block_sparse_moe = MixtralMoE(config=config,
261
                                           quant_config=quant_config)
262
263
264
265
266
267
268
269
270
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
271
272
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
273
274
275
276
277
278
279
280
281
282
283
284
285
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
286
            attn_metadata=attn_metadata,
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual


class MixtralModel(nn.Module):

    def __init__(
        self,
        config: MixtralConfig,
301
        cache_config: Optional[CacheConfig] = None,
302
        quant_config: Optional[QuantizationConfig] = None,
303
304
305
306
307
308
309
310
311
312
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
313
314
315
            MixtralDecoderLayer(config,
                                cache_config,
                                quant_config=quant_config)
316
317
318
319
320
321
322
323
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
324
325
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
326
327
328
329
330
331
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
332
                                            kv_caches[i], attn_metadata,
333
334
335
336
337
338
                                            residual)
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
339
    fall_back_to_pt_during_load = False
340
341
342
343

    def __init__(
        self,
        config: MixtralConfig,
344
        cache_config: Optional[CacheConfig] = None,
345
        quant_config: Optional[QuantizationConfig] = None,
346
347
348
    ) -> None:
        super().__init__()
        self.config = config
349
        self.quant_config = quant_config
350
        self.model = MixtralModel(config, cache_config, quant_config)
351
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
352
353
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
354
355
356
357
358

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
359
360
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
361
362
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
363
                                   attn_metadata)
364
365
        return hidden_states

366
367
368
369
370
371
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

372
373
    def sample(
        self,
374
        logits: Optional[torch.Tensor],
375
376
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
377
        next_tokens = self.sampler(logits, sampling_metadata)
378
379
        return next_tokens

380
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
381
382
383
384
385
386
387
388
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
389
        for name, loaded_weight in weights:
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip experts that are not assigned to this worker.
                if ("block_sparse_moe.experts." in name
                        and name not in params_dict):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)