chatglm.py 31.4 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
# Adapted from
2
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
3
"""Inference-only ChatGLM model compatible with THUDM weights."""
4
5
from argparse import Namespace
from array import array
6
7
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
                    TypedDict)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
from torch import nn
from torch.nn import LayerNorm
zhuwenwen's avatar
zhuwenwen committed
13
import os
zhuwenwen's avatar
zhuwenwen committed
14
import re
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
15

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig, VllmConfig
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
20
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
21
from vllm.logger import init_logger
22
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
25
26
                                               QKVParallelLinear,
                                               RowParallelLinear)
27
from vllm.model_executor.layers.logits_processor import LogitsProcessor
28
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
30
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
    ParallelLMHead, VocabParallelEmbedding)
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.model_executor.models.module_mapping import MultiModelKeys
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
40
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
                                    NestedTensors)
41
42
43
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
44
from vllm.transformers_utils.configs import ChatGLMConfig
45

zhuwenwen's avatar
zhuwenwen committed
46
47
from .interfaces import SupportsLoRA, SupportsMultiModal

48
49
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
52

zhuwenwen's avatar
zhuwenwen committed
53
from vllm import _custom_ops as ops
54
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
69
70
71
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

90
    return MultiModalKwargs({'pixel_values': pixel_values})
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


132
133
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
                        mm_counts: Mapping[str, int]) -> DummyData:
134
135
136
137
138
139
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
140
        return DummyData(seq_data, None)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

155
        return DummyData(seq_data, mm_data)
156
157
158
159
160
161
162
163
164

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


165
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
166
167
168
169
    multi_modal_data = inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return inputs

170
171
172
173
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
174
        return inputs
175
176
177
178
179
180
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

181
182
    input_ids = inputs["prompt_token_ids"]

183
184
185
186
187
188
189
190
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
191
192
                "image": multi_modal_data["image"],
                "content": inputs['prompt'],
193
194
195
196
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
197
198
            return_dict=True,
        ).data
199
    except Exception:
200
        logger.error("Failed to process content (%s)", inputs['prompt'])
201
202
203
204
205
206
207
208
209
210
211
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
zhuwenwen's avatar
zhuwenwen committed
212

213
214
215
216
217
218
219
220
221
222
223
224
225
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])

226
227
228
    prompt = inputs.get("prompt")
    if prompt is None:
        prompt = tokenizer.decode(new_input_ids)
229

230
231
232
233
234
    return token_inputs(
        prompt_token_ids=new_input_ids,
        prompt=prompt,
        multi_modal_data=multi_modal_data,
    )
235

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
236
237
238

class GLMAttention(nn.Module):

239
240
    def __init__(
        self,
241
        config: ChatGLMConfig,
242
        cache_config: Optional[CacheConfig] = None,
243
        quant_config: Optional[QuantizationConfig] = None,
244
        prefix: str = "",
245
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
246
247
248
249
250
251
252
253
254
255
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
256
257
258
259
260
261
262
263
264
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
265
266
267
268
269
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

270
271
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
272
            self.head_dim,
273
274
275
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
276
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
277
278
279
280
281
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
282
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
283
284
        )

285
286
287
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
289
290
            self.head_dim,
            rotary_dim=self.head_dim // 2,
291
292
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
293
294
            is_neox_style=False,
        )
295
296
297
298
299
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
300
301
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
302
303
304
305
306
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
307
308
309
310
311

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
312
313
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
314
315
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
316
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
317
            qkv = qkv[...,:-32]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
