phi.py 13.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple, Union
39
40
41

import torch
from torch import nn
42
from transformers import PhiConfig
43

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.config import CacheConfig, LoRAConfig
46
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
from vllm.model_executor.layers.rotary_embedding import get_rope
54
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
55
from vllm.model_executor.layers.vocab_parallel_embedding import (
56
    ParallelLMHead, VocabParallelEmbedding)
57
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
58
from vllm.model_executor.sampling_metadata import SamplingMetadata
59
from vllm.sequence import IntermediateTensors
60

61
62
63
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
64

65
66
67
68

class PhiAttention(nn.Module):

    def __init__(self,
69
                 config: PhiConfig,
70
                 cache_config: Optional[CacheConfig] = None,
71
                 quant_config: Optional[QuantizationConfig] = None):
72
73
74
75
76
77
78
79
80
81
82
83
84
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
85
            self.hidden_size,
86
87
            self.head_size,
            self.total_num_heads,
88
            bias=True,
89
            quant_config=quant_config,
90
        )
91
        self.dense = RowParallelLinear(
92
93
            self.hidden_size,
            self.hidden_size,
94
            quant_config=quant_config,
95
96
97
        )

        scaling = self.head_size**-0.5
98
99
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
100
101
102
103
104
105
106
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        self.rotary_emb = get_rope(
108
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
111
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        )
113
114
115
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
116
117
                              cache_config=cache_config,
                              quant_config=quant_config)
118
119
120
121
122

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
123
124
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
125
    ) -> torch.Tensor:
126
        qkv, _ = self.qkv_proj(hidden_states)
127
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        q, k = self.rotary_emb(position_ids, q, k)
129
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
130
        output, _ = self.dense(attn_output)
131
132
133
134
135
136
        return output


class PhiMLP(nn.Module):

    def __init__(self,
137
                 config: PhiConfig,
138
                 quant_config: Optional[QuantizationConfig] = None):
139
140
141
142
143
144
145
146
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
147
            quant_config=quant_config,
148
149
150
151
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
152
            quant_config=quant_config,
153
        )
154
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
155
156
157
158
159
160
161
162
163
164
165

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
166
                 config: PhiConfig,
167
                 cache_config: Optional[CacheConfig] = None,
168
                 quant_config: Optional[QuantizationConfig] = None):
169
        super().__init__()
170
171
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
172
        self.self_attn = PhiAttention(config, cache_config, quant_config)
173
        self.mlp = PhiMLP(config, quant_config)
174
175
176
177
178

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
179
180
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
181
182
    ) -> torch.Tensor:
        residual = hidden_states
183
184
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
185
186
187
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
188
            attn_metadata=attn_metadata,
189
190
191
192
193
194
195
196
197
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
198
                 config: PhiConfig,
199
                 cache_config: Optional[CacheConfig] = None,
200
201
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
202
203
        super().__init__()
        self.config = config
204
        self.quant_config = quant_config
205
206
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
207
208
209
210
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: PhiLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers")
211
212
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
213
214
215
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
216
217
218
219
220

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
221
222
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
223
224
225
226
227
228
229
230
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
231
            layer = self.layers[i]
232
233
234
            hidden_states = layer(
                positions,
                hidden_states,
235
                kv_caches[i - self.start_layer],
236
                attn_metadata,
237
238
            )

239
240
241
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

242
        hidden_states = self.final_layernorm(hidden_states)
243

244
        return hidden_states
245
246


247
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "dense",
        "fc1",
        "fc2",
    ]
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".fc2.", ".dense."]

277
278
279
280
281
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
282
        config: PhiConfig,
283
284
285
286
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
287
        super().__init__()
288

289
        self.config = config
290
291
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
292
293
        self.lora_config = lora_config

294
        self.quant_config = quant_config
295

296
        self.model = PhiModel(config, cache_config, quant_config)
297
298
299

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
300
301
                                      bias=True,
                                      quant_config=quant_config)
302
303
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
304
305
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
306
307
308
309
310

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
311
312
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
313
        intermediate_tensors: Optional[IntermediateTensors] = None,
314
    ) -> Union[torch.Tensor, IntermediateTensors]:
315
        hidden_states = self.model(input_ids, positions, kv_caches,
316
                                   attn_metadata, intermediate_tensors)
317

318
319
        return hidden_states

320
321
322
323
324
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
325
        logits = self.logits_processor(self.lm_head, hidden_states,
326
327
328
                                       sampling_metadata, self.lm_head.bias)
        return logits

329
330
    def sample(
        self,
331
        logits: torch.Tensor,
332
        sampling_metadata: SamplingMetadata,
333
    ) -> Optional[SamplerOutput]:
334
        next_tokens = self.sampler(logits, sampling_metadata)
335
        return next_tokens
336

337
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
338
339
340
341
342
343
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
344
        params_dict = dict(self.named_parameters())
345

346
        for name, loaded_weight in weights:
347
348
349
            if "rotary_emb.inv_freq" in name:
                continue

350
351
352
353
354
355
356
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
357
358
                if is_pp_missing_parameter(name, self):
                    continue
359
360
361
362
363
364
365
366
367
368
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

369
370
                if is_pp_missing_parameter(name, self):
                    continue
371
372
373
374
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)