qwen.py 42.1 KB
Newer Older
Qing's avatar
Qing committed
1
2
3
4
5
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Woosuk Kwon's avatar
Woosuk Kwon committed
6
"""Inference-only QWen model compatible with HuggingFace weights."""
Qing's avatar
Qing committed
7

8
9
10
11
12
13
14
import math
import re
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
                    Optional, Tuple, TypedDict, Union)

import numpy as np
15
import torch
16
from PIL import Image
17
from torch import nn
18
19
from torchvision import transforms
from torchvision.transforms import InterpolationMode
20
from transformers import PretrainedConfig
Qing's avatar
Qing committed
21

gaoqiong's avatar
gaoqiong committed
22
23
24
import os
import re

25
from vllm.attention import Attention, AttentionMetadata
26
from vllm.config import CacheConfig, MultiModalConfig
27
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
28
29
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
                         token_inputs)
30
31
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
32
from vllm.model_executor.layers.layernorm import RMSNorm
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
35
36
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
47
48
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
49
from vllm.sequence import IntermediateTensors, SequenceData
50
from vllm.utils import is_list_of
Qing's avatar
Qing committed
51

gaoqiong's avatar
gaoqiong committed
52
from vllm import _custom_ops as ops
53
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
54

55
56
57
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
logger = init_logger(__name__)

# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START = "<img>"
IMG_END = "</img>"
IMG_PAD = "<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS = 256
# Image normalization params
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)


class QwenImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, 3, image_size, image_size)`

    Note that image_size is the value in the vision config to which we resize
    the image to in the normalization transform. Currently multi-image support
    can only be leveraged by passing image embeddings directly.
    """


class QwenImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size * num_images, 256, hidden_size)`

    `hidden_size` must match the hidden size of the language model backbone
    and is stored in the visual config of the model if we have one.
    """


QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]


class VisualAttention(nn.Module):
    """self-attention layer class.
    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        kdim: Optional[int] = None,
        vdim: Optional[int] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim \
            and self.vdim == embed_dim

        self.num_heads = num_heads

        # Per attention head and per partition values.
        assert embed_dim % num_heads == 0
        self.hidden_size_per_attention_head = embed_dim // num_heads
        self.num_attention_heads_per_partition = num_heads
        self.hidden_size_per_partition = embed_dim

        # Strided linear layer.
        assert self._qkv_same_embed_dim, \
                'Visual Attention implementation only supports self-attention'
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # query/key/value: [sq, b, h]
        sq, b, _ = x.size()
        mixed_x_layer = self.in_proj(x)

        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
        query_layer, key_layer, value_layer = mixed_x_layer.split(
            self.hidden_size_per_attention_head, dim=-1)

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        q_scaled = query_layer / self.norm_factor
        if attn_mask is not None:
            attention_probs = torch.baddbmm(attn_mask, q_scaled,
                                            key_layer.transpose(-2, -1))
        else:
            attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
        attention_probs = attention_probs.softmax(dim=-1)

        value_layer = value_layer.view(
            sq, b * self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head).transpose(0, 1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer)
Qing's avatar
Qing committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
        # change view [b, np, sq, hn]
        context_layer = context_layer.view(
            b, self.num_attention_heads_per_partition, sq,
            self.hidden_size_per_attention_head)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        output = self.out_proj(context_layer)

        return output


class QwenVMLP(nn.Module):
    """MLP for the visual component of the Qwen model."""

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.c_fc = ColumnParallelLinear(hidden_size,
                                         intermediate_size,
                                         bias=True,
                                         quant_config=quant_config)
        self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x):
        x, _ = self.c_fc(x)
        x = self.act_fn(x)
        x, _ = self.c_proj(x)
        return x


