phi.py 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple, Union
39
40
41

import torch
from torch import nn
42
from transformers import PhiConfig
43

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.config import CacheConfig, LoRAConfig
46
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
from vllm.model_executor.layers.rotary_embedding import get_rope
54
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
55
from vllm.model_executor.layers.vocab_parallel_embedding import (
56
    ParallelLMHead, VocabParallelEmbedding)
57
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
58
from vllm.model_executor.sampling_metadata import SamplingMetadata
59
from vllm.sequence import IntermediateTensors
60

61
62
63
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
64

65
66
67
68

class PhiAttention(nn.Module):

    def __init__(self,
69
                 config: PhiConfig,
70
                 cache_config: Optional[CacheConfig] = None,
71
                 quant_config: Optional[QuantizationConfig] = None):
72
73
74
75
76
77
78
79
80
81
82
83
84
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
85
            self.hidden_size,
86
87
            self.head_size,
            self.total_num_heads,
88
            bias=True,
89
            quant_config=quant_config,
90
        )
91
        self.dense = RowParallelLinear(
92
93
            self.hidden_size,
            self.hidden_size,
94
            quant_config=quant_config,
95
96
97
        )

        scaling = self.head_size**-0.5
98
99
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
100
101
102
103
104
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
105
106
107
        rope_theta = getattr(config, "rope_theta", 10000.0)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        self.rotary_emb = get_rope(
109
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
112
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        )
114
115
116
        self.attn = Attention(self.num_heads,
                              self.head_size,
                              scaling,
117
118
                              cache_config=cache_config,
                              quant_config=quant_config)
119
120
121
122
123

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
124
125
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
126
    ) -> torch.Tensor:
127
        qkv, _ = self.qkv_proj(hidden_states)
128
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        q, k = self.rotary_emb(position_ids, q, k)
130
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
131
        output, _ = self.dense(attn_output)
132
133
134
135
136
137
        return output


class PhiMLP(nn.Module):

    def __init__(self,
138
                 config: PhiConfig,
139
                 quant_config: Optional[QuantizationConfig] = None):
140
141
142
143
144
145
146
147
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
148
            quant_config=quant_config,
149
150
151
152
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
153
            quant_config=quant_config,
154
        )
155
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
156
157
158
159
160
161
162
163
164
165
166

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
167
                 config: PhiConfig,
168
                 cache_config: Optional[CacheConfig] = None,
169
                 quant_config: Optional[QuantizationConfig] = None):
170
        super().__init__()
171
172
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
173
        self.self_attn = PhiAttention(config, cache_config, quant_config)
174
        self.mlp = PhiMLP(config, quant_config)
175
176
177
178
179

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
180
181
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
182
183
    ) -> torch.Tensor:
        residual = hidden_states
184
185
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
186
187
188
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
189
            attn_metadata=attn_metadata,
190
191
192
193
194
195
196
197
198
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
199
                 config: PhiConfig,
200
                 cache_config: Optional[CacheConfig] = None,
201
202
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
203
204
        super().__init__()
        self.config = config
205
        self.quant_config = quant_config
206
207
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
208
209
210
211
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: PhiLayer(config, cache_config, quant_config),
            prefix=f"{prefix}.layers")
212
213
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
214
215
216
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
217
218
219
220
221

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
222
223
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
224
225
226
227
228
229
230
231
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
232
            layer = self.layers[i]
233
234
235
            hidden_states = layer(
                positions,
                hidden_states,
236
                kv_caches[i - self.start_layer],
237
                attn_metadata,
238
239
            )

240
241
242
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

243
        hidden_states = self.final_layernorm(hidden_states)
244

245
        return hidden_states
246
247


248
class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "dense",
        "fc1",
        "fc2",
    ]
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
    }
    default_bitsandbytes_target_modules = [
        ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".fc2.", ".dense."]

278
279
280
281
282
    embedding_modules = {}
    embedding_padding_modules = []

    def __init__(
        self,
283
        config: PhiConfig,
284
285
286
287
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
288
        super().__init__()
289

290
        self.config = config
291
292
        # lm_head use bias, cannot share word embeddings
        assert not config.tie_word_embeddings
293
294
        self.lora_config = lora_config

295
        self.quant_config = quant_config
296

297
        self.model = PhiModel(config, cache_config, quant_config)
298
299
300

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
301
302
                                      bias=True,
                                      quant_config=quant_config)
303
304
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
305
306
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
307
308
309
310
311

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
312
313
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
314
        intermediate_tensors: Optional[IntermediateTensors] = None,
315
    ) -> Union[torch.Tensor, IntermediateTensors]:
316
        hidden_states = self.model(input_ids, positions, kv_caches,
317
                                   attn_metadata, intermediate_tensors)
318

319
320
        return hidden_states

321
322
323
324
325
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
326
        logits = self.logits_processor(self.lm_head, hidden_states,
327
328
329
                                       sampling_metadata, self.lm_head.bias)
        return logits

330
331
    def sample(
        self,
332
        logits: torch.Tensor,
333
        sampling_metadata: SamplingMetadata,
334
    ) -> Optional[SamplerOutput]:
335
        next_tokens = self.sampler(logits, sampling_metadata)
336
        return next_tokens
337

338
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
339
340
341
342
343
344
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
345
        params_dict = dict(self.named_parameters())
346

347
        for name, loaded_weight in weights:
348
349
350
            if "rotary_emb.inv_freq" in name:
                continue

351
352
353
354
355
356
357
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
358
359
                if is_pp_missing_parameter(name, self):
                    continue
360
361
362
363
364
365
366
367
368
369
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

370
371
                if is_pp_missing_parameter(name, self):
                    continue
372
373
374
375
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)