"vllm/benchmarks/datasets/datasets.py" did not exist on "3717a4dd475e6a936df0c84b043743310368e766"
qwen.py 12.8 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.config import CacheConfig
15
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
16
from vllm.model_executor.layers.activation import SiluAndMul
17
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
20
                                               QKVParallelLinear,
                                               RowParallelLinear)
21
from vllm.model_executor.layers.logits_processor import LogitsProcessor
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
24
from vllm.model_executor.layers.rotary_embedding import get_rope
25
26
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
    ParallelLMHead, VocabParallelEmbedding)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.model_executor.sampling_metadata import SamplingMetadata
30
from vllm.sequence import IntermediateTensors, SamplerOutput
31
from vllm.utils import print_warning_once
Qing's avatar
Qing committed
32

33
34
from .utils import is_pp_missing_parameter, make_layers

35
36
37
38
39
40
41
42

class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
43
        quant_config: Optional[QuantizationConfig] = None,
44
45
46
47
48
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
49
            quant_config=quant_config)
50
51
52
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
53
                                        quant_config=quant_config)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
75
        cache_config: Optional[CacheConfig] = None,
76
        quant_config: Optional[QuantizationConfig] = None,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
92
            quant_config=quant_config,
93
94
95
96
97
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
98
            quant_config=quant_config,
99
100
101
102
103
104
105
106
107
108
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
109
110
111
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
112
113
                              cache_config=cache_config,
                              quant_config=quant_config)
114
115
116
117
118

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
119
120
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
121
122
123
124
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
125
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
126
127
128
129
130
131
132
133
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
134
        config: PretrainedConfig,
135
        cache_config: Optional[CacheConfig] = None,
136
        quant_config: Optional[QuantizationConfig] = None,
137
138
139
140
141
142
143
144
145
146
147
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
148
                                  cache_config=cache_config,
149
                                  quant_config=quant_config)
150
151
152
153
154

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
155
                           quant_config=quant_config)
156
157
158
159
160

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
163
164
165
166
167
168
169
170
171
172
173
174
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
175
            attn_metadata=attn_metadata,
176
177
178
179
180
181
182
183
184
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
185

186
187
    def __init__(
        self,
188
        config: PretrainedConfig,
189
        cache_config: Optional[CacheConfig] = None,
190
        quant_config: Optional[QuantizationConfig] = None,
191
        prefix: str = "",
192
193
194
195
196
197
198
199
200
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
201
202
203
204
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: QWenBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h")
205
206
207
208
209
210
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
211
212
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
213
        intermediate_tensors: Optional[IntermediateTensors],
214
    ) -> torch.Tensor:
215
216
217
218
219
220
221
222
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
223
224
225
226
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
227
                kv_caches[i - self.start_layer],
228
                attn_metadata,
229
230
                residual,
            )
231
232
233
234
235
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
236
237
238
239
240
241
242
243
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
244
        config: PretrainedConfig,
245
        cache_config: Optional[CacheConfig] = None,
246
        quant_config: Optional[QuantizationConfig] = None,
247
248
249
    ):
        super().__init__()
        self.config = config
250
        self.quant_config = quant_config
251
        self.transformer = QWenModel(config, cache_config, quant_config)
252
253
254
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
255
256
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
257
258
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
259
260
261
262
263

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
264
265
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
266
        intermediate_tensors: Optional[IntermediateTensors] = None,
267
268
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
269
                                         attn_metadata, intermediate_tensors)
270
271
        return hidden_states

272
273
274
275
276
277
278
279
280
281
282
283
284
285
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

286
287
288
289
290
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
291
        logits = self.logits_processor(self.lm_head, hidden_states,
292
293
294
                                       sampling_metadata)
        return logits

295
296
    def sample(
        self,
297
        logits: torch.Tensor,
298
299
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
300
        next_tokens = self.sampler(logits, sampling_metadata)
301
        return next_tokens
Qing's avatar
Qing committed
302

303
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
304
305
306
307
308
309
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
310
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
311
312
            if "rotary_emb.inv_freq" in name:
                continue
313
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
314
315
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
316
317
318
319
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
320
321
322
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
323
                param = params_dict[name]
324
325
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
326
                break
327
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
328
329
330
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
331
332
333
334
335
336
337
338
339
                # Skip loading visual weights to support Qwen-VL models
                # in cases with text-only inputs
                # TODO: add support for Qwen-VL
                if (name not in params_dict
                        and name.startswith("transformer.visual.")):
                    print_warning_once(
                        "Only text inputs are allowed. Images won't be handled "
                        "until Qwen-VL models are fully supported.")
                    continue
340
341
342
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
343
344
345
346
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)