model.py 6.3 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
import torch
import torch.nn as nn
Vik Paruchuri's avatar
Vik Paruchuri committed
3
from typing import List
4
from awq.utils import fused_utils
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock


class MixtralModel(nn.Module):
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
        self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
        self.norm = norm
        self.last_forward_num_tokens = 0

    @torch.inference_mode()
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
        *args,
        **kwargs,
    ):
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
            input_ids, self.last_forward_num_tokens
        )
        _bsz, seqlen = input_ids.shape

        fused_utils.prepare_cache(self.blocks, seqlen)

        h = self.embedding(input_ids)

        mask = fused_utils.prepare_attention_mask(
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h,
        )

        for layer in self.blocks:
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
        
        h = self.norm(h)

        return MoeModelOutputWithPast(
            last_hidden_state=h,
            past_key_values=past_key_value,
            hidden_states=(),
            attentions=(),
            router_logits=(),
        )

Casper's avatar
Casper committed
62
63
64
65
66
67
68
69
70
71

class LlamaLikeModel(nn.Module):
    """
    LlamaLikeModel is intended to be reused across models that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
72
        self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
Casper's avatar
Casper committed
73
74
75
76
77
        self.norm = norm
        self.last_forward_num_tokens = 0
    
    @torch.inference_mode()
    def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
78
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
79
80
81
82
            input_ids,
            self.last_forward_num_tokens
        )
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
83

84
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
85

Casper's avatar
Casper committed
86
87
        h = self.embedding(input_ids)

88
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
89
90
91
92
93
94
95
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )

        for layer in self.blocks:
96
97
98
99
100
101
102
103
104
105
106
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h,
                None,
                attention_mask=mask,
                is_causal=is_causal
            )
Casper's avatar
Casper committed
107
108
109
        h = self.norm(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
Casper Hansen's avatar
Casper Hansen committed
110
111
112
113
114
115

class MPTModel(nn.Module):
    def __init__(self, vocab_size, blocks, wte, norm_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.wte = wte
Vik Paruchuri's avatar
Vik Paruchuri committed
116
        self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
Casper Hansen's avatar
Casper Hansen committed
117
118
119
        self.norm_f = norm_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
120
        self.last_forward_num_tokens = 0
Casper Hansen's avatar
Casper Hansen committed
121
122
123

    @torch.inference_mode()
    def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
124
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
125
126
127
            input_ids,
            self.last_forward_num_tokens
        )
Casper Hansen's avatar
Casper Hansen committed
128
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
129

130
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
131

Casper Hansen's avatar
Casper Hansen committed
132
133
        h = self.wte(input_ids)

134
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
135
136
137
138
139
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )
Casper Hansen's avatar
Casper Hansen committed
140
141

        for layer in self.blocks:
142
143
144
145
146
147
148
149
150
151
152
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h,
                None,
                attention_mask=mask,
                is_causal=is_causal
            )
Casper Hansen's avatar
Casper Hansen committed
153
154
155
        h = self.norm_f(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
156
157
158
159
160
161

class FalconModel(nn.Module):
    def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.word_embeddings = word_embeddings
Vik Paruchuri's avatar
Vik Paruchuri committed
162
        self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
163
164
165
        self.ln_f = ln_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
166
        self.last_forward_num_tokens = 0
167
168
169

    @torch.inference_mode()
    def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
170
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
171
172
173
            input_ids,
            self.last_forward_num_tokens
        )
174
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
175

176
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
177

178
179
        h = self.word_embeddings(input_ids)

180
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
181
182
183
184
185
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )
186
187

        for layer in self.blocks:
188
189
190
191
192
193
194
195
196
197
198
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h, 
                None, 
                attention_mask=mask, 
                is_causal=is_causal
            )
199
200
201
        h = self.ln_f(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())