fused_attn.py 6.41 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
4
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
5
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaRotaryEmbedding
Haotian Tang's avatar
Haotian Tang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        
Casper Hansen's avatar
Casper Hansen committed
33
        # [max_position, rot_dim]
Haotian Tang's avatar
Haotian Tang committed
34
35
36
37
38
39
40
41
42
43
44
45
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        query = query.contiguous()
        key = key.contiguous()
Casper Hansen's avatar
Casper Hansen committed
46

Casper Hansen's avatar
Casper Hansen committed
47
        awq_inference_engine.rotary_embedding(
Haotian Tang's avatar
Haotian Tang committed
48
49
50
51
52
            positions,
            query,
            key,
            self.dim,
            self.cos_sin_cache,
Casper Hansen's avatar
Casper Hansen committed
53
            True # is_neox
Haotian Tang's avatar
Haotian Tang committed
54
        )
Casper Hansen's avatar
Casper Hansen committed
55

Haotian Tang's avatar
Haotian Tang committed
56
57
58
59
60
61
62
63
64
        return query, key

class QuantLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size,
        num_heads,
Casper Hansen's avatar
Casper Hansen committed
65
        num_kv_heads,
Haotian Tang's avatar
Haotian Tang committed
66
67
        qkv_proj,
        o_proj,
Casper Hansen's avatar
Casper Hansen committed
68
        dev,
Casper Hansen's avatar
Casper Hansen committed
69
        max_new_tokens,
Casper Hansen's avatar
Casper Hansen committed
70
        use_hf_rotary=False
Haotian Tang's avatar
Haotian Tang committed
71
72
73
74
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
Casper Hansen's avatar
Casper Hansen committed
75
        self.num_kv_heads = num_kv_heads
Haotian Tang's avatar
Haotian Tang committed
76
        self.head_dim = hidden_size // num_heads
Casper Hansen's avatar
Casper Hansen committed
77
        self.use_hf_rotary = use_hf_rotary
Haotian Tang's avatar
Haotian Tang committed
78
79
80
81
82
83

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {num_heads}).")
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
Casper Hansen's avatar
Casper Hansen committed
84
85
86
87
88

        if use_hf_rotary:
            self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_new_tokens, device=dev)
        else:
            self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
Haotian Tang's avatar
Haotian Tang committed
89
90
91
92
93
94
95

    def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
        """Input shape: Batch x Time x Channel"""

        bsz, q_len, _ = hidden_states.size()

        qkv_states = self.qkv_proj(hidden_states)
Casper Hansen's avatar
Casper Hansen committed
96
97

        if self.use_hf_rotary:
Casper Hansen's avatar
Casper Hansen committed
98
            # get qkv
Casper Hansen's avatar
Casper Hansen committed
99
            qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
Casper Hansen's avatar
Casper Hansen committed
100
101
            query, key, value = torch.split(qkv_states, 1, dim=2)
            del qkv_states
Casper Hansen's avatar
Casper Hansen committed
102
            
Casper Hansen's avatar
Casper Hansen committed
103
104
105
106
            # reshape for hf rotary
            query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Casper Hansen's avatar
Casper Hansen committed
107

Casper Hansen's avatar
Casper Hansen committed
108
            kv_seq_len = key.shape[-2]
Casper Hansen's avatar
Casper Hansen committed
109
110
111
            if past_key_value is not None:
                kv_seq_len += past_key_value[0].shape[-2]
            
Casper Hansen's avatar
Casper Hansen committed
112
113
            cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
            query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
Casper Hansen's avatar
Casper Hansen committed
114
115

        else:
Casper Hansen's avatar
Casper Hansen committed
116
117
118
            # get qkv
            query, key, value = qkv_states.chunk(chunks=3, dim=-1)
            del qkv_states
Casper Hansen's avatar
Casper Hansen committed
119

Casper Hansen's avatar
Casper Hansen committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            # [num_tokens, num_heads * head_size]
            query_batch_size, query_len, _ = query.shape
            query = query.view(query_len*query_batch_size, self.num_heads * self.head_dim)

            # [num_tokens, num_kv_heads * head_size]
            key_batch_size, key_len, _ = key.shape
            key = key.view(key_len*key_batch_size, self.num_kv_heads * self.head_dim)

            # [num_tokens]
            positions = position_ids.view(-1).to(query.device)

            query, key = self.rotary_emb(query, key, positions)

            query = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            key = key.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            value = value.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Casper Hansen's avatar
Casper Hansen committed
136
        
Haotian Tang's avatar
Haotian Tang committed
137
138
139
140
141
142
        is_causal = past_key_value is None

        kv_seq_len = q_len
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        
Casper Hansen's avatar
Casper Hansen committed
143
        value = value.to(key.device)
Haotian Tang's avatar
Haotian Tang committed
144
145
146

        if past_key_value is not None:
            # reuse k, v, self_attention
Casper Hansen's avatar
Casper Hansen committed
147
148
            key = torch.cat([past_key_value[0], key], dim=2)
            value = torch.cat([past_key_value[1], value], dim=2)
Haotian Tang's avatar
Haotian Tang committed
149
150

        if use_cache:
Casper Hansen's avatar
Casper Hansen committed
151
            # Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
Haotian Tang's avatar
Haotian Tang committed
152
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
Casper Hansen's avatar
Casper Hansen committed
153
154
155
            key = key.contiguous()
            value = value.contiguous()
            query = query.contiguous()
Haotian Tang's avatar
Haotian Tang committed
156

Casper Hansen's avatar
Casper Hansen committed
157
        past_key_value = (key, value) if use_cache else None
Haotian Tang's avatar
Haotian Tang committed
158
159

        # with torch.backends.cuda.sdp_kernel(enable_math=False):
Casper Hansen's avatar
Casper Hansen committed
160
161
        attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
        del query, key, value
Haotian Tang's avatar
Haotian Tang committed
162
163
164
165
166

        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value