"googlemock/test/gmock-generated-matchers_test.cc" did not exist on "ada23475e27babd85fb9c13250243f6acfd3ffd8"
qwen.py 10.1 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
Qing's avatar
Qing committed
7
from typing import Any, Dict, List, Optional, Tuple
Qing's avatar
Qing committed
8
9
10
11
12
13

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.model_executor.layers.attention import PagedAttention
15
16
17
18
19
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.model_executor.layers.rotary_embedding import get_rope
Qing's avatar
Qing committed
21
from vllm.model_executor.layers.sampler import Sampler
22
23
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Qing's avatar
Qing committed
24
from vllm.model_executor.parallel_utils.parallel_state import (
25
    get_tensor_model_parallel_world_size)
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
28
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
29
from vllm.sequence import SamplerOutput
Qing's avatar
Qing committed
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.transformers_utils.configs.qwen import QWenConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


class QWenMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
42
        linear_method: Optional[LinearMethodBase] = None,
Qing's avatar
Qing committed
43
44
    ):
        super().__init__()
45
46
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
Qing's avatar
Qing committed
47
            bias=False,
48
49
50
51
52
            linear_method=linear_method)
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
                                        linear_method=linear_method)
Qing's avatar
Qing committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

67
68
69
70
71
72
73
74
75
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
76
77
78
79
80
81
82
83
84
85
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads

86
        self.c_attn = QKVParallelLinear(
Qing's avatar
Qing committed
87
            hidden_size,
88
89
            self.head_dim,
            self.total_num_heads,
Qing's avatar
Qing committed
90
            bias=True,
91
            linear_method=linear_method,
Qing's avatar
Qing committed
92
93
94
95
96
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
97
            linear_method=linear_method,
Qing's avatar
Qing committed
98
99
        )
        self.scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101

        self.rotary_emb = get_rope(
Qing's avatar
Qing committed
102
103
104
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
Qing's avatar
Qing committed
109
110
111
112
113
114
115
116
117
118
119

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        q, k = self.rotary_emb(positions, q, k)
Qing's avatar
Qing committed
121
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
Qing's avatar
Qing committed
124
125
126
127
128
129
130

        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

131
132
133
134
135
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
136
        super().__init__()
Qing's avatar
Qing committed
137
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
138

139
        rope_theta = getattr(config, "rope_theta", 10000)
Qing's avatar
Qing committed
140
        rope_scaling = getattr(config, "rope_scaling", None)
Qing's avatar
Qing committed
141
        self.attn = QWenAttention(config.hidden_size,
142
143
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
Qing's avatar
Qing committed
144
                                  rope_theta=rope_theta,
145
146
                                  rope_scaling=rope_scaling,
                                  linear_method=linear_method)
Qing's avatar
Qing committed
147

Qing's avatar
Qing committed
148
        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
149

150
151
152
        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
                           linear_method=linear_method)
Qing's avatar
Qing committed
153
154
155
156
157
158
159
160

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
161
162
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Qing's avatar
Qing committed
163
        # Self Attention
164
165
166
167
168
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
Qing's avatar
Qing committed
169
170
171
172
173
174
175
176
177
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
178
        hidden_states, residual = self.ln_2(hidden_states, residual)
Qing's avatar
Qing committed
179
        hidden_states = self.mlp(hidden_states)
180
        return hidden_states, residual
Qing's avatar
Qing committed
181
182
183
184


class QWenModel(nn.Module):

185
186
187
188
189
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
190
191
192
193
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

194
        self.wte = VocabParallelEmbedding(
195
            config.vocab_size,
196
197
            config.hidden_size,
        )
198
199
200
201
        self.h = nn.ModuleList([
            QWenBlock(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
Qing's avatar
Qing committed
202
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
Qing's avatar
Qing committed
203
204
205
206
207
208
209
210
211
212

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
213
        residual = None
Qing's avatar
Qing committed
214
        for i in range(len(self.h)):
215
            cache_event = None if cache_events is None else cache_events[i]
Qing's avatar
Qing committed
216
            layer = self.h[i]
217
            hidden_states, residual = layer(
Qing's avatar
Qing committed
218
219
220
221
222
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
223
                residual,
Qing's avatar
Qing committed
224
            )
225
        hidden_states, _ = self.ln_f(hidden_states, residual)
Qing's avatar
Qing committed
226
227
228
229
230
        return hidden_states


class QWenLMHeadModel(nn.Module):

231
232
233
234
235
    def __init__(
        self,
        config: QWenConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ):
Qing's avatar
Qing committed
236
237
        super().__init__()
        self.config = config
238
239
240
        self.linear_method = linear_method
        self.transformer = QWenModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Qing's avatar
Qing committed
241
242
243
244
245
246
247
248
249
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
250
    ) -> torch.Tensor:
Qing's avatar
Qing committed
251
252
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
253
254
255
256
257
258
259
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
Qing's avatar
Qing committed
260
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
261
                                   sampling_metadata)
Qing's avatar
Qing committed
262
263
        return next_tokens

264
265
266
267
268
269
270
271
272
273
274
    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
Qing's avatar
Qing committed
275
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
276
                model_name_or_path, cache_dir, load_format, revision):
Qing's avatar
Qing committed
277
278
            if "rotary_emb.inv_freq" in name:
                continue
279
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
280
281
                if weight_name not in name:
                    continue
282
283
284
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
285
                break
286
287
288
289
290
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)