phi.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
39
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
40

41
from collections.abc import Iterable
42
from itertools import islice
43
44
45

import torch
from torch import nn
46
from transformers import PhiConfig
47

48
from vllm.attention.layer import Attention
49
from vllm.compilation.decorators import support_torch_compile
50
from vllm.config import CacheConfig, VllmConfig
51
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
52
from vllm.model_executor.layers.activation import get_act_fn
53
54
55
56
57
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
58
from vllm.model_executor.layers.logits_processor import LogitsProcessor
59
from vllm.model_executor.layers.quantization import QuantizationConfig
60
from vllm.model_executor.layers.rotary_embedding import get_rope
61
from vllm.model_executor.layers.vocab_parallel_embedding import (
62
63
64
    ParallelLMHead,
    VocabParallelEmbedding,
)
65
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
66
from vllm.sequence import IntermediateTensors
67

68
from .interfaces import SupportsLoRA, SupportsPP
69
70
71
72
73
74
75
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
76

77
78

class PhiAttention(nn.Module):
79
80
81
    def __init__(
        self,
        config: PhiConfig,
82
83
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
84
85
        prefix: str = "",
    ):
86
87
        super().__init__()
        self.hidden_size = config.hidden_size
88
        self.head_size = self.hidden_size // config.num_attention_heads
89

90
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
91
92
        assert config.num_attention_heads % tensor_model_parallel_world_size == 0
        self.num_heads = config.num_attention_heads // tensor_model_parallel_world_size
93
94
95

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
96
            self.hidden_size,
97
            self.head_size,
98
            config.num_attention_heads,
99
            bias=True,
100
            quant_config=quant_config,
101
            prefix=f"{prefix}.qkv_proj",
102
        )
103
        self.dense = RowParallelLinear(
104
105
            self.hidden_size,
            self.hidden_size,
106
            quant_config=quant_config,
107
            prefix=f"{prefix}.dense",
108
109
110
111
        )

        scaling = self.head_size**-0.5

112
        max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        self.rotary_emb = get_rope(
114
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
115
            max_position=max_position_embeddings,
116
            rope_parameters=config.rope_parameters,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        )
118
119
120
121
122
123
124
125
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
126
127
128
129
130
131

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
132
        qkv, _ = self.qkv_proj(hidden_states)
133
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        q, k = self.rotary_emb(position_ids, q, k)
135
        attn_output = self.attn(q, k, v)
136
        output, _ = self.dense(attn_output)
137
138
139
140
        return output


class PhiMLP(nn.Module):
141
    def __init__(
142
143
144
145
        self,
        config: PhiConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
146
    ):
147
148
149
150
151
152
153
154
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.fc1",
157
158
159
160
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
161
            quant_config=quant_config,
162
            prefix=f"{prefix}.fc2",
163
        )
164
        self.act = get_act_fn(config.hidden_act)
165
166
167
168
169
170
171
172
173

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):
174
175
176
    def __init__(
        self,
        config: PhiConfig,
177
178
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
179
180
        prefix: str = "",
    ):
181
        super().__init__()
182
183
184
185
186
187
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.self_attn = PhiAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
188
        self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
189
190
191
192
193
194
195

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
196
197
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
198
199
200
201
202
203
204
205
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


206
@support_torch_compile
207
class PhiModel(nn.Module):
208
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
209
        super().__init__()
210
211
212
213
214

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

215
        self.config = config
216
        self.quant_config = quant_config
217
218
219
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
220
221
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
222
223
224
225
226
227
228
229
230
            lambda prefix: PhiLayer(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers",
        )
        self.final_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
231

232
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
233
234
        return self.embed_tokens(input_ids)

235
236
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
237
        input_ids: torch.Tensor,
238
        positions: torch.Tensor,
239
240
241
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
242
        if get_pp_group().is_first_rank:
243
244
245
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
246
                hidden_states = self.embed_input_ids(input_ids)
247
248
249
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
250
        for layer in islice(self.layers, self.start_layer, self.end_layer):
251
            hidden_states = layer(positions, hidden_states)
252

253
254
255
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

256
        hidden_states = self.final_layernorm(hidden_states)
257

258
        return hidden_states
259

260
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
261
262
263
264
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
265
            ("qkv_proj", "v_proj", "v"),
266
267
        ]
        params_dict = dict(self.named_parameters())
268
        loaded_params: set[str] = set()
269
270
271
272
273

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

274
            for param_name, weight_name, shard_id in stacked_params_mapping:
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
296
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
297
298
299
300
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

301

302
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
303
304
305
306
307
308
309
310
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

311
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
312
        super().__init__()
313
314
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
315

316
        self.config = config
317
318
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
319

320
        self.quant_config = quant_config
321

322
323
324
        self.model = PhiModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
325

326
327
328
329
330
331
332
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
333
        self.logits_processor = LogitsProcessor(config.vocab_size)
334
        self.make_empty_intermediate_tensors = (
335
336
            self.model.make_empty_intermediate_tensors
        )
337

338
339
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
340

341
342
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
343
        input_ids: torch.Tensor,
344
        positions: torch.Tensor,
345
346
347
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
348
349
350
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
351

352
353
        return hidden_states

354
355
356
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
357
    ) -> torch.Tensor | None:
358
        logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias)
359
360
        return logits

361
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
362
        loader = AutoWeightsLoader(self)
zhuwenwen's avatar
zhuwenwen committed
363
        return loader.load_weights(weights)