unary_operator.cu 7.62 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
/*
 
 * Copyright (c) 2024, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <torch/all.h>
#include "hip_compat.h"
#include "dispatch_utils.h"
#include <torch/torch.h>
#include <cmath>

#if defined(__HIP_PLATFORM_AMD__) || defined(USE_ROCM)
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <c10/hip/HIPGuard.h>
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#include <hip/hip_fp16.h>
using aiter_stream_t = hipStream_t;
#else
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
using aiter_stream_t = cudaStream_t;
#endif

namespace aiter
{
    template <typename T, typename Operation>
    inline __device__ T performUnaryOperation(T a);

    struct TanhOp
    {
        template <typename T>
        inline __device__ static T apply(T a)
        {
            return (T)(::tanhf(static_cast<float>(a)));

            // float y, x = static_cast<float>(a);
            // float neg_x = -x;
            // const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
            // float tmp = 0, neg_tmp = 0, m = 0, n = 0, emu = 0, neg_emu = 0;
            // asm volatile(
            //              "v_mul_f32 %[v_neg_tmp], %[s_log2e], %[v_neg_x]; log2e*(-x)\n"
            //              "s_nop 8                                       ; hazard for exp\n"
            //              "v_mul_f32 %[v_tmp], %[s_log2e], %[v_x]        ; log2e*x\n"
            //              "s_nop 8                                       ; hazard for exp\n"
            //              "v_exp_f32 %[v_neg_emu], %[v_neg_tmp]          ; neg_emu = exp2(log2e*(-x)) 0.3678794515979072\n"
            //              "s_nop 8                                       ; hazard for exp\n"
            //              "v_exp_f32 %[v_emu], %[v_tmp]                  ; emu = exp2(log2e*x)\n"
            //              "s_nop 8                                       ; hazard for exp\n"
            //              "v_add_f32 %[v_m], %[v_emu], %[v_neg_emu]      ;m=emu+neg_emu\n"
            //              "v_sub_f32 %[v_n], %[v_emu], %[v_neg_emu]      ;n=emu - neg_emu\n"
            //              "v_rcp_f32 %[v_tmp], %[v_m]                      ; 1/m\n"
            //              "s_nop 4                                       ; hazard for rcp \n"
            //              "v_mul_f32 %[v_y], %[v_n], %[v_tmp]              ; n/m\n"
            //              "s_nop 8                                       ; hazard for exp\n"
            //              : [v_y] "=v"(y),
            //                [v_tmp] "+v"(tmp),
            //                [v_neg_tmp] "+v"(neg_tmp),
            //                [v_emu] "+v"(emu),
            //                [v_neg_emu] "+v"(neg_emu),
            //                [v_m] "+v"(m),
            //                [v_n] "+v"(n)
            //              : [v_x] "v"(x), [v_neg_x] "v"(neg_x), [s_log2e] "n" (log2e_)
            //              :);
            // return static_cast<T>(y);
        }

        static torch::Tensor compute(torch::Tensor &input)
        {
            return torch::tanh(input);
        }
    };

    struct SigmoidOp
    {
        template <typename T>
        inline __device__ static T apply(T x)
        {
            //   float y, neg_a = static_cast<float>(-x);
            //   const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
            //   float tmp;
            //   asm volatile("v_mul_f32 %[v_tmp], %[s_log2e], %[v_x]    ; log2e*x\n"
            //                "v_exp_f32 %[v_tmp], %[v_tmp]              ; emu = exp2(log2e*x)\n"
            //                "s_nop 4                                   ; hazard for exp\n"
            //                "v_add_f32 %[v_tmp], %[v_tmp], 1.0         ; emu+1.0f\n"
            //                "v_rcp_f32 %[v_y], %[v_tmp]                ; 1/(emu+1.0f)\n"
            //                "s_nop 4                                   ; hazard for rcp \n"
            //                : [v_y] "=v"(y), [v_tmp] "+v"(tmp)
            //                : [v_x] "v"(neg_a), [s_log2e] "n"(log2e_)
            //                :);
            //   return static_cast<T>(y);
            return static_cast<T>(1.0f / (1.0f + expf(static_cast<float>(-x))));
        }

        static torch::Tensor compute(torch::Tensor &input)
        {
            return torch::sigmoid(input);
        }
    };

    template <class _T, int _rows, int _vec, typename Operation>
    __global__ void unary_operator_tile_kernel(const void *__restrict a, void *__restrict c, const int M, const int N, const int K)
    {
        uint64_t idx = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
        uint32_t n_tiles = N / _rows;
        uint32_t k_tiles = K / _vec;
        if (idx < (uint64_t)M * n_tiles * k_tiles)
        {
            uint32_t ti = idx / (k_tiles * n_tiles);
            uint64_t idx_block = idx % (k_tiles * n_tiles);
            uint32_t tj = (idx_block / k_tiles) % n_tiles;
            uint32_t tk = idx_block % k_tiles;
            for (int row = 0; row < _rows; row++)
            {
                uint64_t offset_ac = (uint64_t)(tj + row * n_tiles) * K + tk * _vec + (uint64_t)ti * N * K;
                const _T *pa = (const _T *)a + offset_ac;
                _T *pc = (_T *)c + offset_ac;
                for (int col = 0; col < _vec; col++)
                {
                    const _T *pfa = (const _T *)(pa + col);
                    _T *pfc = (_T *)(pc + col);
                    *pfc = Operation::apply(*pfa);
                }
            }
        }
    }
}

template <typename Operation>
torch::Tensor unary_operation(torch::Tensor &input)
{
    int dim = input.dim();
    bool is_support = true;
    is_support &= input.is_contiguous() == true;
    int M = dim == 2 ? 1 : input.size(0);
    int N = dim == 2 ? input.size(0) : input.size(1);
    int K = dim == 2 ? input.size(1) : input.size(2);
    const uint32_t rows = 8;
    const uint32_t vec = 16 / sizeof(input.dtype());
    is_support &= N % rows == 0;
    is_support &= K % vec == 0;
    if (is_support)
    {
        auto options = torch::TensorOptions().dtype(input.dtype()).device("cuda");
        auto output = torch::empty(input.sizes(), options);
        void *buf_c = reinterpret_cast<void *>(output.data_ptr());

        void *buf_a = reinterpret_cast<void *>(input.data_ptr());
        const hipStream_t stream = at::hip::getCurrentHIPStream();
        int elements = N * K;

        constexpr uint32_t wg = 256;
        int grid_x = (elements / (rows * vec) + wg - 1) / wg;
        const dim3 grid_dim(grid_x, 1, 1);
        const dim3 block_dim(wg, 1, 1);

        VLLM_DISPATCH_FLOATING_TYPES(
            input.scalar_type(), "unary_operator_tile_kernel", [&]
            { aiter::unary_operator_tile_kernel<scalar_t, rows, vec, Operation>
                  <<<grid_dim, block_dim, 0, stream>>>(buf_a, buf_c, M, N, K); });
        return output;
    }
    else
    {
        return Operation::compute(input);
    }
}

torch::Tensor aiter_sigmoid(torch::Tensor &input)
{
    return unary_operation<aiter::SigmoidOp>(input);
}

torch::Tensor aiter_tanh(torch::Tensor &input)
{
    return unary_operation<aiter::TanhOp>(input);
}