class VisualAttentionBlock(nn.Module):

    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable = nn.LayerNorm,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.attn = VisualAttention(d_model, n_head)
        self.mlp = QwenVMLP(
            hidden_size=d_model,
            intermediate_size=mlp_width,
            quant_config=quant_config,
        )

    def attention(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(x, attn_mask=attn_mask)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x


class TransformerBlock(nn.Module):

    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        norm_layer: Callable = nn.LayerNorm,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.width = width
        self.layers = layers

        self.resblocks = nn.ModuleList([
            VisualAttentionBlock(width,
                                 heads,
                                 mlp_ratio,
                                 norm_layer=norm_layer,
                                 quant_config=quant_config)
            for _ in range(layers)
        ])

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def get_cast_device(self) -> torch.device:
        return self.resblocks[0].mlp.c_fc.weight.device

    def forward(self,
                x: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for r in self.resblocks:
            x = r(x, attn_mask=attn_mask)
        return x


class VisionTransformer(nn.Module):

    def __init__(self,
                 image_size: int,
                 patch_size: int,
                 width: int,
                 layers: int,
                 heads: int,
                 mlp_ratio: float,
                 n_queries: int = 256,
                 output_dim: int = 512,
                 image_start_id: int = 151857,
                 quant_config: Optional[QuantizationConfig] = None,
                 **kwargs):
        super().__init__()
        image_height, image_width = self.image_size = (image_size, image_size)
        patch_height, patch_width = self.patch_size = (patch_size, patch_size)
        self.grid_size = (image_height // patch_height,
                          image_width // patch_width)
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=width,
                               kernel_size=patch_size,
                               stride=patch_size,
                               bias=False)

        # class embeddings and positional embeddings
        scale = width**-0.5
        self.positional_embedding = nn.Parameter(scale *
                                                 torch.randn(256, width))

        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.ln_pre = norm_layer(width)
        self.transformer = TransformerBlock(width,
                                            layers,
                                            heads,
                                            mlp_ratio,
                                            norm_layer=norm_layer,
                                            quant_config=quant_config)

        self.attn_pool = Resampler2(
            grid_size=int(math.sqrt(n_queries)),
            embed_dim=output_dim,
            num_heads=output_dim // 128,
            kv_dim=width,
            norm_layer=norm_layer,
            adaptive=False,
            do_post_projection=False,
        ).to(
            device=self.positional_embedding.device,
            dtype=self.positional_embedding.dtype,
        )

        self.ln_post = norm_layer(output_dim)
        self.proj = nn.Parameter(
            (output_dim**-0.5) * torch.randn(output_dim, output_dim))
        self.image_start_id = image_start_id
        self.image_end_id = image_start_id + 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(
            dtype=self.transformer.get_cast_dtype(),
            device=self.transformer.get_cast_device(),
        )

        # to patches
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1],
                      -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

        x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(
            x.size(1))))

        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.attn_pool(x)
        x = self.ln_post(x)
        x = x @ self.proj

        return x

    def get_image_positions(self,
                            input_ids: torch.Tensor) -> Optional[torch.Tensor]:
        """Given the input IDs, extracts start/stop points corresponding to
        images.

        args:
        Returns:
            Optional torch tensor corresponding to start/stop pairs of images.
        """
        if torch.any(input_ids == self.image_start_id):
            bos_pos = torch.where(input_ids == self.image_start_id)
            eos_pos = torch.where(input_ids == self.image_end_id)
            return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
        return None
404
405


406
class QWenMLP(nn.Module):
407
408
    """MLP for the language component of the Qwen model, which contains a
    MergedColumnParallelLinear merging 2 outputs via silu activation."""
409
410
411
412
413
414

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str = "silu",
415
        quant_config: Optional[QuantizationConfig] = None,
416
417
418
419
420
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
421
            quant_config=quant_config)
422
423
424
        self.c_proj = RowParallelLinear(intermediate_size,
                                        hidden_size,
                                        bias=False,
425
                                        quant_config=quant_config)
426
427
428
429
430
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

431
    def forward(self, x: torch.Tensor) -> torch.Tensor:
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.c_proj(x)
        return x


class QWenAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
447
        cache_config: Optional[CacheConfig] = None,
448
        quant_config: Optional[QuantizationConfig] = None,
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.c_attn = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
464
            quant_config=quant_config,
