fused_attn.py 5.01 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
4
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
5
from transformers.models.llama.modeling_llama import LlamaLinearScalingRotaryEmbedding
Haotian Tang's avatar
Haotian Tang committed
6

Casper Hansen's avatar
Casper Hansen committed
7
8
class RotaryEmbeddingNeox(nn.Module):
    def __init__(self, head_dim, seq_len, device):
Haotian Tang's avatar
Haotian Tang committed
9
        super().__init__()
Casper Hansen's avatar
Casper Hansen committed
10
11
12
        self.head_dim = head_dim
        self.seq_len = seq_len
        self.base = 10000
Haotian Tang's avatar
Haotian Tang committed
13

Casper Hansen's avatar
Casper Hansen committed
14
15
        # create inv_frequency
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
Haotian Tang's avatar
Haotian Tang committed
16
17
        self.register_buffer("inv_freq", inv_freq)

Casper Hansen's avatar
Casper Hansen committed
18
19
20
21
        # set cache
        self._set_cos_sin_cache(seq_len=self.seq_len, device=self.inv_freq.device)
    
    def _set_cos_sin_cache(self, seq_len, device):
Haotian Tang's avatar
Haotian Tang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
    
Casper Hansen's avatar
Casper Hansen committed
35
36
37
38
39
40
    def forward(self, positions, query, key):
        batch_size, seq_len, _ = query.shape
        query = query.view(batch_size * seq_len, -1)
        key = key.view(batch_size * seq_len, -1)
        positions = positions.view(-1).to(query.device)

Haotian Tang's avatar
Haotian Tang committed
41
42
        query = query.contiguous()
        key = key.contiguous()
Casper Hansen's avatar
Casper Hansen committed
43

Haotian Tang's avatar
Haotian Tang committed
44
45
46
47
        awq_inference_engine.rotary_embedding_neox(
            positions,
            query,
            key,
Casper Hansen's avatar
Casper Hansen committed
48
            self.head_dim,
Haotian Tang's avatar
Haotian Tang committed
49
50
            self.cos_sin_cache,
        )
Casper Hansen's avatar
Casper Hansen committed
51
52
        query = query.view(batch_size, seq_len, -1)
        key = key.view(batch_size, seq_len, -1)
Haotian Tang's avatar
Haotian Tang committed
53

Casper Hansen's avatar
Casper Hansen committed
54
55
        return query, key
    
Haotian Tang's avatar
Haotian Tang committed
56
57
58
59
60
61
62
class QuantLlamaAttention(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_heads,
        qkv_proj,
        o_proj,
Casper Hansen's avatar
Casper Hansen committed
63
        device,
Casper Hansen's avatar
Casper Hansen committed
64
        max_new_tokens
Haotian Tang's avatar
Haotian Tang committed
65
66
67
68
69
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
Casper Hansen's avatar
Casper Hansen committed
70
71
72
73
        self.seq_len = max_new_tokens
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
        self.rotary_embedding_neox = RotaryEmbeddingNeox(self.head_dim, self.seq_len, device)
Haotian Tang's avatar
Haotian Tang committed
74
75
76
77
78

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {num_heads}).")

Casper Hansen's avatar
Casper Hansen committed
79
80
    def attn(self, query, key, value, past_key_value, use_cache, attention_mask):
        batch_size, seq_len, _ = query.shape
Haotian Tang's avatar
Haotian Tang committed
81

Casper Hansen's avatar
Casper Hansen committed
82
83
84
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
Haotian Tang's avatar
Haotian Tang committed
85
        
Casper Hansen's avatar
Casper Hansen committed
86
        value = value.to(key.device)
Haotian Tang's avatar
Haotian Tang committed
87

Casper Hansen's avatar
Casper Hansen committed
88
        # cache ops
Haotian Tang's avatar
Haotian Tang committed
89
90
91
        is_causal = past_key_value is None
        if past_key_value is not None:
            # reuse k, v, self_attention
Casper Hansen's avatar
Casper Hansen committed
92
93
            key = torch.cat([past_key_value[0], key], dim=2)
            value = torch.cat([past_key_value[1], value], dim=2)
Haotian Tang's avatar
Haotian Tang committed
94
95

        if use_cache:
Casper Hansen's avatar
Casper Hansen committed
96
            # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv tensor
Haotian Tang's avatar
Haotian Tang committed
97
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
Casper Hansen's avatar
Casper Hansen committed
98
99
100
            query = query.contiguous()
            key = key.contiguous()
            value = value.contiguous()
Casper Hansen's avatar
Casper Hansen committed
101

Casper Hansen's avatar
Casper Hansen committed
102
        past_key_value = (key, value) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
103

Casper Hansen's avatar
Casper Hansen committed
104
105
106
107
108
109
110
111
        # multi-head masked attention
        attn_output = F.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=None if is_causal else attention_mask,
            is_causal=is_causal
        )
Casper Hansen's avatar
Casper Hansen committed
112

Casper Hansen's avatar
Casper Hansen committed
113
114
115
        # reshape output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
Casper Hansen's avatar
Casper Hansen committed
116

Casper Hansen's avatar
Casper Hansen committed
117
        return attn_output, past_key_value
Casper Hansen's avatar
Casper Hansen committed
118
119
120
121
122
123
124
125
126
127
128

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask,
        position_ids,
        past_key_value,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        # qkv proj
Casper Hansen's avatar
Casper Hansen committed
129
        query, key, value = self.qkv_proj(hidden_states).chunk(chunks=3, dim=-1)
Casper Hansen's avatar
Casper Hansen committed
130

Casper Hansen's avatar
Casper Hansen committed
131
132
        # rotary embeddings
        query, key = self.rotary_embedding_neox(position_ids, query, key)
Casper Hansen's avatar
Casper Hansen committed
133

Casper Hansen's avatar
Casper Hansen committed
134
135
        # attention
        attn_output, past_key_value = self.attn(query, key, value, past_key_value, use_cache, attention_mask)
Casper Hansen's avatar
Casper Hansen committed
136
137
138
139
140

        # out projection
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value