test_llama_attention.py 5.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pytest
import torch
from packaging import version

try:
    from colossalai.kernel.triton import int8_rotary_embedding_fwd

    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("please install triton from https://github.com/openai/triton")

try:
    from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention

    HAS_TORCH_INT = True
except ImportError:
    HAS_TORCH_INT = False
    print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")


TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")

import math

import torch
from torch.nn import functional as F


def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
    """
    adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
    """
    xq = xq.view(bs, seqlen, num_head, head_dim)
    xk = xk.view(bs, seqlen, num_head, head_dim)
    xv = xv.view(bs, seqlen, num_head, head_dim)
    mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
    mask[mask == 0.0] = -100000000.0
    mask = mask.repeat(bs, num_head, 1, 1)
    keys = xk
    values = xv
    xq = xq.transpose(1, 2)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
    scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
    output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)

    return output


@pytest.mark.skipif(
    not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
    reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
)
def test_llama_context_attention():
    head_num = 2
    seq_len = 32
    head_dim = 64
    dtype = torch.float
    hidden_size = head_num * head_dim

    smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)

    smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
    smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
    smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
    smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
    smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)

    qkv_weight_scale = 1.0

    ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")

    smooth_attn = smooth_attn.to("cuda")

    input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
    input_scale = 1 / 20.0

    output = torch.matmul(input.to(torch.float) * input_scale, ones)
    qkv_max_out = torch.max(torch.abs(output)) / 127
    smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
    smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
    smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)

    q = smooth_attn.q_proj(input)
    k = smooth_attn.k_proj(input)
    v = smooth_attn.v_proj(input)

    cos_shape = (seq_len, head_dim // 2)
    cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
    sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
    in_scale = torch.tensor([qkv_max_out], device="cuda")
    out_scale = torch.tensor([qkv_max_out], device="cuda")
    int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
    int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())

    q = q.to(torch.float) * out_scale
    k = k.to(torch.float) * out_scale
    v = v.to(torch.float) * out_scale
    torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
    attn_out_max = torch.max(torch.abs(torch_out)) / 127

    output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
    smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
    smooth_attn.k_output_scale = torch.tensor(qkv_max_out)

    smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
    smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
    smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)

    smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
    smooth_attn.out_proj.a = torch.tensor([attn_out_max])

    torch_out = (
        (torch_out / smooth_attn.attn_output_scale)
        .round()
        .clamp(-128, 127)
        .to(torch.int8)
        .view(-1, seq_len, head_num * head_dim)
    )

    torch_out = smooth_attn.out_proj(torch_out)
    torch_out = torch_out.to(torch.float)

    smooth_attn = smooth_attn.to("cuda")
    smooth_out, _, _ = smooth_attn(input, (cos, sin))
    smooth_out = smooth_out.to(torch.float)

    assert torch.allclose(
        torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
    ), "outputs from triton and torch are not matched"


if __name__ == "__main__":
    test_llama_context_attention()