phi.py 13.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
38
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
39
from typing import Iterable, Optional, Set, Tuple, Union
40
41
42

import torch
from torch import nn
43
from transformers import PhiConfig
44

45
from vllm.attention import Attention
46
from vllm.compilation.decorators import support_torch_compile
47
from vllm.config import CacheConfig, VllmConfig
48
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
49
50
51
52
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
53
from vllm.model_executor.layers.logits_processor import LogitsProcessor
54
from vllm.model_executor.layers.quantization import QuantizationConfig
55
from vllm.model_executor.layers.rotary_embedding import get_rope
56
from vllm.model_executor.layers.vocab_parallel_embedding import (
57
    ParallelLMHead, VocabParallelEmbedding)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.model_executor.sampling_metadata import SamplingMetadata
60
from vllm.sequence import IntermediateTensors
61

62
from .interfaces import SupportsLoRA, SupportsPP
63
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
64
65
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
66

67
68
69
70

class PhiAttention(nn.Module):

    def __init__(self,
71
                 config: PhiConfig,
72
                 cache_config: Optional[CacheConfig] = None,
73
74
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
75
76
77
78
79
80
81
82
83
84
85
86
87
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
88
            self.hidden_size,
89
90
            self.head_size,
            self.total_num_heads,
91
            bias=True,
92
            quant_config=quant_config,
93
        )
94
        self.dense = RowParallelLinear(
95
96
            self.hidden_size,
            self.hidden_size,
97
            quant_config=quant_config,
98
99
100
        )

        scaling = self.head_size**-0.5
101
102
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
103
104
105
106
107
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
108
109
110
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        self.rotary_emb = get_rope(
112
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
115
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
116
        )
117
118
119
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
120
                              cache_config=cache_config,
121
122
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
123
124
125
126
127
128

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
129
        qkv, _ = self.qkv_proj(hidden_states)
130
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
131
        q, k = self.rotary_emb(position_ids, q, k)
132
        attn_output = self.attn(q, k, v)
133
        output, _ = self.dense(attn_output)
134
135
136
137
138
139
        return output


class PhiMLP(nn.Module):

    def __init__(self,
140
                 config: PhiConfig,
141
                 quant_config: Optional[QuantizationConfig] = None):
142
143
144
145
146
147
148
149
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
150
            quant_config=quant_config,
151
152
153
154
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
155
            quant_config=quant_config,
156
        )
157
        self.act = get_act_fn(config.hidden_act)
158
159
160
161
162
163
164
165
166
167
168

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
169
                 config: PhiConfig,
170
                 cache_config: Optional[CacheConfig] = None,
171
172
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
173
        super().__init__()
174
175
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
176
177
178
179
        self.self_attn = PhiAttention(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.self_attn")
180
        self.mlp = PhiMLP(config, quant_config)
181
182
183
184
185
186
187

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
188
189
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
190
191
192
193
194
195
196
197
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


198
@support_torch_compile
199
200
class PhiModel(nn.Module):

201
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
202
        super().__init__()
203
204
205
206
207

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

208
        self.config = config
209
        self.quant_config = quant_config
210
211
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
212
213
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
214
215
            lambda prefix: PhiLayer(
                config, cache_config, quant_config, prefix=prefix),
216
            prefix=f"{prefix}.layers")
217
218
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
219
220
221
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
222

223
224
225
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

226
227
228
229
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
230
        intermediate_tensors: Optional[IntermediateTensors],
231
        inputs_embeds: Optional[torch.Tensor] = None,
232
233
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
234
235
236
237
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
238
239
240
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
241
242
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
243

244
245
246
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

247
        hidden_states = self.final_layernorm(hidden_states)
248

249
        return hidden_states
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

294

295
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
296
297
298
299
300
301
302
303
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

304
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
305
        super().__init__()
306
307
308
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
309
        self.config = config
310
311
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
312
313
        self.lora_config = lora_config

314
        self.quant_config = quant_config
315

316
317
        self.model = PhiModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
318
319
320

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
321
322
                                      bias=True,
                                      quant_config=quant_config)
323
        self.logits_processor = LogitsProcessor(config.vocab_size)
324
325
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
326

327
328
329
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

330
331
332
333
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
334
        intermediate_tensors: Optional[IntermediateTensors] = None,
335
        inputs_embeds: Optional[torch.Tensor] = None,
336
    ) -> Union[torch.Tensor, IntermediateTensors]:
337
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
338
                                   inputs_embeds)
339

340
341
        return hidden_states

342
343
344
345
346
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
347
        logits = self.logits_processor(self.lm_head, hidden_states,
348
349
350
                                       sampling_metadata, self.lm_head.bias)
        return logits

351
352
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
353
354
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)