block.py 5.46 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

Casper's avatar
Casper committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class LlamaLikeBlock(nn.Module):
    """
    LlamaLikeBlock is intended to be reused across blocks that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
            self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
            dev=dev, max_seq_len=max_seq_len, use_alibi=False
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.mlp = mlp.to(dev)

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask
        )

        h = hidden_states + attn_output
        out = h + self.mlp.forward(self.norm_2(h))

        return out, None, past_key_value

38
39
class MPTBlock(nn.Module):
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
Casper Hansen's avatar
Casper Hansen committed
40
41
        super().__init__()
        self.n_heads = n_heads
Casper Hansen's avatar
Casper Hansen committed
42
        self.n_kv_heads = 0
Casper Hansen's avatar
Casper Hansen committed
43
        self.hidden_size = hidden_size
44
        self.norm_1 = norm_1
Casper Hansen's avatar
Casper Hansen committed
45
46
47
48
        self.attn = QuantAttentionFused(
            hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 
            dev=dev, max_seq_len=max_seq_len, use_alibi=True
        ).to(dev)
49
50
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
Casper Hansen's avatar
Casper Hansen committed
51
52

    def forward(
53
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
Casper Hansen's avatar
Casper Hansen committed
54
55
56
57
58
59
60
61
62
63
64
65
66
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

        h = hidden_states + attn_output
        out = h + self.ffn.forward(self.norm_2(h))
67
68
69
        return out, None, past_key_value

class FalconDecoderLayer(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
70
71
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, 
                       input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
72
73
        super().__init__()
        self.n_heads = n_heads
74
        self.n_kv_heads = 8 if new_decoder_arch else 0
75
76
        self.hidden_size = hidden_size
        self.new_decoder_arch = new_decoder_arch
Casper Hansen's avatar
Casper Hansen committed
77
78
79
80
81

        if new_decoder_arch:
            attention_shapes = None
        else:
            attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
82
83
84
        
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
85
            hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 
86
87
88
89
            dev=dev, max_seq_len=max_seq_len, use_alibi=False,
            attention_shapes=attention_shapes
        ).to(dev)
        
90
91
92
93
94
95
96
        if new_decoder_arch:
            self.ln_attn = ln_attn # before attention
            self.ln_mlp = ln_mlp # before mlp
        else:
            self.input_layernorm = input_layernorm # before attention
        
        self.mlp = mlp
97
    
Casper Hansen's avatar
Casper Hansen committed
98
    def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
99
100
        batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
        
Casper Hansen's avatar
Casper Hansen committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        self.attention_shapes = {
            # following fastertransformer definition
            "cache_v": (batch_size, 1, max_seq_len, head_dim,),
            # 8: pack 8 fp16 in FT, if fp32 then use 4
            "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
            "xqkv_view": (n_heads+2, head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, :-2],
            "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
            "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
            "xq_view": (n_heads, head_dim),
            "xk_view": (1, head_dim),
            "xv_view": (1, head_dim),
            "xk_reshape": (1, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (1, head_dim),
            "single_xv_view": (1, head_dim)
        }
118
119

        return self.attention_shapes
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
        
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

        h_attn = hidden_states + attn_output

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
        
        out = h_attn + h_mlp
        
        return out, None, past_key_value