block.py 1.13 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

class MptBlock(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
5
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, dev):
Casper Hansen's avatar
Casper Hansen committed
6
7
8
        super().__init__()
        self.n_heads = n_heads
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
9
10
11
12
        self.attn = QuantAttentionFused(hidden_size, self.n_heads, qkv_layer, o_proj, dev="cuda:0", max_seq_len=8096, use_alibi=True)
        self.ffn = mpt_mlp
        self.norm_1 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
        self.norm_2 = nn.LayerNorm(hidden_size, eps=1e-6).half().to(dev)
Casper Hansen's avatar
Casper Hansen committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

    def forward(
        self, hidden_states, past_key_value, attn_bias, attention_mask, is_causal
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

        h = hidden_states + attn_output
        out = h + self.ffn.forward(self.norm_2(h))
        return out, None, past_key_value