qwen.py 9.99 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
6
7
8
9
10
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
Qing's avatar
Qing committed
11
from typing import Any, Dict, List, Optional, Tuple
Qing's avatar
Qing committed
12
13
14
15
16
17
18

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
19
20
21
22
23
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Qing's avatar
Qing committed
24
from vllm.model_executor.layers.sampler import Sampler
25
26
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Qing's avatar
Qing committed
27
from vllm.model_executor.parallel_utils.parallel_state import (
28
29
30
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
31
from vllm.sequence import SamplerOutput
Qing's avatar
Qing committed
32
33
34
35
36
37
38
39
40
41
42
43
from vllm.transformers_utils.configs.qwen import QWenConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
44
        linear_method: Optional[LinearMethodBase] = None,
Qing's avatar
Qing committed
45
46
    ):
        super().__init__()
47
48
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
Qing's avatar
Qing committed
49
            bias=False,
50
51
52
53
54
            linear_method=linear_method)
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
                                        linear_method=linear_method)
Qing's avatar
Qing committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

69
70
71
72
73
74
75
76
77
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
78
79
80
81
82
83
84
85
86
87
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads

88
        self.c_attn = QKVParallelLinear(
Qing's avatar
Qing committed
89
            hidden_size,
90
91
            self.head_dim,
            self.total_num_heads,
Qing's avatar
Qing committed
92
            bias=True,
93
            linear_method=linear_method,
Qing's avatar
Qing committed
94
95
96
97
98
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
99
            linear_method=linear_method,
Qing's avatar
Qing committed
100
101
102
103
104
105
106
        )
        self.scaling = self.head_dim**-0.5
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            rotary_dim=self.head_dim,
107
            base=rope_theta,
Qing's avatar
Qing committed
108
            max_position=max_position_embeddings,
Qing's avatar
Qing committed
109
            rope_scaling=rope_scaling)
Qing's avatar
Qing committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        k_cache, v_cache = kv_cache
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)

        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

132
133
134
135
136
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
137
        super().__init__()
Qing's avatar
Qing committed
138
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
139

140
        rope_theta = getattr(config, "rope_theta", 10000)
Qing's avatar
Qing committed
141
        rope_scaling = getattr(config, "rope_scaling", None)
Qing's avatar
Qing committed
142
        self.attn = QWenAttention(config.hidden_size,
143
144
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
Qing's avatar
Qing committed
145
                                  rope_theta=rope_theta,
146
147
                                  rope_scaling=rope_scaling,
                                  linear_method=linear_method)
Qing's avatar
Qing committed
148

Qing's avatar
Qing committed
149
        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
150

151
152
153
        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
                           linear_method=linear_method)
Qing's avatar
Qing committed
154
155
156
157
158
159
160
161

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
162
163
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Qing's avatar
Qing committed
164
        # Self Attention
165
166
167
168
169
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
Qing's avatar
Qing committed
170
171
172
173
174
175
176
177
178
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
179
        hidden_states, residual = self.ln_2(hidden_states, residual)
Qing's avatar
Qing committed
180
        hidden_states = self.mlp(hidden_states)
181
        return hidden_states, residual
Qing's avatar
Qing committed
182
183
184
185


class QWenModel(nn.Module):

186
187
188
189
190
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
191
192
193
194
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

195
        self.wte = VocabParallelEmbedding(
196
            config.vocab_size,
197
198
            config.hidden_size,
        )
199
200
201
202
        self.h = nn.ModuleList([
            QWenBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Qing's avatar
Qing committed
203
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
204
205
206
207
208
209
210
211
212
213

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
214
        residual = None
Qing's avatar
Qing committed
215
216
217
218
219
220
        for i in range(len(self.h)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.h[i]
221
            hidden_states, residual = layer(
Qing's avatar
Qing committed
222
223
224
225
226
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
227
                residual,
Qing's avatar
Qing committed
228
            )
229
        hidden_states, _ = self.ln_f(hidden_states, residual)
Qing's avatar
Qing committed
230
231
232
233
234
        return hidden_states


class QWenLMHeadModel(nn.Module):

235
236
237
238
239
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
240
241
        super().__init__()
        self.config = config
242
243
244
        self.linear_method = linear_method
        self.transformer = QWenModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Qing's avatar
Qing committed
245
246
247
248
249
250
251
252
253
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
254
    ) -> SamplerOutput:
Qing's avatar
Qing committed
255
256
257
258
259
260
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
        return next_tokens

261
262
263
264
265
266
267
268
269
270
271
    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
Qing's avatar
Qing committed
272
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
273
                model_name_or_path, cache_dir, load_format, revision):
Qing's avatar
Qing committed
274
275
            if "rotary_emb.inv_freq" in name:
                continue
276
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
277
278
                if weight_name not in name:
                    continue
279
280
281
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
282
                break
283
284
285
286
287
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)