"CONTRIBUTING.md" did not exist on "e3b5179e609ec3d13e21381a27f4e51334d277b0"
model.py 6.66 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
import torch
import torch.nn as nn
Vik Paruchuri's avatar
Vik Paruchuri committed
3
from typing import List
4
from awq.utils import fused_utils
Casper's avatar
Casper committed
5
6
7
8
9
10
11
12
13
14
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    MoeModelOutputWithPast,
)
from awq.modules.fused.block import (
    MPTBlock,
    FalconDecoderLayer,
    LlamaLikeBlock,
    MixtralBlock,
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57


class MixtralModel(nn.Module):
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
        self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
        self.norm = norm
        self.last_forward_num_tokens = 0

    @torch.inference_mode()
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
            input_ids, self.last_forward_num_tokens
        )
        _bsz, seqlen = input_ids.shape

        fused_utils.prepare_cache(self.blocks, seqlen)

        h = self.embedding(input_ids)

        mask = fused_utils.prepare_attention_mask(
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h,
        )

        for layer in self.blocks:
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
Casper's avatar
Casper committed
58
59
60
61
            h, _, past_key_value = layer(
                h, None, attention_mask=mask, is_causal=is_causal
            )

62
63
64
65
66
67
68
69
70
71
        h = self.norm(h)

        return MoeModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
            router_logits=(),
        )

Casper's avatar
Casper committed
72
73
74
75
76
77

class LlamaLikeModel(nn.Module):
    """
    LlamaLikeModel is intended to be reused across models that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
78

Casper's avatar
Casper committed
79
80
81
82
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
83
        self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
Casper's avatar
Casper committed
84
85
        self.norm = norm
        self.last_forward_num_tokens = 0
86
87
88
89
90
91
92
93
        
    @property
    def embed_tokens(self):
        return self.embedding
    
    @property
    def layers(self):
        return self.blocks
Casper's avatar
Casper committed
94

Casper's avatar
Casper committed
95
    @torch.inference_mode()
Casper's avatar
Casper committed
96
97
98
99
100
101
102
103
104
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
105
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
106
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
107
108
        )
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
109

110
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
111

Casper's avatar
Casper committed
112
113
        h = self.embedding(input_ids)

114
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
115
116
117
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
118
            type_as=h,
Casper's avatar
Casper committed
119
120
121
        )

        for layer in self.blocks:
122
123
124
125
126
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
TechxGenus's avatar
TechxGenus committed
127
            h, _, _ = layer(
Casper's avatar
Casper committed
128
                h, None, attention_mask=mask, is_causal=is_causal
129
            )
Casper's avatar
Casper committed
130
131
        h = self.norm(h)

Casper's avatar
Casper committed
132
133
        return BaseModelOutputWithPast(
            last_hidden_state=h,
TechxGenus's avatar
TechxGenus committed
134
            past_key_values=None,
Casper's avatar
Casper committed
135
136
137
138
            hidden_states=(),
            attentions=(),
        )

Casper Hansen's avatar
Casper Hansen committed
139
140
141
142
143
144

class MPTModel(nn.Module):
    def __init__(self, vocab_size, blocks, wte, norm_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.wte = wte
Vik Paruchuri's avatar
Vik Paruchuri committed
145
        self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
Casper Hansen's avatar
Casper Hansen committed
146
147
148
        self.norm_f = norm_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
149
        self.last_forward_num_tokens = 0
Casper Hansen's avatar
Casper Hansen committed
150
151

    @torch.inference_mode()
Casper's avatar
Casper committed
152
153
154
155
156
157
158
159
160
    def forward(
        self,
        input_ids,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
161
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
162
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
163
        )
Casper Hansen's avatar
Casper Hansen committed
164
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
165

166
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
167

Casper Hansen's avatar
Casper Hansen committed
168
169
        h = self.wte(input_ids)

170
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
171
172
173
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
174
            type_as=h,
Casper's avatar
Casper committed
175
        )
Casper Hansen's avatar
Casper Hansen committed
176
177

        for layer in self.blocks:
178
179
180
181
182
183
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
Casper's avatar
Casper committed
184
                h, None, attention_mask=mask, is_causal=is_causal
185
            )
Casper Hansen's avatar
Casper Hansen committed
186
187
        h = self.norm_f(h)

Casper's avatar
Casper committed
188
189
190
191
192
193
194
        return BaseModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
        )

195
196
197
198
199
200

class FalconModel(nn.Module):
    def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.word_embeddings = word_embeddings
Vik Paruchuri's avatar
Vik Paruchuri committed
201
        self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
202
203
204
        self.ln_f = ln_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
205
        self.last_forward_num_tokens = 0
206
207

    @torch.inference_mode()
Casper's avatar
Casper committed
208
209
210
211
212
213
214
215
216
    def forward(
        self,
        input_ids,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
217
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
218
            input_ids, self.last_forward_num_tokens
Casper's avatar
Casper committed
219
        )
220
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
221

222
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
223

224
225
        h = self.word_embeddings(input_ids)

226
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
227
228
229
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
Casper's avatar
Casper committed
230
            type_as=h,
Casper's avatar
Casper committed
231
        )
232
233

        for layer in self.blocks:
234
235
236
237
238
239
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
Casper's avatar
Casper committed
240
                h, None, attention_mask=mask, is_causal=is_causal
241
            )
242
243
        h = self.ln_f(h)

Casper's avatar
Casper committed
244
245
246
247
248
249
        return BaseModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
        )