model.py 6.54 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
import torch
import torch.nn as nn
Vik Paruchuri's avatar
Vik Paruchuri committed
3
from typing import List
4
from awq.utils import fused_utils
Casper's avatar
Casper committed
5
6
7
8
9
10
11
12
13
14
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    MoeModelOutputWithPast,
)
from awq.modules.fused.block import (
    MPTBlock,
    FalconDecoderLayer,
    LlamaLikeBlock,
    MixtralBlock,
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


class MixtralModel(nn.Module):
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
        self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
        self.norm = norm
        self.last_forward_num_tokens = 0

    @torch.inference_mode()
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
            input_ids, self.last_forward_num_tokens
        )
        _bsz, seqlen = input_ids.shape

        fused_utils.prepare_cache(self.blocks, seqlen)

        h = self.embedding(input_ids)

        mask = fused_utils.prepare_attention_mask(
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h,
        )

        for layer in self.blocks:
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
Casper's avatar
Casper committed
58
59
60
61
            h, _, past_key_value = layer(
                h, None, attention_mask=mask, is_causal=is_causal
            )

62
63
64
65
66
67
68
69
70
71
        h = self.norm(h)

        return MoeModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
            router_logits=(),
        )

Casper's avatar
Casper committed
72
73
74
75
76
77

class LlamaLikeModel(nn.Module):
    """
    LlamaLikeModel is intended to be reused across models that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
78

Casper's avatar
Casper committed
79
80
81
82
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
83
        self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
Casper's avatar
Casper committed
84
85
        self.norm = norm
        self.last_forward_num_tokens = 0
Casper's avatar
Casper committed
86

Casper's avatar
Casper committed
87
    @torch.inference_mode()
Casper's avatar
Casper committed
88
89
90
91
92
93
94
95
96
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
97
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
98
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
99
100
        )
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
101

102
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
103

Casper's avatar
Casper committed
104
105
        h = self.embedding(input_ids)

106
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
107
108
109
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
110
            type_as=h,
Casper's avatar
Casper committed
111
112
113
        )

        for layer in self.blocks:
114
115
116
117
118
119
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
Casper's avatar
Casper committed
120
                h, None, attention_mask=mask, is_causal=is_causal
121
            )
Casper's avatar
Casper committed
122
123
        h = self.norm(h)

Casper's avatar
Casper committed
124
125
126
127
128
129
130
        return BaseModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
        )

Casper Hansen's avatar
Casper Hansen committed
131
132
133
134
135
136

class MPTModel(nn.Module):
    def __init__(self, vocab_size, blocks, wte, norm_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.wte = wte
Vik Paruchuri's avatar
Vik Paruchuri committed
137
        self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
Casper Hansen's avatar
Casper Hansen committed
138
139
140
        self.norm_f = norm_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
141
        self.last_forward_num_tokens = 0
Casper Hansen's avatar
Casper Hansen committed
142
143

    @torch.inference_mode()
Casper's avatar
Casper committed
144
145
146
147
148
149
150
151
152
    def forward(
        self,
        input_ids,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
153
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
154
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
155
        )
Casper Hansen's avatar
Casper Hansen committed
156
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
157

158
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
159

Casper Hansen's avatar
Casper Hansen committed
160
161
        h = self.wte(input_ids)

162
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
163
164
165
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
166
            type_as=h,
Casper's avatar
Casper committed
167
        )
Casper Hansen's avatar
Casper Hansen committed
168
169

        for layer in self.blocks:
170
171
172
173
174
175
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
Casper's avatar
Casper committed
176
                h, None, attention_mask=mask, is_causal=is_causal
177
            )
Casper Hansen's avatar
Casper Hansen committed
178
179
        h = self.norm_f(h)

Casper's avatar
Casper committed
180
181
182
183
184
185
186
        return BaseModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
        )

187
188
189
190
191
192

class FalconModel(nn.Module):
    def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.word_embeddings = word_embeddings
Vik Paruchuri's avatar
Vik Paruchuri committed
193
        self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
194
195
196
        self.ln_f = ln_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
197
        self.last_forward_num_tokens = 0
198
199

    @torch.inference_mode()
Casper's avatar
Casper committed
200
201
202
203
204
205
206
207
208
    def forward(
        self,
        input_ids,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
209
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
210
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
211
        )
212
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
213

214
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
215

216
217
        h = self.word_embeddings(input_ids)

218
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
219
220
221
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
222
            type_as=h,
Casper's avatar
Casper committed
223
        )
224
225

        for layer in self.blocks:
226
227
228
229
230
231
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
Casper's avatar
Casper committed
232
                h, None, attention_mask=mask, is_causal=is_causal
233
            )
234
235
        h = self.ln_f(h)

Casper's avatar
Casper committed
236
237
238
239
240
241
        return BaseModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
        )