model.py 4.76 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
import torch
import torch.nn as nn
Vik Paruchuri's avatar
Vik Paruchuri committed
3
from typing import List
4
from awq.utils import fused_utils
Casper Hansen's avatar
Casper Hansen committed
5
from transformers.modeling_outputs import BaseModelOutputWithPast
Casper's avatar
Casper committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock

class LlamaLikeModel(nn.Module):
    """
    LlamaLikeModel is intended to be reused across models that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
    def __init__(self, vocab_size, blocks, embedding, norm):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = embedding
        self.blocks: List[LlamaLikeBlock] = blocks
        self.norm = norm
        self.last_forward_num_tokens = 0
    
    @torch.inference_mode()
    def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
23
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
24
25
26
27
            input_ids,
            self.last_forward_num_tokens
        )
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
28

29
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
30

Casper's avatar
Casper committed
31
32
        h = self.embedding(input_ids)

33
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
34
35
36
37
38
39
40
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )

        for layer in self.blocks:
41
42
43
44
45
46
47
48
49
50
51
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h,
                None,
                attention_mask=mask,
                is_causal=is_causal
            )
Casper's avatar
Casper committed
52
53
54
        h = self.norm(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
Casper Hansen's avatar
Casper Hansen committed
55
56
57
58
59
60

class MPTModel(nn.Module):
    def __init__(self, vocab_size, blocks, wte, norm_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.wte = wte
Vik Paruchuri's avatar
Vik Paruchuri committed
61
        self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
Casper Hansen's avatar
Casper Hansen committed
62
63
64
        self.norm_f = norm_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
65
        self.last_forward_num_tokens = 0
Casper Hansen's avatar
Casper Hansen committed
66
67
68

    @torch.inference_mode()
    def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
69
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
70
71
72
            input_ids,
            self.last_forward_num_tokens
        )
Casper Hansen's avatar
Casper Hansen committed
73
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
74

75
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
76

Casper Hansen's avatar
Casper Hansen committed
77
78
        h = self.wte(input_ids)

79
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
80
81
82
83
84
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )
Casper Hansen's avatar
Casper Hansen committed
85
86

        for layer in self.blocks:
87
88
89
90
91
92
93
94
95
96
97
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h,
                None,
                attention_mask=mask,
                is_causal=is_causal
            )
Casper Hansen's avatar
Casper Hansen committed
98
99
100
        h = self.norm_f(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
101
102
103
104
105
106

class FalconModel(nn.Module):
    def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
        super().__init__()
        self.vocab_size = vocab_size
        self.word_embeddings = word_embeddings
Vik Paruchuri's avatar
Vik Paruchuri committed
107
        self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
108
109
110
        self.ln_f = ln_f
        self.attn_uses_sequence_id = False
        self.prefix_lm = False
Casper's avatar
Casper committed
111
        self.last_forward_num_tokens = 0
112
113
114

    @torch.inference_mode()
    def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
115
        input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
Casper's avatar
Casper committed
116
117
118
            input_ids,
            self.last_forward_num_tokens
        )
119
        _bsz, seqlen = input_ids.shape
Casper's avatar
Casper committed
120

121
        fused_utils.prepare_cache(self.blocks, seqlen)
Casper's avatar
Casper committed
122

123
124
        h = self.word_embeddings(input_ids)

125
        mask = fused_utils.prepare_attention_mask(
Casper's avatar
Casper committed
126
127
128
129
130
            seqlen=seqlen,
            start_pos=self.blocks[0].attn.start_pos,
            device=input_ids.device,
            type_as=h
        )
131
132

        for layer in self.blocks:
133
134
135
136
137
138
139
140
141
142
143
            h, mask = fused_utils.prepare_correct_devices(
                layer,
                h,
                mask,
            )
            h, _, past_key_value = layer(
                h, 
                None, 
                attention_mask=mask, 
                is_causal=is_causal
            )
144
145
146
        h = self.ln_f(h)

        return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())