block.py 2.82 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
3
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

4
5
class MPTBlock(nn.Module):
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
Casper Hansen's avatar
Casper Hansen committed
6
7
8
        super().__init__()
        self.n_heads = n_heads
        self.hidden_size = hidden_size
9
10
11
12
        self.norm_1 = norm_1
        self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True).to(dev)
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
Casper Hansen's avatar
Casper Hansen committed
13
14

    def forward(
15
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
Casper Hansen's avatar
Casper Hansen committed
16
17
18
19
20
21
22
23
24
25
26
27
28
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

        h = hidden_states + attn_output
        out = h + self.ffn.forward(self.norm_2(h))
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        return out, None, past_key_value

class FalconDecoderLayer(nn.Module):
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
        super().__init__()
        self.n_heads = n_heads
        self.hidden_size = hidden_size
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False).to(dev)
        self.new_decoder_arch = new_decoder_arch
        
        if new_decoder_arch:
            self.ln_attn = ln_attn # before attention
            self.ln_mlp = ln_mlp # before mlp
        else:
            self.input_layernorm = input_layernorm # before attention
        
        self.mlp = mlp

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
        
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

        h_attn = hidden_states + attn_output

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
        
        out = h_attn + h_mlp
        
        return out, None, past_key_value