"tests/models/quantization/untest_fp8.py" did not exist on "d3311562fbe740a883e7f03f0b59620587cabb29"
internlm2_ve.py 5.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional, Union
5
6
7
8
9

import torch
from torch import nn
from transformers import PretrainedConfig

10
from vllm.config import CacheConfig, VllmConfig
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
                                                  InternLM2ForCausalLM,
                                                  InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors


class InternLM2VEDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
27
        prefix: str = "",
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.attention = InternLM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
44
            prefix=f"{prefix}.attention",
45
46
47
48
49
50
        )
        self.feed_forward = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
51
            prefix=f"{prefix}.feed_forward",
52
53
54
55
56
57
        )
        self.feed_forward_ve = InternLM2MLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
58
            prefix=f"{prefix}.feed_forward_ve",
59
60
61
62
63
64
65
66
67
68
69
        )
        self.attention_norm = RMSNorm(config.hidden_size,
                                      eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
        visual_token_mask: Optional[torch.Tensor] = None,
70
    ) -> tuple[torch.Tensor, torch.Tensor]:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.attention_norm(hidden_states)
        else:
            hidden_states, residual = self.attention_norm(
                hidden_states, residual)
        hidden_states = self.attention(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.ffn_norm(hidden_states, residual)
        if visual_token_mask is not None and visual_token_mask.any():
            visual_token_mask = visual_token_mask.repeat(
                1, self.hidden_size).bool()
            text_token_mask = ~visual_token_mask
            hidden_states[visual_token_mask] = self.feed_forward_ve(
                hidden_states[visual_token_mask].reshape(
                    -1, self.hidden_size)).flatten()
            if text_token_mask.any():
                hidden_states[text_token_mask] = self.feed_forward(
                    hidden_states[text_token_mask].reshape(
                        -1, self.hidden_size)).flatten()
        else:
            hidden_states = self.feed_forward(hidden_states)
        return hidden_states, residual


class InternLM2VEModel(InternLM2Model):

103
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
104
105
106
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         layer_type=InternLM2VEDecoderLayer)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        visual_token_mask: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.tok_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
126
        for layer in self.layers[self.start_layer:self.end_layer]:
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
                visual_token_mask=visual_token_mask,
            )
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class InternLM2VEForCausalLM(InternLM2ForCausalLM):

144
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
145
146
147
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         model_type=InternLM2VEModel)