"examples/vscode:/vscode.git/clone" did not exist on "c0745a851a4f6d9a3651d768abb1c14ab8353827"
chatglm.py 29.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
3
# Adapted from
4
5
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
6
7
from argparse import Namespace
from array import array
8
9
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
                    TypedDict)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
10
11

import torch
12
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
13
14
15
from torch import nn
from torch.nn import LayerNorm

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig, VllmConfig
18
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19
20
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
21
from vllm.logger import init_logger
22
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
25
26
                                               QKVParallelLinear,
                                               RowParallelLinear)
27
from vllm.model_executor.layers.logits_processor import LogitsProcessor
28
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
30
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
    ParallelLMHead, VocabParallelEmbedding)
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.sampling_metadata import SamplingMetadata
37
from vllm.multimodal import MULTIMODAL_REGISTRY
38
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
39
                                    NestedTensors)
40
41
42
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
43
44
from vllm.transformers_utils.configs import ChatGLMConfig

45
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
46
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49
50
51
52
53
54
55
56
57
58

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
59
    data: ModalityData[object],
60
61
) -> Dict:
    model_config = ctx.model_config
62
63
64
    tokenizer = cached_get_tokenizer(
        model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

83
    return MultiModalKwargs({'pixel_values': pixel_values})
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


125
126
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
                        mm_counts: Mapping[str, int]) -> DummyData:
127
128
129
130
131
132
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
133
        return DummyData(seq_data, None)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

148
        return DummyData(seq_data, mm_data)
149
150
151
152
153
154
155
156
157

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


158
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
159
160
161
162
    multi_modal_data = inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return inputs

163
164
165
166
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
167
        return inputs
168
169
170
171
172
173
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

174
175
    input_ids = inputs["prompt_token_ids"]

176
177
178
179
180
181
182
183
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
184
185
                "image": multi_modal_data["image"],
                "content": inputs['prompt'],
186
187
188
189
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
190
191
            return_dict=True,
        ).data
192
    except Exception:
193
        logger.error("Failed to process content (%s)", inputs['prompt'])
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])

217
218
219
    prompt = inputs.get("prompt")
    if prompt is None:
        prompt = tokenizer.decode(new_input_ids)
220

221
222
223
224
225
    return token_inputs(
        prompt_token_ids=new_input_ids,
        prompt=prompt,
        multi_modal_data=multi_modal_data,
    )
226

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
227
228
229

class GLMAttention(nn.Module):

230
231
    def __init__(
        self,
232
        config: ChatGLMConfig,
233
        cache_config: Optional[CacheConfig] = None,
234
        quant_config: Optional[QuantizationConfig] = None,
235
        prefix: str = "",
236
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
237
238
239
240
241
242
243
244
245
246
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
247
248
249
250
251
252
253
254
255
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
256
257
258
259
260
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

261
262
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
263
            self.head_dim,
264
265
266
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
267
            quant_config=quant_config,
268
            prefix=f"{prefix}.query_key_value",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
269
270
271
272
273
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
274
            quant_config=quant_config,
275
            prefix=f"{prefix}.dense",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
276
277
        )

278
279
280
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
281
282
283
        # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
        # which is equivalent to is_neox_style=True
        is_neox_style = not config.original_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
285
286
            self.head_dim,
            rotary_dim=self.head_dim // 2,
287
288
            max_position=max_positions,
            base=10000 * rope_ratio,
289
            is_neox_style=is_neox_style,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
290
        )
291
292
293
294
295
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
296
297
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
298
299
300
301
302

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
303
304
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
305
306
307
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
309
310
311
312
        context_layer = self.attn(
            q,
            k,
            v,
313
314
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
315
316
317
318
319
320
321
322
323
324
325
326
327
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

328
329
    def __init__(
        self,
330
        config: ChatGLMConfig,
331
        quant_config: Optional[QuantizationConfig] = None,
332
        prefix: str = "",
333
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
334
335
336
337
338
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
339
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
340
            config.hidden_size,
341
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
342
            bias=config.add_bias_linear,
343
            quant_config=quant_config,
344
            prefix=f"{prefix}.dense_h_to_4h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
345
346
347
348
349
350
351
352
353
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
354
            quant_config=quant_config,
355
            prefix=f"{prefix}.dense_4h_to_h",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
376
        config: ChatGLMConfig,
377
        cache_config: Optional[CacheConfig] = None,
378
        quant_config: Optional[QuantizationConfig] = None,
379
        prefix: str = "",
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
380
381
382
383
384
385
386
387
388
389
390
391
392
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
393
394
395
396
        self.self_attention = GLMAttention(config,
                                           cache_config,
                                           quant_config,
                                           prefix=f"{prefix}.self_attention")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
397
398
399
400
401
402
403
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
404
        self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
405
406
407
408
409

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
410
411
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
412
413
414
415
416
417
418
419
420
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
421
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

449
450
    def __init__(
        self,
451
        config: ChatGLMConfig,
452
        cache_config: Optional[CacheConfig] = None,
453
        quant_config: Optional[QuantizationConfig] = None,
454
        prefix: str = "",
455
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
456
457
458
459
460
461
462
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
463
464
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.num_layers,
465
466
            lambda prefix: GLMBlock(
                config, cache_config, quant_config, prefix=prefix),
467
468
            prefix=f"{prefix}.layers",
        )
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
469
470
471
472
473
474
475

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

476
477
478
479
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
480
481
482
483
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
484
485
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
486
    ) -> torch.Tensor:
487
        for i in range(self.start_layer, self.end_layer):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
488
489
490
491
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
492
                kv_cache=kv_caches[i - self.start_layer],
493
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
494
495
            )
        # Final layer norm.
