test_self_attention.py 4.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pytest
from packaging import version
import torch
from torch import nn 
import torch.nn.functional as F

from colossalai.kernel.triton.ops import self_attention_compute_using_triton
from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel

try:
    import triton
    import triton.language as tl
    HAS_TRITON = True
except ImportError:
    HAS_TRITON = False
    print("please install triton from https://github.com/openai/triton")

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
def test_qkv_matmul():
    qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
    scale = 1.2
    head_size = 32
    batches = qkv.shape[0]
    d_model = qkv.shape[-1] // 3
    num_of_heads = d_model // head_size

    q = qkv[:, :, :d_model]
    k = qkv[:, :, d_model:d_model * 2]

    q = q.view(batches, -1, num_of_heads, head_size)
    k = k.view(batches, -1, num_of_heads, head_size)
    q_copy = q.clone()
    k_copy = k.clone()
    q = torch.transpose(q, 1, 2).contiguous()
    k = torch.transpose(k, 1, 2).contiguous()
    k = torch.transpose(k, 2, 3).contiguous()

    torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k)
    torch_ouput *= 1.2

    q, k = q_copy, k_copy
    batches, M, H, K = q.shape
    N = k.shape[1]
    score_output = torch.empty(
    (batches, H, M, N), device=q.device, dtype=q.dtype)

    grid = lambda meta: (
        batches,
        H,
        triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
        triton.cdiv(N, meta["BLOCK_SIZE_N"]),
    )

    K = q.shape[3]
    qkv_gemm_4d_kernel[grid](
        q, k, score_output,
        M, N, K,
        q.stride(0), q.stride(2), q.stride(1), q.stride(3),
        k.stride(0), k.stride(2), k.stride(3), k.stride(1),
        score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
        scale=scale,
        # currently manually setting, later on we can use auto-tune config to match best setting
        BLOCK_SIZE_M=64,
        BLOCK_SIZE_N=32,
        BLOCK_SIZE_K=32,
        GROUP_SIZE_M=8,
    )

    check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5)
    assert check is True, "the outputs of triton and torch are not matched"
    

def self_attention_compute_using_torch(qkv,
                                       input_mask,
                                       scale,
                                       head_size
                                       ):

    batches = qkv.shape[0]
    d_model = qkv.shape[-1] // 3
    num_of_heads = d_model // head_size
    
    q = qkv[:, :, :d_model]
    k = qkv[:, :, d_model:d_model * 2]
    v = qkv[:, :, d_model * 2:]
    q = q.view(batches, -1, num_of_heads, head_size)
    k = k.view(batches, -1, num_of_heads, head_size)
    v = v.view(batches, -1, num_of_heads, head_size)

    q = torch.transpose(q, 1, 2).contiguous()
    k = torch.transpose(k, 1, 2).contiguous()
    v = torch.transpose(v, 1, 2).contiguous()

    k = torch.transpose(k, -1, -2).contiguous()

    score_output = torch.einsum('bnij,bnjk->bnik', q, k)
    score_output *= scale

    softmax_output = F.softmax(score_output, dim = -1)
    res = torch.einsum('bnij,bnjk->bnik', softmax_output, v)
    res = torch.transpose(res, 1, 2)
    res = res.contiguous()


    return res.view(batches, -1, d_model), score_output, softmax_output

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4")
def test_self_atttention_test():

    qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16)
    data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch(
                                                           qkv.clone(), 
                                                           input_mask = None, 
                                                           scale = 1.2, 
                                                           head_size = 32
                                                           )

    data_output_triton = self_attention_compute_using_triton(
                                                            qkv.clone(),
                                                            alibi=None,
                                                            head_size=32,
                                                            scale=1.2,
                                                            input_mask=None,
                                                            layer_past=None,
                                                            use_flash=False,
                                                            triangular=True)

    check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2)
    assert check is True, "the triton output is not matched with torch output"


if __name__ == "__main__":
    test_qkv_matmul()
    test_self_atttention_test()