465
466
467
468
469
        )
        self.c_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
470
            quant_config=quant_config,
471
472
473
474
475
476
477
478
479
480
        )
        self.scaling = self.head_dim**-0.5

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
481
482
483
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
484
485
                              cache_config=cache_config,
                              quant_config=quant_config)
486
487
488
489
490
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
491
492
493
494
495

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
496
497
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
498
499
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
500
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
501
            qkv = qkv[...,:-32]
502
503
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        q, k = self.rotary_emb(positions, q, k)
504
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
505
506
507
508
509
510
511
512
        output, _ = self.c_proj(attn_output)
        return output


class QWenBlock(nn.Module):

    def __init__(
        self,
513
        config: PretrainedConfig,
514
        cache_config: Optional[CacheConfig] = None,
515
        quant_config: Optional[QuantizationConfig] = None,
516
517
518
519
520
521
522
523
524
525
526
    ):
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        self.attn = QWenAttention(config.hidden_size,
                                  config.num_attention_heads,
                                  config.max_position_embeddings,
                                  rope_theta=rope_theta,
                                  rope_scaling=rope_scaling,
527
                                  cache_config=cache_config,
528
                                  quant_config=quant_config)
529
530
531
532
533

        self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

        self.mlp = QWenMLP(config.hidden_size,
                           config.intermediate_size // 2,
534
                           quant_config=quant_config)
535
536
537
538
539

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
540
541
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
542
543
544
545
546
547
548
549
550
551
552
553
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln_1(hidden_states)
        else:
            hidden_states, residual = self.ln_1(hidden_states, residual)
        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
554
            attn_metadata=attn_metadata,
