qwen.py 10.3 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
7
from typing import Any, Dict, List, Optional, Tuple
Qing's avatar
Qing committed
8

9
10
import torch
from torch import nn
11
from transformers import PretrainedConfig
Qing's avatar
Qing committed
12

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.model_executor.layers.activation import SiluAndMul
15
from vllm.model_executor.layers.layernorm import RMSNorm
16
17
18
19
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.rotary_embedding import get_rope
22
23
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
    ParallelLMHead, VocabParallelEmbedding)
25
26
27
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
29
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
30
from vllm.sequence import SamplerOutput
Qing's avatar
Qing committed
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
                                        linear_method=linear_method)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            linear_method=linear_method,
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            linear_method=linear_method,
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
105
        self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
106
107
108
109
110

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
111
112
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
113
114
115
116
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
117
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
118
119
120
121
122
123
124
125
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
126
        config: PretrainedConfig,
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
                                  linear_method=linear_method)

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
                           linear_method=linear_method)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
151
152
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
153
154
155
156
157
158
159
160
161
162
163
164
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
165
            attn_metadata=attn_metadata,
166
167
168
169
170
171
172
173
174
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
175

176
177
    def __init__(
        self,
178
        config: PretrainedConfig,
179
        linear_method: Optional[LinearMethodBase] = None,
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.h = nn.ModuleList([
            QWenBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
199
200
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
201
202
203
204
205
206
207
208
209
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        residual = None
        for i in range(len(self.h)):
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
210
                attn_metadata,
211
212
213
214
215
216
217
218
219
220
                residual,
            )
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


class QWenLMHeadModel(nn.Module):

    def __init__(
        self,
221
        config: PretrainedConfig,
222
223
224
225
226
227
228
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.config = config
        self.linear_method = linear_method
        self.transformer = QWenModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
229
230
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
231
232
233
234
235

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
236
237
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
238
239
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
240
                                         attn_metadata)
241
242
        return hidden_states

243
244
245
246
247
248
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

249
250
    def sample(
        self,
251
        logits: torch.Tensor,
252
253
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
254
        next_tokens = self.sampler(logits, sampling_metadata)
255
        return next_tokens
Qing's avatar
Qing committed
256

257
258
259
260
261
262
263
264
265
266
267
    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
Qing's avatar
Qing committed
268
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
269
                model_name_or_path, cache_dir, load_format, revision):
Qing's avatar
Qing committed
270
271
            if "rotary_emb.inv_freq" in name:
                continue
272
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
273
274
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
275
276
277
278
279
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
280
281
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
282
                break
283
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
284
285
286
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
287
288
289
290
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)