phi.py 13.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
38
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
39
40
from collections.abc import Iterable
from typing import Optional, Union
41
42
43

import torch
from torch import nn
44
from transformers import PhiConfig
45

46
from vllm.attention import Attention
47
from vllm.compilation.decorators import support_torch_compile
48
from vllm.config import CacheConfig, VllmConfig
49
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
50
51
52
53
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
54
from vllm.model_executor.layers.logits_processor import LogitsProcessor
55
from vllm.model_executor.layers.quantization import QuantizationConfig
56
from vllm.model_executor.layers.rotary_embedding import get_rope
57
from vllm.model_executor.layers.vocab_parallel_embedding import (
58
    ParallelLMHead, VocabParallelEmbedding)
59
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
60
from vllm.model_executor.sampling_metadata import SamplingMetadata
61
from vllm.sequence import IntermediateTensors
62

63
from .interfaces import SupportsLoRA, SupportsPP
64
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
65
66
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
67

68
69
70
71

class PhiAttention(nn.Module):

    def __init__(self,
72
                 config: PhiConfig,
73
                 cache_config: Optional[CacheConfig] = None,
74
75
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
76
77
78
79
80
81
82
83
84
85
86
87
88
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
89
            self.hidden_size,
90
91
            self.head_size,
            self.total_num_heads,
92
            bias=True,
93
            quant_config=quant_config,
94
        )
95
        self.dense = RowParallelLinear(
96
97
            self.hidden_size,
            self.hidden_size,
98
            quant_config=quant_config,
99
100
101
        )

        scaling = self.head_size**-0.5
102
103
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
104
105
106
107
108
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
109
110
111
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        self.rotary_emb = get_rope(
113
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
116
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        )
118
119
120
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
121
                              cache_config=cache_config,
122
123
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
124
125
126
127
128
129

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
130
        qkv, _ = self.qkv_proj(hidden_states)
131
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        q, k = self.rotary_emb(position_ids, q, k)
133
        attn_output = self.attn(q, k, v)
134
        output, _ = self.dense(attn_output)
135
136
137
138
139
140
        return output


class PhiMLP(nn.Module):

    def __init__(self,
141
                 config: PhiConfig,
142
                 quant_config: Optional[QuantizationConfig] = None):
143
144
145
146
147
148
149
150
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
151
            quant_config=quant_config,
152
153
154
155
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
156
            quant_config=quant_config,
157
        )
158
        self.act = get_act_fn(config.hidden_act)
159
160
161
162
163
164
165
166
167
168
169

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
170
                 config: PhiConfig,
171
                 cache_config: Optional[CacheConfig] = None,
172
173
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
174
        super().__init__()
175
176
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
177
178
179
180
        self.self_attn = PhiAttention(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.self_attn")
181
        self.mlp = PhiMLP(config, quant_config)
182
183
184
185
186
187
188

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
189
190
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
191
192
193
194
195
196
197
198
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


199
@support_torch_compile
200
201
class PhiModel(nn.Module):

202
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
203
        super().__init__()
204
205
206
207
208

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

209
        self.config = config
210
        self.quant_config = quant_config
211
212
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
213
214
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
215
216
            lambda prefix: PhiLayer(
                config, cache_config, quant_config, prefix=prefix),
217
            prefix=f"{prefix}.layers")
218
219
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
220
221
222
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
223

224
225
226
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

227
228
229
230
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
231
        intermediate_tensors: Optional[IntermediateTensors],
232
        inputs_embeds: Optional[torch.Tensor] = None,
233
234
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
235
236
237
238
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
239
240
241
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
242
243
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
244

245
246
247
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

248
        hidden_states = self.final_layernorm(hidden_states)
249

250
        return hidden_states
251

252
253
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
254
255
256
257
258
259
260
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
        params_dict = dict(self.named_parameters())
261
        loaded_params: set[str] = set()
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

295

296
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
297
298
299
300
301
302
303
304
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

305
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
306
        super().__init__()
307
308
309
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
310
        self.config = config
311
312
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
313
314
        self.lora_config = lora_config

315
        self.quant_config = quant_config
316

317
318
        self.model = PhiModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
319
320
321

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
322
323
                                      bias=True,
                                      quant_config=quant_config)
324
        self.logits_processor = LogitsProcessor(config.vocab_size)
325
326
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
327

328
329
330
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

331
332
333
334
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
335
        intermediate_tensors: Optional[IntermediateTensors] = None,
336
        inputs_embeds: Optional[torch.Tensor] = None,
337
    ) -> Union[torch.Tensor, IntermediateTensors]:
338
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
339
                                   inputs_embeds)
340

341
342
        return hidden_states

343
344
345
346
347
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
348
        logits = self.logits_processor(self.lm_head, hidden_states,
349
350
351
                                       sampling_metadata, self.lm_head.bias)
        return logits

352
353
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
354
355
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)