"vllm/vscode:/vscode.git/clone" did not exist on "6e9ff050c8e83ad6d5e5eab621e83549e35933a1"
chatglm.py 27.8 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
# Adapted from
2
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
3
"""Inference-only ChatGLM model compatible with THUDM weights."""
4
5
from argparse import Namespace
from array import array
6
7
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
                    TypedDict)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
13
from torch import nn
from torch.nn import LayerNorm

14
from vllm.attention import Attention, AttentionMetadata
15
from vllm.config import CacheConfig, VllmConfig
16
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
17
18
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
21
from vllm.model_executor.layers.layernorm import RMSNorm
22
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
23
24
                                               QKVParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
28
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
29
from vllm.model_executor.layers.vocab_parallel_embedding import (
30
    ParallelLMHead, VocabParallelEmbedding)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
from vllm.model_executor.sampling_metadata import SamplingMetadata
35
36
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
37
38
39
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
40
41
from vllm.transformers_utils.configs import ChatGLMConfig

42
43
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
44
45
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
46
47
48
49
50
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
59
60
61
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

80
    return MultiModalKwargs({'pixel_values': pixel_values})
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


122
123
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
                        mm_counts: Mapping[str, int]) -> DummyData:
124
125
126
127
128
129
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
130
        return DummyData(seq_data, None)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

145
        return DummyData(seq_data, mm_data)
146
147
148
149
150
151
152
153
154

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


155
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
156
157
158
159
    multi_modal_data = inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return inputs

160
161
162
163
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
164
        return inputs
165
166
167
168
169
170
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

171
172
    input_ids = inputs["prompt_token_ids"]

173
174
175
176
177
178
179
180
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
181
182
                "image": multi_modal_data["image"],
                "content": inputs['prompt'],
183
184
185
186
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
187
188
            return_dict=True,
        ).data
189
    except Exception:
190
        logger.error("Failed to process content (%s)", inputs['prompt'])
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])

215
216
217
    prompt = inputs.get("prompt")
    if prompt is None:
        prompt = tokenizer.decode(new_input_ids)
218

219
220
221
222
223
    return token_inputs(
        prompt_token_ids=new_input_ids,
        prompt=prompt,
        multi_modal_data=multi_modal_data,
    )
224

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
225
226
227

class GLMAttention(nn.Module):

228
229
    def __init__(
        self,
230
        config: ChatGLMConfig,
231
        cache_config: Optional[CacheConfig] = None,
232
        quant_config: Optional[QuantizationConfig] = None,
233
        prefix: str = "",
234
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
235
236
237
238
239
240
241
242
243
244
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
245
246
247
248
249
250
251
252
253
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
254
255
256
257
258
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

259
260
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
261
            self.head_dim,
262
263
264
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
265
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
266
267
268
269
270
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
271
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
272
273
        )

274
275
276
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
278
279
            self.head_dim,
            rotary_dim=self.head_dim // 2,
280
281
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
282
283
            is_neox_style=False,
        )
