qwen.py 11 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.config import CacheConfig
15
from vllm.distributed import get_tensor_model_parallel_world_size
16
from vllm.model_executor.layers.activation import SiluAndMul
17
from vllm.model_executor.layers.layernorm import RMSNorm
18
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
19
20
                                               QKVParallelLinear,
                                               RowParallelLinear)
21
from vllm.model_executor.layers.logits_processor import LogitsProcessor
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
24
from vllm.model_executor.layers.rotary_embedding import get_rope
25
26
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
27
    ParallelLMHead, VocabParallelEmbedding)
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
30
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
31
from vllm.utils import print_warning_once
Qing's avatar
Qing committed
32

33
34
35
36
37
38
39
40

class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
41
        quant_config: Optional[QuantizationConfig] = None,
42
43
44
45
46
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
47
            quant_config=quant_config)
48
49
50
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
51
                                        quant_config=quant_config)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
73
        cache_config: Optional[CacheConfig] = None,
74
        quant_config: Optional[QuantizationConfig] = None,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
90
            quant_config=quant_config,
91
92
93
94
95
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
96
            quant_config=quant_config,
97
98
99
100
101
102
103
104
105
106
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
107
108
109
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
110
111
                              cache_config=cache_config,
                              quant_config=quant_config)
112
113
114
115
116

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
117
118
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
119
120
121
122
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
123
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
124
125
126
127
128
129
130
131
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
132
        config: PretrainedConfig,
133
        cache_config: Optional[CacheConfig] = None,
134
        quant_config: Optional[QuantizationConfig] = None,
135
136
137
138
139
140
141
142
143
144
145
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
146
                                  cache_config=cache_config,
147
                                  quant_config=quant_config)
148
149
150
151
152

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
153
                           quant_config=quant_config)
154
155
156
157
158

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
159
160
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
161
162
163
164
165
166
167
168
169
170
171
172
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
173
            attn_metadata=attn_metadata,
174
175
176
177
178
179
180
181
182
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
183

184
185
    def __init__(
        self,
186
        config: PretrainedConfig,
187
        cache_config: Optional[CacheConfig] = None,
188
        quant_config: Optional[QuantizationConfig] = None,
189
190
191
192
193
194
195
196
197
198
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
199
            QWenBlock(config, cache_config, quant_config)
200
201
202
203
204
205
206
207
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
208
209
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
210
211
212
213
214
215
216
217
218
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
219
                attn_metadata,
220
221
222
223
224
225
226
227
228
229
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
230
        config: PretrainedConfig,
231
        cache_config: Optional[CacheConfig] = None,
232
        quant_config: Optional[QuantizationConfig] = None,
233
234
235
    ):
        super().__init__()
        self.config = config
236
        self.quant_config = quant_config
237
        self.transformer = QWenModel(config, cache_config, quant_config)
238
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
239
240
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
241
242
243
244
245

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
246
247
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
248
249
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
250
                                         attn_metadata)
251
252
        return hidden_states

253
254
255
256
257
258
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

259
260
    def sample(
        self,
261
        logits: torch.Tensor,
262
263
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
264
        next_tokens = self.sampler(logits, sampling_metadata)
265
        return next_tokens
Qing's avatar
Qing committed
266

267
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
268
269
270
271
272
273
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
274
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
275
276
            if "rotary_emb.inv_freq" in name:
                continue
277
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
278
279
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
280
281
282
283
284
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
285
286
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
287
                break
288
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
289
290
291
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
292
293
294
295
296
297
298
299
300
                # Skip loading visual weights to support Qwen-VL models
                # in cases with text-only inputs
                # TODO: add support for Qwen-VL
                if (name not in params_dict
                        and name.startswith("transformer.visual.")):
                    print_warning_once(
                        "Only text inputs are allowed. Images won't be handled "
                        "until Qwen-VL models are fully supported.")
                    continue
301
302
303
304
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)