496
        if get_pp_group().is_last_rank and self.post_layer_norm:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
497
498
499
500
501
502
503
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

504
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
505
506
        super().__init__()

507
508
509
510
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

511
512
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
513
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
514
                                                config.hidden_size,
515
516
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.embedding")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
517
518
519
520

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
521
522
523
524
        self.encoder = GLMTransformer(config,
                                      cache_config,
                                      quant_config,
                                      prefix=f"{prefix}.encoder")
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
525

526
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
527
                                           config.hidden_size,
528
529
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.output_layer")
530
531
532
533

        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
534
535
536
            self.vision = EVA2CLIPModel(self.config,
                                        quant_config,
                                        prefix=f"{prefix}.vision")
537
538
539
        else:
            self.vision = None

540
541
542
        self.make_empty_intermediate_tensors = (
            self.encoder.make_empty_intermediate_tensors)

543
544
545
546
547
548
549
550
551
552
553
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
554
                raise TypeError("""pixel_values must be a torch.Tensor
555
556
557
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
558

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input["pixel_values"] is None:
            return None
        pixel_values = image_input["pixel_values"].to(
            dtype=self.config.torch_dtype)
        vision_embeddings = self.vision(pixel_values)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.embedding(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_glm_vision_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                vision_embeddings=multimodal_embeddings,
                boi_token_id=self.config.boi_token_id,
                eoi_token_id=self.config.eoi_token_id)
        return inputs_embeds

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
583
584
585
    def forward(
        self,
        input_ids: torch.Tensor,
586
        positions: torch.Tensor,
587
588
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
589
        intermediate_tensors: Optional[IntermediateTensors] = None,
590
        inputs_embeds: Optional[torch.Tensor] = None,
591
592
        **kwargs: object,
    ) -> torch.Tensor:
593
594
595
596
597
598
599
600

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        if intermediate_tensors is None and inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
601
602
        else:
            inputs_embeds = intermediate_tensors["hidden_states"]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
603
604
605
606

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
607
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
608
            kv_caches=kv_caches,
609
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
610
        )
611
612
613

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
614
615
        return hidden_states

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
            ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "rotary_pos_emb.inv_freq" in name:
                    continue
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
654

655
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
656

657
658
659
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={".word_embeddings": ""}, )

660
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
661
        super().__init__()
662
663
664
665
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        multimodal_config = vllm_config.model_config.multimodal_config
666
667
        self.config = config
        self.lora_config = lora_config
668
        self.multimodal_config = multimodal_config
669

670
        self.quant_config = quant_config
671
672
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
673
674
675
        self.transformer = ChatGLMModel(vllm_config=vllm_config,
                                        prefix=maybe_prefix(
                                            prefix, "transformer"))
676
677
678
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
679
        self.lm_head = self.transformer.output_layer
680
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
Joe Runde's avatar
Joe Runde committed
681
        self.sampler = get_sampler()
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
682

683
684
685
686
687
688
689
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
690
        hidden_states = self.transformer(input_ids, positions, kv_caches,
691
692
                                         attn_metadata, intermediate_tensors,
                                         **kwargs)
693
694
        return hidden_states

695
696
697
698
699
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
700
        logits = self.logits_processor(self.lm_head, hidden_states,
701
702
703
                                       sampling_metadata)
        return logits

704
705
    def sample(
        self,
706
        logits: torch.Tensor,
707
        sampling_metadata: SamplingMetadata,
708
    ) -> Optional[SamplerOutput]:
709
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
710
711
        return next_tokens

712
713
714
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733


class ChatGLM(ChatGLMBaseModel):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]

    embedding_modules = {}
    embedding_padding_modules = []


734
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
735

736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
        "merged_proj": ["gate_proj", "dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
        # vision
        "fc1",
        "fc2",
        "merged_proj",
        "linear_proj"
    ]

    embedding_modules = {}
    embedding_padding_modules = []

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
            tower_model="transformer.vision.transformer")


@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
                         SupportsMultiModal):
    # Ensure that the LoRA support check passes when the class is not
    # initialized, but set all these attributes to empty.
775
    # These will be updated when an instance class is selected
776
777
778
779
780
781
782
783
784
785
786
    packed_modules_mapping = {}
    supported_lora_modules = []
    embedding_modules = {}
    embedding_padding_modules = []

    def __new__(
        cls,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        config = vllm_config.model_config.hf_config
787

788
        # Initialize VL
789
790
        if hasattr(config, "vision_config"):  # noqa: SIM108
            instance_cls = ChatGLMV
791
792
        # Initialize LLM
        else:
793
794
795
796
797
798
799
800
801
            instance_cls = ChatGLM

        # quant_config references base class members,
        # so update values before init is called
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
        cls.supported_lora_modules += instance_cls.supported_lora_modules
        cls.embedding_modules.update(instance_cls.embedding_modules)
        cls.embedding_padding_modules += instance_cls.embedding_padding_modules
        return instance_cls(vllm_config=vllm_config, prefix=prefix)