phi.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
39
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
40
from collections.abc import Iterable
41
from itertools import islice
42
from typing import Optional, Union
43
44
45

import torch
from torch import nn
46
from transformers import PhiConfig
47

48
from vllm.attention import Attention
49
from vllm.compilation.decorators import support_torch_compile
50
from vllm.config import CacheConfig, VllmConfig
51
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
52
53
54
55
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
56
from vllm.model_executor.layers.logits_processor import LogitsProcessor
57
from vllm.model_executor.layers.quantization import QuantizationConfig
58
from vllm.model_executor.layers.rotary_embedding import get_rope
59
from vllm.model_executor.layers.vocab_parallel_embedding import (
60
    ParallelLMHead, VocabParallelEmbedding)
61
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
from vllm.model_executor.sampling_metadata import SamplingMetadata
63
from vllm.sequence import IntermediateTensors
64

65
from .interfaces import SupportsLoRA, SupportsPP
66
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
67
68
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
69

70
71
72
73

class PhiAttention(nn.Module):

    def __init__(self,
74
                 config: PhiConfig,
75
                 cache_config: Optional[CacheConfig] = None,
76
77
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
78
79
80
81
82
83
84
85
86
87
88
89
90
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
91
            self.hidden_size,
92
93
            self.head_size,
            self.total_num_heads,
94
            bias=True,
95
            quant_config=quant_config,
96
        )
97
        self.dense = RowParallelLinear(
98
99
            self.hidden_size,
            self.hidden_size,
100
            quant_config=quant_config,
101
102
103
        )

        scaling = self.head_size**-0.5
104
105
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
106
107
108
109
110
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
111
112
113
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        self.rotary_emb = get_rope(
115
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
118
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        )
120
121
122
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
123
                              cache_config=cache_config,
124
125
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
126
127
128
129
130
131

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
132
        qkv, _ = self.qkv_proj(hidden_states)
133
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        q, k = self.rotary_emb(position_ids, q, k)
135
        attn_output = self.attn(q, k, v)
136
        output, _ = self.dense(attn_output)
137
138
139
140
141
142
        return output


class PhiMLP(nn.Module):

    def __init__(self,
143
                 config: PhiConfig,
144
                 quant_config: Optional[QuantizationConfig] = None):
145
146
147
148
149
150
151
152
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
153
            quant_config=quant_config,
154
155
156
157
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
158
            quant_config=quant_config,
159
        )
160
        self.act = get_act_fn(config.hidden_act)
161
162
163
164
165
166
167
168
169
170
171

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
172
                 config: PhiConfig,
173
                 cache_config: Optional[CacheConfig] = None,
174
175
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
176
        super().__init__()
177
178
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
179
180
181
182
        self.self_attn = PhiAttention(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.self_attn")
183
        self.mlp = PhiMLP(config, quant_config)
184
185
186
187
188
189
190

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
191
192
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
193
194
195
196
197
198
199
200
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


201
@support_torch_compile
202
203
class PhiModel(nn.Module):

204
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
205
        super().__init__()
206
207
208
209
210

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

211
        self.config = config
212
        self.quant_config = quant_config
213
214
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
215
216
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
217
218
            lambda prefix: PhiLayer(
                config, cache_config, quant_config, prefix=prefix),
219
            prefix=f"{prefix}.layers")
220
221
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
222
223
224
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
225

226
227
228
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

229
230
231
232
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
233
        intermediate_tensors: Optional[IntermediateTensors],
234
        inputs_embeds: Optional[torch.Tensor] = None,
235
236
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
237
238
239
240
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
241
242
243
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
244
        for layer in islice(self.layers, self.start_layer, self.end_layer):
245
            hidden_states = layer(positions, hidden_states)
246

247
248
249
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

250
        hidden_states = self.final_layernorm(hidden_states)
251

252
        return hidden_states
253

254
255
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
256
257
258
259
260
261
262
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
        params_dict = dict(self.named_parameters())
263
        loaded_params: set[str] = set()
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

297

298
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
299
300
301
302
303
304
305
306
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

307
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
308
        super().__init__()
309
310
311
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
312
        self.config = config
313
314
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
315
316
        self.lora_config = lora_config

317
        self.quant_config = quant_config
318

319
320
        self.model = PhiModel(vllm_config=vllm_config,
                              prefix=maybe_prefix(prefix, "model"))
321
322
323

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
324
                                      bias=True,
325
326
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
327
        self.logits_processor = LogitsProcessor(config.vocab_size)
328
329
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
330

331
332
333
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

334
335
336
337
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
338
        intermediate_tensors: Optional[IntermediateTensors] = None,
339
        inputs_embeds: Optional[torch.Tensor] = None,
340
    ) -> Union[torch.Tensor, IntermediateTensors]:
341
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
342
                                   inputs_embeds)
343

344
345
        return hidden_states

346
347
348
349
350
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
351
        logits = self.logits_processor(self.lm_head, hidden_states,
352
353
354
                                       sampling_metadata, self.lm_head.bias)
        return logits

355
356
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
357
358
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)