555
556
557
558
559
560
561
562
563
        )

        # Fully Connected
        hidden_states, residual = self.ln_2(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class QWenModel(nn.Module):
Qing's avatar
Qing committed
564

565
566
    def __init__(
        self,
567
        config: PretrainedConfig,
568
        cache_config: Optional[CacheConfig] = None,
569
        quant_config: Optional[QuantizationConfig] = None,
570
        prefix: str = "",
571
572
573
574
575
576
577
578
579
    ):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
580
581
582
583
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: QWenBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h")
584
        self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
585
586
587
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
588
589
590
        self.visual = VisionTransformer(**config.visual,
                                        quant_config=quant_config) if hasattr(
                                            config, "visual") else None
591
592
593
594
595

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
596
597
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
598
        intermediate_tensors: Optional[IntermediateTensors],
599
        pixel_values: Optional[QwenImageInputs],
600
    ) -> Union[torch.Tensor, IntermediateTensors]:
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        img_pos = None
        # If pixel / visual embeddings are provided, this is a visual model
        if pixel_values is not None and self.visual is not None:
            if pixel_values["type"] != "image_embeds":
                image_embeds = self.visual(pixel_values["data"])
            else:
                image_embeds = pixel_values["data"]

            # features should be of shape (# images, 256, hidden_dim)
            img_pos = self.visual.get_image_positions(input_ids)
            if isinstance(
                    img_pos,
                    np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
                raise ValueError(
                    f"Number of placeholders: {img_pos.shape[0]} "
                    f"does not match number of images {image_embeds.shape[0]}."
                )

619
620
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
621
622
623
624
625
            # Merge the image embeddings into the hidden states if actually have
            # visual features and the corresponding image tokens
            if img_pos is not None:
                for idx, (img_bos, img_eos) in enumerate(img_pos):
                    hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
626
627
628
629
630
631
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
632
633
634
635
            layer = self.h[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
636
                kv_caches[i - self.start_layer],
637
                attn_metadata,
638
639
                residual,
            )
640
641
642
643
644
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
645
646
647
648
        hidden_states, _ = self.ln_f(hidden_states, residual)
        return hidden_states


649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
def get_image_text(image_num: int, padding: bool) -> str:
    """Retrieves a placeholder text that when tokenized, will be expanded with
    image pads.

    Args:
        image_num: The number of the image that we want a text prompt for.
            Images should be indexed starting at 1.
        padding: Whether or not padding should be manually added.

    Returns:
        Text placeholder prompt for the image being considered.
    """
    image_start = f"Picture {image_num}: {IMG_START}"
    image_end = f"{IMG_END}\n"
    if not padding:
        return f"{image_start}{image_end}"
    return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"


def input_processor_for_qwen(ctx: InputContext,
669
                             inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
670
671
672
673
674
675
    """Processes the inputs, which may or may not be multimodal.
    Multimodal inputs will only be processed if the model has a "visual"
    component in its model config, otherwise they'll be ignored.

    Args:
        ctx: Context of the loaded model.
676
        inputs: LLM inputs which may have a multi_modal_data attribute.
677
678
679

    Returns:
        If the model is language only or not multimodal inputs were provided,
680
        returns inputs unmodified. Otherwise, processes the multimodal
681
682
        images / image embeddings and adds the fixed-length image placeholders.
    """
683
    multi_modal_data = inputs.get("multi_modal_data")
684
685
686
687
688

    # Only process images if we have multimodal data and a visual config
    hf_config = ctx.get_hf_config()
    if (multi_modal_data is None or "image" not in multi_modal_data
            or not hasattr(hf_config, "visual")):
689
        return inputs
690

691
692
    prompt = inputs.get("prompt")
    prompt_token_ids = inputs["prompt_token_ids"]
693
    model_config = ctx.model_config
694
695
696
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
697
698
699
700
701
702
703
    image_data = multi_modal_data["image"]
    if isinstance(image_data, torch.Tensor):
        num_dims = len(image_data.shape)
        if num_dims < 2 or num_dims > 3:
            raise ValueError(
                f"Expected img embeds to be have 3 dimensions, got {num_dims}")
        num_images = 1 if num_dims == 2 else image_data.shape[0]
704
    elif isinstance(image_data, Image.Image):
705
        num_images = 1
706
707
708
709
    elif is_list_of(image_data, Image.Image):
        num_images = len(image_data)
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729

    if prompt is None:
        prompt = tokenizer.decode(prompt_token_ids)

    # Drops anything between <img>/</img> tags; encoding with the tokenizer
    # will automatically add the image pads for the context.
    new_prompt, num_matched_images = re.subn(
        r"(Picture \d*: <img>).*?(<\/img>\n)",
        r"\1\2",
        prompt,
    )

    if num_matched_images != num_images:
        logger.warning(
            "Number of matched image placeholders %s doesn't match the number "
            "of expected images %s; check your placeholder formatting.",
            num_matched_images, num_images)

    new_prompt_token_ids = tokenizer.encode(new_prompt)

730
731
732
    return token_inputs(prompt=new_prompt,
                        prompt_token_ids=new_prompt_token_ids,
                        multi_modal_data=multi_modal_data)
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755


def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
    """Maps the input data to its MultiModalInputs (if any).

    Args:
        ctx: Context of the loaded model.
        data: data potentially containing image/image embeddings to be mapped
            to pixel_values in .forward() for a visual QWenLMHeadModel model.

    Returns:
        MultiModalInputs containing the stacked normalized images tensor or
        image embeddings.
    """
    # Early exit if we have provided an image to a language only Qwen model
    hf_config = ctx.get_hf_config()
    if not hasattr(hf_config, "visual"):
        logger.warning(
            "Images were provided but this model has no visual config; "
            "multimodal inputs will not be forwarded to the model.")
        return MultiModalInputs()

    model_config = ctx.model_config
756
757
758
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792

    image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
                                      add_special_tokens=False,
                                      return_tensors="pt").squeeze()
    image_start_id = image_pair_tok[0]
    image_end_id = image_pair_tok[-1]
    if (image_start_id + 1) != image_end_id:
        raise ValueError(
            f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
    if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
        raise ValueError(
            f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
            f"but got {image_pair_tok - 2}")

    hf_config = ctx.get_hf_config()
    image_size = hf_config.visual["image_size"]
    img_emb_size = hf_config.visual["output_dim"]

    if isinstance(data, torch.Tensor):
        # It's expected that our values have already been processed
        # by the visual transformer; shape is expected to be:
        # (# images, 256, hidden_size)
        if len(data.shape) == 2:
            # Assume only one image embed was provided; unsqueeze the extra dim
            data = data.unsqueeze(0)
        if len(data.shape) != 3 or data.shape[
                1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
            raise ValueError(
                "Expected image embeds to be a tensor of shape"
                f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
                f"received shape [{data.shape}]")
        pixel_values = data
    else:
        transform = build_normalization_transform(image_size)
793
794
795
        if not isinstance(data, (list, tuple)):
            data = [data]
        transformed_images = [transform(datum) for datum in data]
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        pixel_values = torch.stack(transformed_images, dim=0)
    return MultiModalInputs({"pixel_values": pixel_values})


def build_normalization_transform(image_size: int) -> transforms.Compose:
    """Builds a normalization transform which can be applied to one or
    more input images from which we want to extract visual features.

    Args:
        image_size: size of the image to be processed for visual embeddings.
    
    Returns:
        Callable transform for normalizing and resizing one RGB image.
    """
    return transforms.Compose([
        transforms.Resize((image_size, image_size),
                          interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
    ])


def dummy_data_for_qwen(
    ctx: InputContext,
    seq_len: int,
    mm_counts: Mapping[str, int],
) -> Tuple[SequenceData, Optional[Dict]]:
    """Build dummy data for warming up Qwen models; this will only contain text
    matching the defaults for VLLM unless the model has a visual config.

    Args:
        ctx: Context of the loaded model.
        seq_len: Number of tokens in the text sequence.
        mm_counts: multimodal data counts.
    
    Returns:
        Tuple containing sequential and multimodal data.
    """
    hf_config = ctx.get_hf_config()

    # The presence of a visual config indicates this is a multimodal model.
    # If we don't have it, the model is considered an LLM for warmup purposes.
    if not hasattr(hf_config, "visual"):
839
        seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
840
841
842
843
844
845
        mm_data = None
        return seq_data, mm_data

    # We have a visual component - use images to warm up
    num_images = mm_counts["image"]
    model_config = ctx.model_config
846
847
848
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866

    # Build the image prompts with no imgpads; the tokenizer will add img pads
    image_prompt = ''.join(
        [get_image_text(idx, False) for idx in range(1, num_images + 1)])
    toks = tokenizer.encode(image_prompt, add_special_tokens=False)

    # Make sure we actually get the fixed context size per tok padding
    num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
    if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
        raise ValueError(
            f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
            f" per image, but got {num_pads} pads for {num_images} image(s)"
            " in total. Are you using a qwen tokenizer?")

    # Ensure the number of tokens is at minimum the sequence length provided
    if len(toks) < seq_len:
        toks += [0] * (seq_len - len(toks))

867
868
    seq_data = SequenceData.from_seqs(toks)

869
870
871
872
    # Build the input images; width/height doesn't actually matter here since
    # the data will get resized and the # of tokens per image is constant
    image = Image.new("RGB", (224, 224), color=0)
    mm_data = {"image": image if num_images == 1 else [image] * num_images}
873
    return seq_data, mm_data
874
875
876
877
878
879


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
880
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
881
882
883

    def __init__(
        self,
884
        config: PretrainedConfig,
885
        multimodal_config: MultiModalConfig,
886
        cache_config: Optional[CacheConfig] = None,
887
        quant_config: Optional[QuantizationConfig] = None,
888
889
890
    ):
        super().__init__()
        self.config = config
891
        self.multimodal_config = multimodal_config
892
        self.quant_config = quant_config
893
        self.transformer = QWenModel(config, cache_config, quant_config)
894
895
896
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
897
898
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.transformer.wte.weight
899
900
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
901
902
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
gaoqiong's avatar
gaoqiong committed
903
        
904
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
905
906
907
908
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
909
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
910
911
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
912
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    def _get_image_input_type(
            self,
            pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
        """Determines if the provided pixel_values are normalized pixel values
        or image embeddings.

        Args:
            pixel_values: Optional data to processed into visual embeddings.

        Returns:
            None of the QwenImageInputs type used to determine whether or not
            the visual transformer needs to process the pixel_values.
        """
        if pixel_values is not None and self.transformer.visual is not None:
            pixel_values = flatten_bn(pixel_values)
            if len(pixel_values.shape) == 3 and pixel_values.shape[
                    1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
                        2] == self.config.visual["output_dim"]:
                return QwenImageEmbeddingInputs(
                    type="image_embeds",
                    data=pixel_values,
                )
            else:
                # If we have the wrong shape, assume we still need to process
                return QwenImagePixelInputs(
                    type="pixel_values",
                    data=pixel_values,
                )
        return None
943
944
945
946
947

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
948
949
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
950
        intermediate_tensors: Optional[IntermediateTensors] = None,
951
952
953
954
955
956
957
958
        pixel_values: Optional[torch.Tensor] = None
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            input_ids = None
            pixel_values = None
        else:
            pixel_values = self._get_image_input_type(pixel_values)

959
        hidden_states = self.transformer(input_ids, positions, kv_caches,
960
961
                                         attn_metadata, intermediate_tensors,
                                         pixel_values)
962
963
        return hidden_states

964
965
966
967
968
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
969
        logits = self.logits_processor(self.lm_head, hidden_states,
970
971
972
                                       sampling_metadata)
        return logits

973
974
    def sample(
        self,
975
        logits: torch.Tensor,
976
977
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
978
        next_tokens = self.sampler(logits, sampling_metadata)
979
        return next_tokens
Qing's avatar
Qing committed
980

981
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
982
983
984
985
986
987
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "w2", 0),
            ("gate_up_proj", "w1", 1),
        ]
        params_dict = dict(self.named_parameters())
988
        for name, loaded_weight in weights:
Qing's avatar
Qing committed
989
990
            if "rotary_emb.inv_freq" in name:
                continue
991
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Qing's avatar
Qing committed
992
993
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
994
995
996
997
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
998
999
1000
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
1001
                param = params_dict[name]
1002
1003
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Qing's avatar
Qing committed
1004
                break
1005
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1006
1007
1008
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1009
1010
1011
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
1012
1013
1014
1015
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1016
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
1017
            lay_key_words = [
zhuwenwen's avatar
zhuwenwen committed
1018
1019
                "attn.c_attn.weight",
                "attn.c_proj.weight",
gaoqiong's avatar
gaoqiong committed
1020
                "mlp.gate_up_proj.weight",
1021
1022
                "mlp.c_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
1023
1024
1025
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
1026
1027
1028
1029
1030
1031
            lay_qkv_words = ["attn.c_attn.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["attn.c_attn.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
                      
gaoqiong's avatar
gaoqiong committed
1032
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
1033
1034
1035
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                
gaoqiong's avatar
gaoqiong committed
1036
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
1037
                if matches:         
1038
1039
1040
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
1041
1042
1043
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
1044
                        
gaoqiong's avatar
gaoqiong committed
1045
1046
1047
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
1048
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
1049
1050
1051
1052
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
                    
zhuwenwen's avatar
zhuwenwen committed
1053
        if self.quant_method == "awq":
gaoqiong's avatar
gaoqiong committed
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            lay_key_words = [
                "attn.c_attn.qweight",
                "attn.c_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.c_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
1077
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
1078
                    
gaoqiong's avatar
gaoqiong committed
1079
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
1080
1081
1082
1083
1084
1085
1086
1087
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
                    
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
1088
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
1089
1090
1091
1092
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
                       
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "attn.c_attn.weight",
                "attn.c_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.c_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
                if matches:
                    weight_data =params_dict[layername]
                    k=weight_data.shape[0]
                    _weight=weight_data.T.contiguous().reshape(k,-1)
zhuwenwen's avatar
zhuwenwen committed
1109
                    weight_data.data.copy_(_weight)