fused_attn.py 8.14 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
4
from torch.nn import functional as F
Haotian Tang's avatar
Haotian Tang committed
5

Casper Hansen's avatar
Casper Hansen committed
6
7
class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
Haotian Tang's avatar
Haotian Tang committed
8
9
        super().__init__()

Casper Hansen's avatar
Casper Hansen committed
10
11
12
13
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
Haotian Tang's avatar
Haotian Tang committed
14
        self.register_buffer("inv_freq", inv_freq)
Casper Hansen's avatar
Casper Hansen committed
15
16
17
18
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
Haotian Tang's avatar
Haotian Tang committed
19

Casper Hansen's avatar
Casper Hansen committed
20
    def _set_cos_sin_cache(self, seq_len, device, dtype):
Haotian Tang's avatar
Haotian Tang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
    
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
38
39
40
41
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Haotian Tang's avatar
Haotian Tang committed
42
43
44
45
46
47
        query = query.contiguous()
        key = key.contiguous()
        awq_inference_engine.rotary_embedding_neox(
            positions,
            query,
            key,
Casper Hansen's avatar
Casper Hansen committed
48
            self.dim,
Haotian Tang's avatar
Haotian Tang committed
49
50
            self.cos_sin_cache,
        )
Casper Hansen's avatar
Casper Hansen committed
51
        return query, key
Casper Hansen's avatar
Casper Hansen committed
52

Haotian Tang's avatar
Haotian Tang committed
53
class QuantLlamaAttention(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
54
55
    """Multi-headed attention from 'Attention Is All You Need' paper"""

Haotian Tang's avatar
Haotian Tang committed
56
57
58
59
60
61
    def __init__(
        self,
        hidden_size,
        num_heads,
        qkv_proj,
        o_proj,
Casper Hansen's avatar
Casper Hansen committed
62
        dev,
Casper Hansen's avatar
Casper Hansen committed
63
        max_new_tokens
Haotian Tang's avatar
Haotian Tang committed
64
65
66
67
68
69
70
71
72
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {num_heads}).")
Casper Hansen's avatar
Casper Hansen committed
73
74
75
76
77
78
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
        self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)

    def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
        """Input shape: Batch x Time x Channel"""
Haotian Tang's avatar
Haotian Tang committed
79

Casper Hansen's avatar
Casper Hansen committed
80
        bsz, q_len, _ = hidden_states.size()
Haotian Tang's avatar
Haotian Tang committed
81

Casper Hansen's avatar
Casper Hansen committed
82
83
84
85
86
87
        qkv_states = self.qkv_proj(hidden_states)
        qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)

        # This updates the query and key states in-place, saving VRAM.
        query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
        query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
Haotian Tang's avatar
Haotian Tang committed
88
        
Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
        del qkv_states
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Haotian Tang's avatar
Haotian Tang committed
93
94

        is_causal = past_key_value is None
Casper Hansen's avatar
Casper Hansen committed
95
96
97
98
99
100
101

        kv_seq_len = q_len
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        
        value_states = value_states.to(key_states.device)

Haotian Tang's avatar
Haotian Tang committed
102
103
        if past_key_value is not None:
            # reuse k, v, self_attention
Casper Hansen's avatar
Casper Hansen committed
104
105
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
Haotian Tang's avatar
Haotian Tang committed
106
107

        if use_cache:
Casper Hansen's avatar
Casper Hansen committed
108
            # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
Haotian Tang's avatar
Haotian Tang committed
109
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
Casper Hansen's avatar
Casper Hansen committed
110
111
112
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
            query_states = query_states.contiguous()
Casper Hansen's avatar
Casper Hansen committed
113

Casper Hansen's avatar
Casper Hansen committed
114
        past_key_value = (key_states, value_states) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
115

Casper Hansen's avatar
Casper Hansen committed
116
117
118
        # with torch.backends.cuda.sdp_kernel(enable_math=False):
        attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
        del query_states, key_states, value_states
Casper Hansen's avatar
Casper Hansen committed
119

Casper Hansen's avatar
Casper Hansen committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


class CustomQuantLlamaAttention(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_heads,
        qkv_proj,
        o_proj,
        dev,
        max_new_tokens
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {num_heads}).")
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
        self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
Casper Hansen's avatar
Casper Hansen committed
147

Casper Hansen's avatar
Casper Hansen committed
148
149
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Casper Hansen's avatar
Casper Hansen committed
150
151
152
153
154
155
156
157
158
159
160

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask,
        position_ids,
        past_key_value,
        output_attentions: bool = False,
        use_cache: bool = False,
    ):
        # qkv proj
Casper Hansen's avatar
Casper Hansen committed
161
        qkv_states = self.qkv_proj(hidden_states)
Casper Hansen's avatar
Casper Hansen committed
162

Casper Hansen's avatar
Casper Hansen committed
163
164
165
166
167
168
        # extract q,k,v
        bsz, q_len, _ = hidden_states.size()
        query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Casper Hansen's avatar
Casper Hansen committed
169

Casper Hansen's avatar
Casper Hansen committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        # rotary embedding
        query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)

        # cache ops
        is_causal = past_key_value is None
        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        if use_cache:
            # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        past_key_value = (key_states, value_states) if use_cache else None

        # multi-head masked attention
        attn_output = F.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None if is_causal else attention_mask,
            is_causal=is_causal
        )

        # reshape output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Casper Hansen's avatar
Casper Hansen committed
201
202
203
204
205

        # out projection
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value