chatglm.py 26.7 KB
Newer Older
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
1
2
# coding=utf-8
# Adapted from
3
# https://github.com/THUDM/GLM-4
Woosuk Kwon's avatar
Woosuk Kwon committed
4
"""Inference-only ChatGLM model compatible with THUDM weights."""
5
6
7
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
8
9

import torch
10
from PIL import Image
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
11
12
from torch import nn
from torch.nn import LayerNorm
zhuwenwen's avatar
zhuwenwen committed
13
import os
zhuwenwen's avatar
zhuwenwen committed
14
import re
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
15

16
from vllm.attention import Attention, AttentionMetadata
17
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
18
from vllm.distributed import get_tensor_model_parallel_world_size
19
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.activation import SiluAndMul
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
22
from vllm.model_executor.layers.layernorm import RMSNorm
23
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
24
25
                                               QKVParallelLinear,
                                               RowParallelLinear)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
28
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
29
from vllm.model_executor.layers.rotary_embedding import get_rope
30
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
32
    ParallelLMHead, VocabParallelEmbedding)
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
35
from vllm.model_executor.sampling_metadata import SamplingMetadata
36
37
38
39
40
41
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
                             MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                           SequenceData)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
42
from vllm.transformers_utils.configs import ChatGLMConfig
43

zhuwenwen's avatar
zhuwenwen committed
44
from vllm import _custom_ops as ops
45
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from .interfaces import SupportsLoRA, SupportsMultiModal

logger = init_logger(__name__)


def calculate_image_placeholder(vision_config):
    return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2


def mm_input_mapper_for_glmv(
    ctx: InputContext,
    data: MultiModalData[object],
) -> Dict:
    model_config = ctx.model_config
    tokenizer = cached_get_tokenizer(model_config.tokenizer,
                                     trust_remote_code=True)
    if tokenizer is None:
        raise RuntimeError("No HuggingFace processor is available "
                           "to process the image object")
    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
                "image": data
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
        logger.error("Failed to process image (%s)", data)
        raise
    pixel_values = raw_batch_data['images']

    return MultiModalInputs({'pixel_values': pixel_values})


def merge_glm_vision_embeddings(
    input_ids: torch.Tensor,
    inputs_embeds: torch.Tensor,
    vision_embeddings: torch.Tensor,
    boi_token_id: int,
    eoi_token_id: int,
) -> torch.Tensor:

    boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0]
    eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0]

    mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for boi_pos, eoi_pos in zip(boi_positions, eoi_positions):
        assert boi_pos < eoi_pos
        mask[boi_pos:eoi_pos + 1] = True
    inputs_embeds[mask] = vision_embeddings.view(-1,
                                                 vision_embeddings.shape[-1])
    return inputs_embeds


class GLMImagePixelInputs(TypedDict):
    pixel_values: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


def get_max_glmv_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(ChatGLMConfig)

    vision_config = getattr(hf_config, 'vision_config', None)
    if vision_config is None:
        return 1
    elif isinstance(vision_config, dict):
        return calculate_image_placeholder(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def dummy_data_for_glmv(
    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
        seq_data = SequenceData(token_ids)
        return seq_data, None
    elif isinstance(vision_config, dict):
        image_size = vision_config["image_size"]
        image_placeholder_length = calculate_image_placeholder(vision_config)
        token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] +
                          [0] * image_placeholder_length +
                          [hf_config.eoi_token_id])
        token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                           [0] * (seq_len - image_placeholder_length - 2))
        seq_data = SequenceData(token_ids)

        mm_data = {
            "image": Image.new("RGB", (image_size, image_size), color=0)
        }

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def find_all_positions(input_ids: List[int], target: int) -> List[int]:
    return [index for index, value in enumerate(input_ids) if value == target]


157
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
158
159
160
161
    hf_config = ctx.get_hf_config(ChatGLMConfig)
    vision_config = getattr(hf_config, 'vision_config', None)

    if vision_config is None:
162
        return inputs