284
285
286
287
288
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
289
290
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
291
292
293
294
295

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
296
297
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
298
299
300
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
302
303
304
305
        context_layer = self.attn(
            q,
            k,
            v,
306
307
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
308
309
310
311
312
313
314
315
316
317
318
319
320
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

321
322
    def __init__(
        self,
323
        config: ChatGLMConfig,
324
        quant_config: Optional[QuantizationConfig] = None,
325
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
326
327
328
329
330
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
331
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
332
            config.hidden_size,
333
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
334
            bias=config.add_bias_linear,
335
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
336
337
338
339
340
341
342
343
344
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
345
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
366
        config: ChatGLMConfig,
367
        cache_config: Optional[CacheConfig] = None,
368
        quant_config: Optional[QuantizationConfig] = None,
369
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
370
371
372
373
374
375
376
377
378
379
380
381
382
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
383
384
385
386
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
387
388
389
390
391
392
393
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
394
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
395
396
397
398
399

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
400
401
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
402
403
404
405
406
407
408
409
410
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
411
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

439
440
    def __init__(
        self,
441
        config: ChatGLMConfig,
442
        cache_config: Optional[CacheConfig] = None,
443
        quant_config: Optional[QuantizationConfig] = None,
444
        prefix: str = "",
445
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
446
447
448
449
450
451
452
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
453
454
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
455
456
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
457
458
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
459
460
461
462
463
464
465

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

466
467
468
469
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
470
471
472
473
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
474
475
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
476
    ) -> torch.Tensor:
477
        for i in range(self.start_layer, self.end_layer):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
478
479
480
481
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
482
                kv_cache=kv_caches[i - self.start_layer],
483
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
484
485
            )
        # Final layer norm.
486
        if get_pp_group().is_last_rank and self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
487
488
489
490
491
492
493
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

494
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
495
496
        super().__init__()

497
498
499
500
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

501
502
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
503
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
504
505
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
506
507
508
509

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
510
511
512
513
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
514

515
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
516
                                           config.hidden_size,
517
518
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
519
520
521
522

        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
523
524
525
            self.vision = EVA2CLIPModel(self.config,
                                        quant_config,
                                        prefix=f"{prefix}.vision")
526
527
528
        else:
            self.vision = None

529
530
531
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

532
533
534
535
536
537
538
539
540
541
542
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
543
                raise TypeError("""pixel_values must be a torch.Tensor
544
545
546
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
547
548
549
550

    def forward(
        self,
        input_ids: torch.Tensor,
551
        positions: torch.Tensor,
552
553
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
554
555
556
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> torch.Tensor:
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        if intermediate_tensors is None:
            inputs_embeds = self.embedding(input_ids)
            image_input = self._parse_and_validate_image_input(**kwargs)

            if image_input["pixel_values"] is not None:
                pixel_values = image_input["pixel_values"].to(
                    dtype=inputs_embeds.dtype)
                image_embeds = self.vision(pixel_values)

                boi_token_id = self.config.boi_token_id
                eoi_token_id = self.config.eoi_token_id

                inputs_embeds = merge_glm_vision_embeddings(
                    input_ids=input_ids,
                    inputs_embeds=inputs_embeds,
                    vision_embeddings=image_embeds,
                    boi_token_id=boi_token_id,
                    eoi_token_id=eoi_token_id)
        else:
            inputs_embeds = intermediate_tensors["hidden_states"]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
577
578
579
580

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
581
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
582
            kv_caches=kv_caches,
583
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
584
        )
585
586
587

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
588
589
590
        return hidden_states


591
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
592

593
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
594
        super().__init__()
595
596
597
598
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
599
600
        self.config = config
        self.lora_config = lora_config
601
        self.multimodal_config = multimodal_config
602

603
        self.quant_config = quant_config
604
605
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
606
607
608
        self.transformer = ChatGLMModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(
                                            prefix, "transformer"))
609
610
611
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
612
        self.lm_head = self.transformer.output_layer
613
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
Joe Runde's avatar
Joe Runde committed
614
        self.sampler = get_sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
615

616
617
618
619
620
621
622
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
623
        hidden_states = self.transformer(input_ids, positions, kv_caches,
624
625
                                         attn_metadata, intermediate_tensors,
                                         **kwargs)
626
627
        return hidden_states

628
629
630
631
632
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
633
        logits = self.logits_processor(self.lm_head, hidden_states,
634
635
636
                                       sampling_metadata)
        return logits

637
638
    def sample(
        self,
639
        logits: torch.Tensor,
640
        sampling_metadata: SamplingMetadata,
641
    ) -> Optional[SamplerOutput]:
642
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
643
644
        return next_tokens

645
646
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
647
648
649
650
651
652
653
654
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }

655
        params_dict = dict(self.named_parameters(remove_duplicate=False))
656
        loaded_params: Set[str] = set()
657
        for name, loaded_weight in weights:
658
659
660
661
662
663
664
665
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
666
667
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
668
669
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
670
671
672
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
673
674
            if is_pp_missing_parameter(name, self):
                continue
675
676
677
678
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
679
            loaded_params.add(name)
680
681
682
683
684
685
686
687
688

        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)
689
690
                loaded_params.add(combined_name)
        return loaded_params
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709


class ChatGLM(ChatGLMBaseModel):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


710
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
        "merged_proj": ["gate_proj", "dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
        # vision
        "fc1",
        "fc2",
        "merged_proj",
        "linear_proj"
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
            tower_model="transformer.vision.transformer")


@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsMultiModal):
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        config = vllm_config.model_config.hf_config
        # Initialize VL
        if hasattr(config, "visual"):
            return ChatGLM(vllm_config=vllm_config, prefix=prefix)
        # Initialize LLM
        else:
            return ChatGLMV(vllm_config=vllm_config, prefix=prefix)