318
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
319
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
320
321
322
323
        context_layer = self.attn(
            q,
            k,
            v,
324
325
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
326
327
328
329
330
331
332
333
334
335
336
337
338
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

339
340
    def __init__(
        self,
341
        config: ChatGLMConfig,
342
        quant_config: Optional[QuantizationConfig] = None,
343
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
344
345
346
347
348
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
349
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
350
            config.hidden_size,
351
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
352
            bias=config.add_bias_linear,
353
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
354
355
356
357
358
359
360
361
362
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
363
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
384
        config: ChatGLMConfig,
385
        cache_config: Optional[CacheConfig] = None,
386
        quant_config: Optional[QuantizationConfig] = None,
387
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
388
389
390
391
392
393
394
395
396
397
398
399
400
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
401
402
403
404
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
405
406
407
408
409
410
411
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
412
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
413
414
415
416
417

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
418
419
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
420
421
422
423
424
425
426
427
428
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
429
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

457
458
    def __init__(
        self,
459
        config: ChatGLMConfig,
460
        cache_config: Optional[CacheConfig] = None,
461
        quant_config: Optional[QuantizationConfig] = None,
462
        prefix: str = "",
463
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
464
465
466
467
468
469
470
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
471
472
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
473
474
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
475
476
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
477
478
479
480
481
482
483

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

484
485
486
487
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
488
489
490
491
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
492
493
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
494
    ) -> torch.Tensor:
495
        for i in range(self.start_layer, self.end_layer):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
496
497
498
499
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
500
                kv_cache=kv_caches[i - self.start_layer],
501
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
502
503
            )
        # Final layer norm.
504
        if get_pp_group().is_last_rank and self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
505
506
507
508
509
510
511
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

512
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
513
        super().__init__()
514
515
        
        self.config = config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
516

517
518
519
520
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

521
522
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
523
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
524
525
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
526
527
528
529

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
530
531
532
533
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
534

535
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
536
                                           config.hidden_size,
537
538
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
539

540
541
542
        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
543
544
545
            self.vision = EVA2CLIPModel(self.config,
                                        quant_config,
                                        prefix=f"{prefix}.vision")
546
547
        else:
            self.vision = None
548

549
550
551
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

552
553
554
555
556
557
558
559
560
561
562
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
563
                raise TypeError("""pixel_values must be a torch.Tensor
564
565
566
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
567

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input["pixel_values"] is None:
            return None
        pixel_values = image_input["pixel_values"].to(
            dtype=self.config.torch_dtype)
        vision_embeddings = self.vision(pixel_values)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.embedding(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_glm_vision_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                vision_embeddings=multimodal_embeddings,
                boi_token_id=self.config.boi_token_id,
                eoi_token_id=self.config.eoi_token_id)
        return inputs_embeds

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
592
593
594
    def forward(
        self,
        input_ids: torch.Tensor,
595
        positions: torch.Tensor,
596
597
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
598
        intermediate_tensors: Optional[IntermediateTensors] = None,
599
        inputs_embeds: Optional[torch.Tensor] = None,
600
        **kwargs: object,
601
    ) -> torch.Tensor:
602

603
604
605
606
607
608
609
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        if intermediate_tensors is None and inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
610
611
        else:
            inputs_embeds = intermediate_tensors["hidden_states"]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
612
613
614
615

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
616
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
617
            kv_caches=kv_caches,
618
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
619
        )
620
621
622

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
623
624
625
        return hidden_states


626
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
627
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
628
        super().__init__()
629
630
631
632
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
633
634
        self.config = config
        self.lora_config = lora_config
635
        self.multimodal_config = multimodal_config
636

637
        self.quant_config = quant_config
638
639
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
640
641
642
        self.transformer = ChatGLMModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(
                                            prefix, "transformer"))
643
644
645
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
646
        self.lm_head = self.transformer.output_layer
647
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
zhuwenwen's avatar
zhuwenwen committed
648

Joe Runde's avatar
Joe Runde committed
649
        self.sampler = get_sampler()
650
651
652
653
654
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

zhuwenwen's avatar
zhuwenwen committed
655
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
656
657
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
658

659
660
661
662
663
664
665
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
666
        hidden_states = self.transformer(input_ids, positions, kv_caches,
667
668
                                         attn_metadata, intermediate_tensors,
                                         **kwargs)
669
670
        return hidden_states

671
672
673
674
675
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
676
        logits = self.logits_processor(self.lm_head, hidden_states,
677
678
679
                                       sampling_metadata)
        return logits

680
681
    def sample(
        self,
682
        logits: torch.Tensor,
683
        sampling_metadata: SamplingMetadata,
684
    ) -> Optional[SamplerOutput]:
685
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
686
687
        return next_tokens

688
689
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
690
691
692
693
694
695
696
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }
697

698
        params_dict = dict(self.named_parameters(remove_duplicate=False))
699
        loaded_params: Set[str] = set()
700
        for name, loaded_weight in weights:
701
702
703
704
705
706
707
708
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
709
710
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
711
712
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
713
714
715
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
716
717
            if is_pp_missing_parameter(name, self):
                continue
718
719
720
721
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
722
            loaded_params.add(name)
723

724
725
726
727
728
729
730
731
        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)
732
                loaded_params.add(combined_name)
zhuwenwen's avatar
zhuwenwen committed
733
                
734
        if self.use_llama_nn and self.quant_method is None:
zhuwenwen's avatar
zhuwenwen committed
735
736
737
738
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
739
                "mlp.dense_4h_to_h.weight",
zhuwenwen's avatar
zhuwenwen committed
740
741
742
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
743
744
745
746
747
748
            lay_qkv_words = ["self_attention.query_key_value.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words)
            
zhuwenwen's avatar
zhuwenwen committed
749
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
750
751
752
753
754
755
                if "lm_head.weight" in layername and weight.shape[1] == 4096:
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
756
757
758
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                    
zhuwenwen's avatar
zhuwenwen committed
759
                matches = re.findall(combined_words, layername)
760
761
762
763
                if matches:  
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
764
765
766
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
767
                                        
zhuwenwen's avatar
zhuwenwen committed
768
769
770
771
772
773
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
774
775
                    weight.data=weight.data.reshape(ori_shape[1], -1)
                    
776
        return loaded_params
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795


class ChatGLM(ChatGLMBaseModel):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


796
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
        "merged_proj": ["gate_proj", "dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
        # vision
        "fc1",
        "fc2",
        "merged_proj",
        "linear_proj"
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
            tower_model="transformer.vision.transformer")


@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsMultiModal):
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        config = vllm_config.model_config.hf_config
        # Initialize VL
        if hasattr(config, "visual"):
849
            return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
850
851
        # Initialize LLM
        else:
852
            return ChatGLM(vllm_config=vllm_config, prefix=prefix)