fused_attn.py 4.98 KB
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
4
from torch.nn import functional as F
Haotian Tang's avatar
Haotian Tang committed
5

Casper Hansen's avatar
Casper Hansen committed
6
7
class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
Haotian Tang's avatar
Haotian Tang committed
8
9
        super().__init__()

Casper Hansen's avatar
Casper Hansen committed
10
11
12
13
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
Haotian Tang's avatar
Haotian Tang committed
14
        self.register_buffer("inv_freq", inv_freq)
Casper Hansen's avatar
Casper Hansen committed
15
16
17
18
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
Haotian Tang's avatar
Haotian Tang committed
19

Casper Hansen's avatar
Casper Hansen committed
20
    def _set_cos_sin_cache(self, seq_len, device, dtype):
Haotian Tang's avatar
Haotian Tang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
    
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
38
39
40
41
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Haotian Tang's avatar
Haotian Tang committed
42
43
44
45
46
47
        query = query.contiguous()
        key = key.contiguous()
        awq_inference_engine.rotary_embedding_neox(
            positions,
            query,
            key,
Casper Hansen's avatar
Casper Hansen committed
48
            self.dim,
Haotian Tang's avatar
Haotian Tang committed
49
50
            self.cos_sin_cache,
        )
Casper Hansen's avatar
Casper Hansen committed
51
        return query, key
Casper Hansen's avatar
Casper Hansen committed
52

Haotian Tang's avatar
Haotian Tang committed
53
class QuantLlamaAttention(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
54
55
    """Multi-headed attention from 'Attention Is All You Need' paper"""

Haotian Tang's avatar
Haotian Tang committed
56
57
58
59
60
61
    def __init__(
        self,
        hidden_size,
        num_heads,
        qkv_proj,
        o_proj,
Casper Hansen's avatar
Casper Hansen committed
62
        dev,
Casper Hansen's avatar
Casper Hansen committed
63
        max_new_tokens
Haotian Tang's avatar
Haotian Tang committed
64
65
66
67
68
69
70
71
72
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                             f" and `num_heads`: {num_heads}).")
Casper Hansen's avatar
Casper Hansen committed
73
74
75
76
77
78
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
        self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)

    def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
        """Input shape: Batch x Time x Channel"""
Haotian Tang's avatar
Haotian Tang committed
79

Casper Hansen's avatar
Casper Hansen committed
80
        bsz, q_len, _ = hidden_states.size()
Haotian Tang's avatar
Haotian Tang committed
81

Casper Hansen's avatar
Casper Hansen committed
82
83
84
85
86
87
        qkv_states = self.qkv_proj(hidden_states)
        qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)

        # This updates the query and key states in-place, saving VRAM.
        query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
        query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
Haotian Tang's avatar
Haotian Tang committed
88
        
Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
        del qkv_states
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
Haotian Tang's avatar
Haotian Tang committed
93
94

        is_causal = past_key_value is None
Casper Hansen's avatar
Casper Hansen committed
95
96
97
98
99
100
101

        kv_seq_len = q_len
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        
        value_states = value_states.to(key_states.device)

Haotian Tang's avatar
Haotian Tang committed
102
103
        if past_key_value is not None:
            # reuse k, v, self_attention
Casper Hansen's avatar
Casper Hansen committed
104
105
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
Haotian Tang's avatar
Haotian Tang committed
106
107

        if use_cache:
Casper Hansen's avatar
Casper Hansen committed
108
            # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
Haotian Tang's avatar
Haotian Tang committed
109
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
Casper Hansen's avatar
Casper Hansen committed
110
111
112
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
            query_states = query_states.contiguous()
Casper Hansen's avatar
Casper Hansen committed
113

Casper Hansen's avatar
Casper Hansen committed
114
        past_key_value = (key_states, value_states) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
115

Casper Hansen's avatar
Casper Hansen committed
116
117
118
        # with torch.backends.cuda.sdp_kernel(enable_math=False):
        attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
        del query_states, key_states, value_states
Casper Hansen's avatar
Casper Hansen committed
119

Casper Hansen's avatar
Casper Hansen committed
120
121
122
123
        attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value