163
164
165
166
167
168
    elif isinstance(vision_config, dict):
        image_placeholder_length = calculate_image_placeholder(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

169
170
    input_ids = inputs.get("prompt_token_ids")
    position_ids = inputs.get("position_ids")
171
172
173
174
175
176
177
178
    tokenizer = cached_get_tokenizer(
        ctx.model_config.model,
        trust_remote_code=ctx.model_config.trust_remote_code)

    try:
        raw_batch_data = tokenizer.apply_chat_template(
            conversation=[{
                "role": "user",
179
180
                "image": inputs['multi_modal_data']["image"],
                "content": inputs['prompt']
181
182
183
184
185
186
            }],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True).data
    except Exception:
187
        logger.error("Failed to process content (%s)", inputs['prompt'])
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        raise
    input_ids = raw_batch_data['input_ids'][0].tolist()

    if position_ids is None:
        position_ids = list(range(len(input_ids)))
    boi_token_id = hf_config.boi_token_id
    eoi_token_id = hf_config.eoi_token_id
    boi_positions = find_all_positions(input_ids, boi_token_id)
    eoi_positions = find_all_positions(input_ids, eoi_token_id)

    assert len(boi_positions) == len(eoi_positions)

    new_input_ids = []
    new_position_ids = []
    final_processed_position = 0
    final_processed_position = 0

    for boi_position, eoi_position in zip(boi_positions, eoi_positions):
        assert boi_position < eoi_position
        new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                       1])
        new_position_ids.extend(
            list(range(final_processed_position, boi_position + 1)))
        new_input_ids.extend([input_ids[boi_position + 1]] *
                             image_placeholder_length)
        new_position_ids.extend([boi_position + 1] * image_placeholder_length)
        final_processed_position = eoi_position

    new_input_ids.extend(input_ids[final_processed_position:])
    new_position_ids.extend(
        list(range(final_processed_position, len(input_ids))))

    assert len(new_input_ids) == len(new_position_ids)

222
223
224
    inputs["prompt_token_ids"] = new_input_ids
    inputs["position_ids"] = new_position_ids
    return inputs
225
>>>>>>> v0.6.3.post1
226

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
227
228
229

class GLMAttention(nn.Module):

230
231
232
    def __init__(
        self,
        config,
233
        cache_config: Optional[CacheConfig] = None,
234
        quant_config: Optional[QuantizationConfig] = None,
235
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
236
237
238
239
240
241
242
243
244
245
        super().__init__()
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.multi_query_attention = config.multi_query_attention
        self.total_num_kv_heads = (config.multi_query_group_num
                                   if config.multi_query_attention else
                                   config.num_attention_heads)
246
247
248
249
250
251
252
253
254
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
255
256
257
258
259
        self.head_dim = config.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

260
261
        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
262
            self.head_dim,
263
264
265
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=config.add_bias_linear or config.add_qkv_bias,
266
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
267
268
269
270
271
        )
        self.dense = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            config.hidden_size,
            bias=config.add_bias_linear,
272
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
273
274
        )

275
276
277
        # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        max_positions = getattr(config, "seq_length", 8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        self.rotary_emb = get_rope(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
279
280
            self.head_dim,
            rotary_dim=self.head_dim // 2,
281
282
            max_position=max_positions,
            base=10000 * rope_ratio,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
283
284
            is_neox_style=False,
        )
285
286
287
288
289
290
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
291
292
293
294
295
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
296
297
298
299
300

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
301
302
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
303
304
    ) -> torch.Tensor:
        qkv, _ = self.query_key_value(hidden_states)
305
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
306
            qkv = qkv[...,:-32]
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
307
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        q, k = self.rotary_emb(position_ids, q, k)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
309
310
311
312
        context_layer = self.attn(
            q,
            k,
            v,
313
314
            kv_cache,
            attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
315
316
317
318
319
320
321
322
323
324
325
326
327
        )
        attn_output, _ = self.dense(context_layer)
        return attn_output


