test_layernorm2d.py 5.32 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
# SPDX-License-Identifier: MIT
import torch
import torch.nn.functional as F
4
import pandas as pd
Xiaowei.zhang's avatar
Xiaowei.zhang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter import dtypes


@perftest()
def run_torch(input, weight, bias, eps, residual=None, x_bias=None):
    if x_bias is not None:
        input = input + x_bias
    if residual is None:
        residual_out = None
        output = F.layer_norm(
            input=input,
            normalized_shape=(input.shape[-1],),
            weight=weight,
            bias=bias,
            eps=eps,
        )
    else:
        residual_out = input + residual
        output = F.layer_norm(
            input=residual_out,
            normalized_shape=(input.shape[-1],),
            weight=weight,
            bias=bias,
            eps=eps,
        )
    return output, residual_out


@perftest()
def run_ck(input, weight, bias, eps, residual=None, x_bias=None):
    if residual is None:
        residual_out = None
    else:
        residual_out = torch.empty_like(input)
        output = torch.empty_like(input)
        aiter.layernorm2d_fwd_with_add(
            output, input, residual, residual_out, weight, bias, eps, x_bias
        )
    return output, residual_out


# @perftest()
# def run_asm(input, weight, bias, eps, residual=None):
#     if residual is None:
#         residual_out = None
#         output = aiter.layer_norm(input, weight, bias, eps)
#     else:
#         residual_out = torch.empty_like(input)
#         output = torch.empty_like(input)
#         aiter.layernorm2d_with_add_asm(
#             output, input, residual, residual_out, weight, bias, eps
#         )
#     return output, residual_out


def test_layernorm2d(dtype, m, n):
    dim = (m, n)
    input = torch.randn(dim, dtype=dtype, device="cuda")
    weight = torch.randn(n, dtype=dtype, device="cuda")
    bias = torch.randn(n, dtype=dtype, device="cuda")
    hidden_stats = torch.randn(m, n * 8, dtype=dtype, device="cuda")
    q, k, v = torch.split(hidden_stats, [6 * n, n, n], dim=1)
    input = k
    (a, *_), avg_a = run_torch(input, weight, bias, 1e-5)
    (b, *_), avg_b = run_ck(input, weight, bias, 1e-5)
    msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
73
74
75
76
77
78
79
80
81
82
83
    check_ret = checkAllclose(a, b, msg=msg)
    ret_output = "passed" if check_ret == 0 else (1 - check_ret)
    return {
        "m": m,
        "n": n,
        "dtype": str(dtype),
        "torch_us": avg_a,
        "ck_us": avg_b,
        "uplift": f"{avg_a / avg_b - 1:.1%}",
        "accuracy": ret_output,
    }
Xiaowei.zhang's avatar
Xiaowei.zhang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


def test_layernorm2d_fuseAdd(dtype, m, n):
    dim = (m, n)
    input = torch.randn(dim, dtype=dtype, device="cuda")
    # x_bias = torch.randn(n, dtype=dtype, device="cuda")
    # x_bias = None
    weight = torch.randn(n, dtype=dtype, device="cuda")
    bias = torch.randn(n, dtype=dtype, device="cuda")
    res = torch.randn(dim, dtype=dtype, device="cuda")
    hidden_stats = torch.randn(m, n * 8, dtype=dtype, device="cuda")
    q, k, v = torch.split(hidden_stats, [6 * n, n, n], dim=1)
    # input = k
    (a, res_a, *_), avg_a = run_torch(input, weight, bias, 1e-5, residual=res)
    (b, res_b, *_), avg_b = run_ck(input, weight, bias, 1e-5, residual=res)
    # (c, res_c, *_), avg_c = run_asm(input, weight, bias, 1e-5, residual=res)

    msg = f"[perf] dim: {str(dim):<20}, dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, uplift: {avg_a/avg_b-1:<5.1%}"
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    check_ret = checkAllclose(a, b, atol=0.03, msg=msg)
    ret_output = "passed" if check_ret == 0 else (1 - check_ret)
    residual_check_ret = checkAllclose(res_a, res_b, msg="res check")
    residual_output = "passed" if residual_check_ret == 0 else (1-residual_check_ret)
    return {
        "m": m,
        "n": n,
        "dtype": str(dtype),
        "torch_us": avg_a,
        "ck_us": avg_b,
        "uplift": f"{avg_a / avg_b - 1:.1%}",
        "accuracy": ret_output,
        "residual_accuracy": residual_output,
    }


Xiaowei.zhang's avatar
Xiaowei.zhang committed
118
119
120
121
    # checkAllclose(a, c, atol=0.03, msg="asm")
    # checkAllclose(res_a, res_c, atol=0.01, msg="asm res")


122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
if __name__ == "__main__":
    df = []
    df_fuse_add = []

    # for dtype in [dtypes.fp16, dtypes.bf16]:
    #     for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
    #         for n in [4096, 8192, 16384, 32768, 65536]:
    #             ret = test_layernorm2d(dtype, m, n)
    #             if ret is not None:
    #                 df.append(ret)
    # ret = test_layernorm2d(dtypes.bf16, 128, 8192)
    # if ret is not None:
    #     df.append(ret)

    # print('\nstart fuse add test')
    # for dtype in [dtypes.fp16, dtypes.bf16]:
    #     for m in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
    #         for n in [4096, 8192, 16384, 32768, 65536]:
    #             ret = test_layernorm2d_fuseAdd(dtype, m, n)
    #             if ret is not None:
    #                 df_fuse_add.append(ret)
    ret = test_layernorm2d_fuseAdd(dtypes.bf16, 128, 8192)
    if ret is not None:
        df_fuse_add.append(ret)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
146

147
148
149
150
    # if df:
    #     df = pd.DataFrame(df)
    #     aiter.logger.info(f"layernorm2d summary:\n{df}")
    #     df.to_csv("test_layernorm2d.csv", index=False)
Xiaowei.zhang's avatar
Xiaowei.zhang committed
151

152
153
154
155
    if df_fuse_add:
        df_fuse_add = pd.DataFrame(df_fuse_add)
        aiter.logger.info(f"layernorm2d fuseAdd summary:\n{df_fuse_add}")
        df_fuse_add.to_csv("test_layernorm2d_fuseAdd.csv", index=False)