phi.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
39
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
40
from collections.abc import Iterable
41
from itertools import islice
42
from typing import Optional, Union
43
44
45

import torch
from torch import nn
46
from transformers import PhiConfig
47

48
from vllm.attention import Attention
49
from vllm.compilation.decorators import support_torch_compile
50
from vllm.config import CacheConfig, VllmConfig
51
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
52
53
54
55
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
56
from vllm.model_executor.layers.logits_processor import LogitsProcessor
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
from vllm.model_executor.layers.rotary_embedding import get_rope
59
from vllm.model_executor.layers.vocab_parallel_embedding import (
60
    ParallelLMHead, VocabParallelEmbedding)
61
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
from vllm.sequence import IntermediateTensors
63

64
from .interfaces import SupportsLoRA, SupportsPP
65
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
66
67
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
68

69
70
71
72

class PhiAttention(nn.Module):

    def __init__(self,
73
                 config: PhiConfig,
74
                 cache_config: Optional[CacheConfig] = None,
75
76
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
77
78
79
80
81
82
83
84
85
86
87
88
89
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
90
            self.hidden_size,
91
92
            self.head_size,
            self.total_num_heads,
93
            bias=True,
94
            quant_config=quant_config,
95
        )
96
        self.dense = RowParallelLinear(
97
98
            self.hidden_size,
            self.hidden_size,
99
            quant_config=quant_config,
100
101
102
        )

        scaling = self.head_size**-0.5
103
104
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
105
106
107
108
109
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
110
111
112
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        self.rotary_emb = get_rope(
114
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
117
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        )
119
120
121
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
122
                              cache_config=cache_config,
123
124
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
125
126
127
128
129
130

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
131
        qkv, _ = self.qkv_proj(hidden_states)
132
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        q, k = self.rotary_emb(position_ids, q, k)
134
        attn_output = self.attn(q, k, v)
135
        output, _ = self.dense(attn_output)
136
137
138
139
140
141
        return output


class PhiMLP(nn.Module):

    def __init__(self,
142
                 config: PhiConfig,
143
                 quant_config: Optional[QuantizationConfig] = None):
144
145
146
147
148
149
150
151
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
152
            quant_config=quant_config,
153
154
155
156
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
157
            quant_config=quant_config,
158
        )
159
        self.act = get_act_fn(config.hidden_act)
160
161
162
163
164
165
166
167
168
169
170

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
171
                 config: PhiConfig,
172
                 cache_config: Optional[CacheConfig] = None,
173
174
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
175
        super().__init__()
176
177
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
178
179
180
181
        self.self_attn = PhiAttention(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.self_attn")
182
        self.mlp = PhiMLP(config, quant_config)
183
184
185
186
187
188
189

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
190
191
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
192
193
194
195
196
197
198
199
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


200
@support_torch_compile
201
202
class PhiModel(nn.Module):

203
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
204
        super().__init__()
205
206
207
208
209

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

210
        self.config = config
211
        self.quant_config = quant_config
212
213
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
214
215
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
216
217
            lambda prefix: PhiLayer(
                config, cache_config, quant_config, prefix=prefix),
218
            prefix=f"{prefix}.layers")
219
220
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
221
222
223
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
224

225
226
227
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

228
229
230
231
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
232
        intermediate_tensors: Optional[IntermediateTensors],
233
        inputs_embeds: Optional[torch.Tensor] = None,
234
235
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
236
237
238
239
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
240
241
242
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
243
        for layer in islice(self.layers, self.start_layer, self.end_layer):
244
            hidden_states = layer(positions, hidden_states)
245

246
247
248
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

249
        hidden_states = self.final_layernorm(hidden_states)
250

251
        return hidden_states
252

253
254
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
255
256
257
258
259
260
261
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
        params_dict = dict(self.named_parameters())
262
        loaded_params: set[str] = set()
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

296

297
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
298
299
300
301
302
303
304
305
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

306
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
307
        super().__init__()
308
309
310
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
311
        self.config = config
312
313
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
314
315
        self.lora_config = lora_config

316
        self.quant_config = quant_config
317

318
319
        self.model = PhiModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
320
321
322

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
323
                                      bias=True,
324
325
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
326
        self.logits_processor = LogitsProcessor(config.vocab_size)
327
328
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
329

330
331
332
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

333
334
335
336
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
337
        intermediate_tensors: Optional[IntermediateTensors] = None,
338
        inputs_embeds: Optional[torch.Tensor] = None,
339
    ) -> Union[torch.Tensor, IntermediateTensors]:
340
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
341
                                   inputs_embeds)
342

343
344
        return hidden_states

345
346
347
348
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
349
        logits = self.logits_processor(self.lm_head, hidden_states,
350
                                       self.lm_head.bias)
351
352
        return logits

353
354
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
355
356
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)