class GLMMLP(nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

328
329
330
    def __init__(
        self,
        config,
331
        quant_config: Optional[QuantizationConfig] = None,
332
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
333
334
335
336
337
        super().__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h.
338
        self.dense_h_to_4h = MergedColumnParallelLinear(
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
339
            config.hidden_size,
340
            [config.ffn_hidden_size] * 2,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
341
            bias=config.add_bias_linear,
342
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
343
344
345
346
347
348
349
350
351
        )

        self.activation_func = SiluAndMul()

        # Project back to h.
        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=config.add_bias_linear,
352
            quant_config=quant_config,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output, _ = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(
        self,
        config,
374
        cache_config: Optional[CacheConfig] = None,
375
        quant_config: Optional[QuantizationConfig] = None,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
376
377
378
379
380
381
382
383
384
385
386
387
388
    ):
        super().__init__()
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

        self.fp32_residual_connection = config.fp32_residual_connection

        layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = layer_norm_func(config.hidden_size,
                                               eps=config.layernorm_epsilon)

        # Self attention.
389
        self.self_attention = GLMAttention(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
390
391
392
393
394
395
396
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = layer_norm_func(
            config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
397
        self.mlp = GLMMLP(config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
398
399
400
401
402

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
403
404
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
405
406
407
408
409
410
411
412
413
    ) -> torch.Tensor:
        # hidden_states: [num_tokens, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output = self.self_attention(
            hidden_states=layernorm_output,
            position_ids=position_ids,
            kv_cache=kv_cache,
414
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = residual + attention_output

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = self.mlp(layernorm_output) + residual

        return output


class GLMTransformer(nn.Module):
    """Transformer class."""

442
443
444
    def __init__(
        self,
        config,
445
        cache_config: Optional[CacheConfig] = None,
446
        quant_config: Optional[QuantizationConfig] = None,
447
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
448
449
450
451
452
453
454
        super().__init__()
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
455
456
457
458
        self.layers = nn.ModuleList([
            GLMBlock(config, cache_config, quant_config)
            for i in range(self.num_layers)
        ])
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
459
460
461
462
463
464
465
466
467
468
469

        if self.post_layer_norm:
            layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = layer_norm_func(
                config.hidden_size, eps=config.layernorm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
470
471
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
472
473
474
475
476
477
478
    ) -> torch.Tensor:
        for i in range(self.num_layers):
            layer = self.layers[i]
            hidden_states = layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                kv_cache=kv_caches[i],
479
                attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
480
481
482
483
484
485
486
487
488
489
            )
        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


class ChatGLMModel(nn.Module):

490
491
492
    def __init__(
        self,
        config,
493
        cache_config: Optional[CacheConfig] = None,
494
        quant_config: Optional[QuantizationConfig] = None,
495
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
496
497
        super().__init__()

498
499
        self.config = config

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
500
        self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
501
502
                                                config.hidden_size,
                                                quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
503
504
505
506

        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels
507
        self.encoder = GLMTransformer(config, cache_config, quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
508

509
        self.output_layer = ParallelLMHead(config.padded_vocab_size,
510
511
                                           config.hidden_size,
                                           quant_config=quant_config)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        vision_config_flag = getattr(config, 'vision_config', None)
        if vision_config_flag is not None:
            self.vision_config = Namespace(**config.vision_config)
            self.vision = EVA2CLIPModel(self.config, quant_config)
        else:
            self.vision = None

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> GLMImagePixelInputs:

        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is not None and self.vision is not None:
            if isinstance(pixel_values, torch.Tensor):
                if pixel_values.ndim > 2:
                    pixel_values = torch.concat(list(pixel_values))
            elif isinstance(pixel_values, list):
                return torch.concat(pixel_values)
            else:
                raise TypeError("""pixel_values must be a torch.Tensor 
                    or a list of torch.Tensor
                    """)
        return GLMImagePixelInputs(pixel_values=pixel_values)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
535
536
537
538

    def forward(
        self,
        input_ids: torch.Tensor,
539
        positions: torch.Tensor,
540
541
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
542
543
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
544
    ) -> torch.Tensor:
545

GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
546
        inputs_embeds = self.embedding(input_ids)
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input["pixel_values"] is not None:
            pixel_values = image_input["pixel_values"].to(
                dtype=inputs_embeds.dtype)
            image_embeds = self.vision(pixel_values)

            boi_token_id = self.config.boi_token_id
            eoi_token_id = self.config.eoi_token_id

            inputs_embeds = merge_glm_vision_embeddings(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                vision_embeddings=image_embeds,
                boi_token_id=boi_token_id,
                eoi_token_id=eoi_token_id)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
563
564
565
566

        # Run encoder.
        hidden_states = self.encoder(
            hidden_states=inputs_embeds,
567
            position_ids=positions,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
568
            kv_caches=kv_caches,
569
            attn_metadata=attn_metadata,
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
570
571
572
573
        )
        return hidden_states


574
575
576
577
578
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
579
580
581
582
583
584
585
586
587
588
589
590
591
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"]
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
592

593
594
595
    def __init__(
        self,
        config: ChatGLMConfig,
596
        multimodal_config: MultiModalConfig,
597
        cache_config: Optional[CacheConfig] = None,
598
        quant_config: Optional[QuantizationConfig] = None,
599
        lora_config: Optional[LoRAConfig] = None,
600
    ):
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
601
        super().__init__()
602
603
604

        self.config = config
        self.lora_config = lora_config
605
        self.multimodal_config = multimodal_config
606

607
        self.quant_config = quant_config
608
609
        self.max_position_embeddings = getattr(config, "max_sequence_length",
                                               8192)
610
        self.transformer = ChatGLMModel(config, cache_config, quant_config)
611
612
613
        if self.config.tie_word_embeddings:
            self.transformer.output_layer.weight = (
                self.transformer.embedding.weight)
614
        self.lm_head = self.transformer.output_layer
615
616
        self.logits_processor = LogitsProcessor(config.padded_vocab_size)
        self.sampler = Sampler()
617
618
619
620
621
622
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config

zhuwenwen's avatar
zhuwenwen committed
623
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
624
625
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
626

627
628
629
630
631
632
633
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs) -> torch.Tensor:
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
634
        hidden_states = self.transformer(input_ids, positions, kv_caches,
635
                                         attn_metadata, **kwargs)
636
637
        return hidden_states

638
639
640
641
642
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
643
        logits = self.logits_processor(self.lm_head, hidden_states,
644
645
646
                                       sampling_metadata)
        return logits

647
648
    def sample(
        self,
649
        logits: torch.Tensor,
650
        sampling_metadata: SamplingMetadata,
651
    ) -> Optional[SamplerOutput]:
652
        next_tokens = self.sampler(logits, sampling_metadata)
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
653
654
        return next_tokens

655
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
656
657
658
659
660
661
662
663
        # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
        merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
            "transformer.vision.linear_proj.merged_proj.weight": {
                "transformer.vision.linear_proj.gate_proj.weight": None,
                "transformer.vision.linear_proj.dense_h_to_4h.weight": None,
            }
        }

664
        params_dict = dict(self.named_parameters(remove_duplicate=False))
665
        for name, loaded_weight in weights:
666
667
668
669
670
671
672
673
            is_weight_to_be_merge = False
            for _, merged_weight_dict in merged_weights_dict.items():
                if name in merged_weight_dict:
                    assert merged_weight_dict[name] is None
                    merged_weight_dict[name] = loaded_weight
                    is_weight_to_be_merge = True
            if is_weight_to_be_merge:
                continue
674
675
            if "rotary_pos_emb.inv_freq" in name:
                continue
GoHomeToMacDonal's avatar
GoHomeToMacDonal committed
676
677
            if "word_embeddings" in name:
                name = name.replace(".word_embeddings", "")
CHU Tianxiang's avatar
CHU Tianxiang committed
678
679
680
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
681
682
683
684
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
685
686
687
688
689
690
691
692
693

        for combined_name, merged_weight_dict in merged_weights_dict.items():
            if combined_name in params_dict:
                param = params_dict[combined_name]
                combined_weight = torch.cat(list(merged_weight_dict.values()),
                                            dim=0)
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, combined_weight)
zhuwenwen's avatar
zhuwenwen committed
694
        
695
        if self.use_llama_nn and self.quant_method is None:
zhuwenwen's avatar
zhuwenwen committed
696
697
698
699
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
700
701
                "mlp.dense_4h_to_h.weight",
                "lm_head.weight"
zhuwenwen's avatar
zhuwenwen committed
702
703
704
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
705
706
707
708
709
710
            lay_qkv_words = ["self_attention.query_key_value.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words)
            
zhuwenwen's avatar
zhuwenwen committed
711
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
712
713
714
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                    
zhuwenwen's avatar
zhuwenwen committed
715
                matches = re.findall(combined_words, layername)
716
717
718
719
                if matches:  
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
720
721
722
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
723
                                        
zhuwenwen's avatar
zhuwenwen committed
724
725
726
727
728
729
730
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1], -1)
731