mistral.py 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only Mistral model compatible with HuggingFace weights."""
24
25
26
27
from typing import List, Optional, Tuple

import torch
from torch import nn
28
from transformers import MistralConfig
29
30
31

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.attention import PagedAttention
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.rotary_embedding import get_rope
39
from vllm.model_executor.layers.sampler import Sampler
40
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
42
from vllm.model_executor.parallel_utils.parallel_state import (
43
    get_tensor_model_parallel_world_size)
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
46
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


class MistralMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
59
        linear_method: Optional[LinearMethodBase] = None,
60
61
    ) -> None:
        super().__init__()
62
63
64
65
66
67
68
69
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class MistralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
90
                 linear_method: Optional[LinearMethodBase] = None,
91
92
93
94
95
96
97
98
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
99
100
101
102
103
104
105
106
107
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
108
109
110
111
112
113
114
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

115
        self.qkv_proj = QKVParallelLinear(
116
117
            hidden_size,
            self.head_dim,
118
119
            self.total_num_heads,
            self.total_num_kv_heads,
120
            bias=False,
121
            linear_method=linear_method,
122
        )
123
        self.o_proj = RowParallelLinear(
124
125
126
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
127
            linear_method=linear_method,
128
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
134
135
136
137
138
139
140

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=self.rope_theta,
        )
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
                                   num_kv_heads=self.num_kv_heads,
                                   sliding_window=self.sliding_window)
141
142
143
144
145
146
147
148
149
150
151

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
        q, k = self.rotary_emb(positions, q, k)
153
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
156
157
158
159
160
161
162
163
164
        output, _ = self.o_proj(attn_output)
        return output


class MistralDecoderLayer(nn.Module):

    def __init__(
        self,
        config: MistralConfig,
165
        linear_method: Optional[LinearMethodBase] = None,
166
167
168
169
170
171
172
173
174
175
176
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MistralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
177
            linear_method=linear_method,
178
179
180
181
182
            sliding_window=config.sliding_window)
        self.mlp = MistralMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
183
            linear_method=linear_method,
184
185
186
187
188
189
190
191
192
193
194
195
196
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
197
198
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
199
        # Self Attention
200
201
202
203
204
205
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
206
207
208
209
210
211
212
213
214
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
215
216
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
217
        hidden_states = self.mlp(hidden_states)
218
        return hidden_states, residual
219
220
221
222
223
224
225


class MistralModel(nn.Module):

    def __init__(
        self,
        config: MistralConfig,
226
        linear_method: Optional[LinearMethodBase] = None,
227
228
229
230
231
232
233
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
234
            config.vocab_size,
235
236
            config.hidden_size,
        )
237
        self.layers = nn.ModuleList([
238
            MistralDecoderLayer(config, linear_method)
239
240
241
242
243
244
245
246
247
248
249
250
251
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
252
        residual = None
253
        for i in range(len(self.layers)):
254
            cache_event = None if cache_events is None else cache_events[i]
255
            layer = self.layers[i]
256
            hidden_states, residual = layer(
257
258
259
260
261
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
262
                residual,
263
            )
264
        hidden_states, _ = self.norm(hidden_states, residual)
265
266
267
268
269
270
271
272
        return hidden_states


class MistralForCausalLM(nn.Module):

    def __init__(
        self,
        config: MistralConfig,
273
        linear_method: Optional[LinearMethodBase] = None,
274
275
276
    ) -> None:
        super().__init__()
        self.config = config
277
278
279
        self.linear_method = linear_method
        self.model = MistralModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
280
281
282
283
284
285
286
287
288
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
289
    ) -> torch.Tensor:
290
291
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
292
293
294
295
296
297
298
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
299
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
300
                                   sampling_metadata)
301
302
303
304
305
306
307
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
308
309
310
311
312
313
314
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
315
        ]
316
        params_dict = dict(self.named_parameters())
317
318
319
320
        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, load_format, revision):
            if "rotary_emb.inv_freq" in name:
                continue
321
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
322
323
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
324
325
326
327
328
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
329
330
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
331
                break
332
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
333
334
335
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
336
337
338
339
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)