phi.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
39
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
40

41
from collections.abc import Iterable
42
from itertools import islice
43
44
45

import torch
from torch import nn
46
from transformers import PhiConfig
47

48
from vllm.attention import Attention
49
from vllm.compilation.decorators import support_torch_compile
50
from vllm.config import CacheConfig, VllmConfig
51
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
52
from vllm.model_executor.layers.activation import get_act_fn
53
54
55
56
57
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
58
from vllm.model_executor.layers.logits_processor import LogitsProcessor
59
from vllm.model_executor.layers.quantization import QuantizationConfig
60
from vllm.model_executor.layers.rotary_embedding import get_rope
61
from vllm.model_executor.layers.vocab_parallel_embedding import (
62
63
64
    ParallelLMHead,
    VocabParallelEmbedding,
)
65
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
66
from vllm.sequence import IntermediateTensors
67

68
from .interfaces import SupportsLoRA, SupportsPP
69
70
71
72
73
74
75
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
76

77
78

class PhiAttention(nn.Module):
79
80
81
    def __init__(
        self,
        config: PhiConfig,
82
83
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
84
85
        prefix: str = "",
    ):
86
87
88
89
90
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

91
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
92
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
93
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
94
95
96

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
97
            self.hidden_size,
98
99
            self.head_size,
            self.total_num_heads,
100
            bias=True,
101
            quant_config=quant_config,
102
        )
103
        self.dense = RowParallelLinear(
104
105
            self.hidden_size,
            self.hidden_size,
106
            quant_config=quant_config,
107
108
109
        )

        scaling = self.head_size**-0.5
110
111
112
113
        rotary_dim = int(
            config.partial_rotary_factor
            * (config.hidden_size // config.num_attention_heads)
        )
114
115
116
117
118
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
119
        rope_theta = getattr(config, "rope_theta", 10000.0)
120
        max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        self.rotary_emb = get_rope(
122
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
125
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
        )
127
128
129
130
131
132
133
134
        self.attn = Attention(
            self.num_heads,
            self.head_size,
            scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
135
136
137
138
139
140

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
141
        qkv, _ = self.qkv_proj(hidden_states)
142
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
143
        q, k = self.rotary_emb(position_ids, q, k)
144
        attn_output = self.attn(q, k, v)
145
        output, _ = self.dense(attn_output)
146
147
148
149
        return output


class PhiMLP(nn.Module):
150
    def __init__(
151
        self, config: PhiConfig, quant_config: QuantizationConfig | None = None
152
    ):
153
154
155
156
157
158
159
160
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
161
            quant_config=quant_config,
162
163
164
165
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
166
            quant_config=quant_config,
167
        )
168
        self.act = get_act_fn(config.hidden_act)
169
170
171
172
173
174
175
176
177

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):
178
179
180
    def __init__(
        self,
        config: PhiConfig,
181
182
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
183
184
        prefix: str = "",
    ):
185
        super().__init__()
186
187
188
189
190
191
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.self_attn = PhiAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
192
        self.mlp = PhiMLP(config, quant_config)
193
194
195
196
197
198
199

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
200
201
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
202
203
204
205
206
207
208
209
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


210
@support_torch_compile
211
class PhiModel(nn.Module):
212
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
213
        super().__init__()
214
215
216
217
218

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

219
        self.config = config
220
        self.quant_config = quant_config
221
222
223
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
224
225
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
226
227
228
229
230
231
232
233
234
            lambda prefix: PhiLayer(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.layers",
        )
        self.final_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
235

236
237
238
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

239
240
241
242
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
244
245
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
246
        if get_pp_group().is_first_rank:
247
248
249
250
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
251
252
253
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
254
        for layer in islice(self.layers, self.start_layer, self.end_layer):
255
            hidden_states = layer(positions, hidden_states)
256

257
258
259
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

260
        hidden_states = self.final_layernorm(hidden_states)
261

262
        return hidden_states
263

264
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
265
266
267
268
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
269
            ("qkv_proj", "v_proj", "v"),
270
271
        ]
        params_dict = dict(self.named_parameters())
272
        loaded_params: set[str] = set()
273
274
275
276
277

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

278
            for param_name, weight_name, shard_id in stacked_params_mapping:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
300
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
301
302
303
304
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

305

306
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
307
308
309
310
311
312
313
314
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

315
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
316
        super().__init__()
317
318
319
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
320
        self.config = config
321
322
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
323
324
        self.lora_config = lora_config

325
        self.quant_config = quant_config
326

327
328
329
        self.model = PhiModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
330

331
332
333
334
335
336
337
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
338
        self.logits_processor = LogitsProcessor(config.vocab_size)
339
        self.make_empty_intermediate_tensors = (
340
341
            self.model.make_empty_intermediate_tensors
        )
342

343
344
345
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

346
347
348
349
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
350
351
352
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
353
354
355
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
356

357
358
        return hidden_states

359
360
361
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
362
    ) -> torch.Tensor | None:
363
        logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias)
364
365
        return logits

366
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
367
368
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)