asm_layernorm.cpp 4.86 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// SPDX-License-Identifier: MIT
 
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "aiter_hip_common.h"

struct __attribute__((packed)) KernelArgs
{
    void *ptr_O;
    p2 _p0;
    void *ptr_In;
    p2 _p1;
    void *ptr_Weight;
    p2 _p2;
    void *ptr_Bias;
    p2 _p3;
    float epsilon;
    p3 _p4;
    unsigned int M;
    p3 _p5;
    unsigned int N;
    p3 _p6;
    void *ptr_OutResidual;
    p2 _p7;
    void *ptr_InResidual;
    p2 _p8;
    void *ptr_OutYScale;
    p2 _p9;
    void *ptr_XScale;
    p2 _p10;
};

void layernorm2d_with_add_asm(torch::Tensor &out,          // [m ,n]
                              torch::Tensor &input,        // [m ,n]
                              torch::Tensor &residual_in,  // [m ,n]
                              torch::Tensor &residual_out, // [m ,n]
                              torch::Tensor &weight,       // [1 ,n]
                              torch::Tensor &bias,         // [1 ,n]
                              float epsilon)
{
    auto dtype = input.dtype();
    TORCH_CHECK(dtype == torch::kBFloat16,
                __func__, " for now only support bf16 data type");
    TORCH_CHECK(input.is_contiguous(),
                __func__, " for now only support input.is_contiguous()");

    KernelArgs args;
    int n = input.size(-1);
    int m = input.numel() / n;
    TORCH_CHECK(m % 2 == 0,
                __func__, " for now only support m % 2 == 0");
    TORCH_CHECK(n == 8192,
                __func__, " for now only support n == 8192");

    size_t arg_size = sizeof(args);
    args.ptr_O = out.data_ptr();
    args.ptr_In = input.data_ptr();
    args.ptr_Weight = weight.data_ptr();
    args.ptr_Bias = bias.data_ptr();
    args.epsilon = epsilon;
    args.M = m;
    args.N = n;
    args.ptr_OutResidual = residual_out.data_ptr();
    args.ptr_InResidual = residual_in.data_ptr();

    const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    int sub_M = 2;
    static AiterAsmKernel impl("layer_norm_kernel_func", "layer_norm.co");

    impl.launch_kernel({&args,
                        &arg_size,
                        ((m + sub_M - 1) / sub_M), // gdx
                        1,                         // gdy
                        1,                         // gdz
                        256,                       // bdx: 4 wv64
                        1,                         // bdy
                        1,                         // bdz
                        stream});
}

void layernorm2d_with_add_smoothquant_asm(torch::Tensor &out,          // [m ,n]
                                          torch::Tensor &input,        // [m ,n]
                                          torch::Tensor &residual_in,  // [m ,n]
                                          torch::Tensor &residual_out, // [m ,n]
                                          torch::Tensor &xscale,       // [1 ,n]
                                          torch::Tensor &yscale,       // [m ,1]
                                          torch::Tensor &weight,       // [1 ,n]
                                          torch::Tensor &bias,         // [1 ,n]
                                          float epsilon)
{
    auto dtype = input.dtype();
    TORCH_CHECK(dtype == torch::kBFloat16,
                __func__, " for now only support bf16 data type");
    TORCH_CHECK(input.is_contiguous(),
                __func__, " for now only support input.is_contiguous()");

    KernelArgs args;
    int n = input.size(-1);
    int m = input.numel() / n;
    TORCH_CHECK(m % 2 == 0,
                __func__, " for now only support m % 2 == 0");
    TORCH_CHECK(n == 8192,
                __func__, " for now only support n == 8192");

    size_t arg_size = sizeof(args);
    args.ptr_O = out.data_ptr();
    args.ptr_In = input.data_ptr();
    args.ptr_Weight = weight.data_ptr();
    args.ptr_Bias = bias.data_ptr();
    args.epsilon = epsilon;
    args.M = m;
    args.N = n;
    args.ptr_OutResidual = residual_out.data_ptr();
    args.ptr_InResidual = residual_in.data_ptr();
    args.ptr_OutYScale = yscale.data_ptr();
    args.ptr_XScale = xscale.data_ptr();

    const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    int sub_M = 2;
    static AiterAsmKernel impl("layer_norm_qnt", "layer_norm_qnt.co");

    impl.launch_kernel({&args,
                        &arg_size,
                        ((m + sub_M - 1) / sub_M), // gdx
                        1,                         // gdy
                        1,                         // gdz
                        256,                       // bdx: 4 wv64
                        1,                         // bdy
                        1,                         // bdz
                        stream});
}