/* * Copyright (c) 2024, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "hip_compat.h" #include "dispatch_utils.h" #include #include #if defined(__HIP_PLATFORM_AMD__) || defined(USE_ROCM) #include #include #include #include typedef __hip_bfloat16 nv_bfloat16; #include using aiter_stream_t = hipStream_t; #else #include #include #include #include using aiter_stream_t = cudaStream_t; #endif namespace aiter { template inline __device__ T performUnaryOperation(T a); struct TanhOp { template inline __device__ static T apply(T a) { return (T)(::tanhf(static_cast(a))); // float y, x = static_cast(a); // float neg_x = -x; // const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v; // float tmp = 0, neg_tmp = 0, m = 0, n = 0, emu = 0, neg_emu = 0; // asm volatile( // "v_mul_f32 %[v_neg_tmp], %[s_log2e], %[v_neg_x]; log2e*(-x)\n" // "s_nop 8 ; hazard for exp\n" // "v_mul_f32 %[v_tmp], %[s_log2e], %[v_x] ; log2e*x\n" // "s_nop 8 ; hazard for exp\n" // "v_exp_f32 %[v_neg_emu], %[v_neg_tmp] ; neg_emu = exp2(log2e*(-x)) 0.3678794515979072\n" // "s_nop 8 ; hazard for exp\n" // "v_exp_f32 %[v_emu], %[v_tmp] ; emu = exp2(log2e*x)\n" // "s_nop 8 ; hazard for exp\n" // "v_add_f32 %[v_m], %[v_emu], %[v_neg_emu] ;m=emu+neg_emu\n" // "v_sub_f32 %[v_n], %[v_emu], %[v_neg_emu] ;n=emu - neg_emu\n" // "v_rcp_f32 %[v_tmp], %[v_m] ; 1/m\n" // "s_nop 4 ; hazard for rcp \n" // "v_mul_f32 %[v_y], %[v_n], %[v_tmp] ; n/m\n" // "s_nop 8 ; hazard for exp\n" // : [v_y] "=v"(y), // [v_tmp] "+v"(tmp), // [v_neg_tmp] "+v"(neg_tmp), // [v_emu] "+v"(emu), // [v_neg_emu] "+v"(neg_emu), // [v_m] "+v"(m), // [v_n] "+v"(n) // : [v_x] "v"(x), [v_neg_x] "v"(neg_x), [s_log2e] "n" (log2e_) // :); // return static_cast(y); } static torch::Tensor compute(torch::Tensor &input) { return torch::tanh(input); } }; struct SigmoidOp { template inline __device__ static T apply(T x) { // float y, neg_a = static_cast(-x); // const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v; // float tmp; // asm volatile("v_mul_f32 %[v_tmp], %[s_log2e], %[v_x] ; log2e*x\n" // "v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x)\n" // "s_nop 4 ; hazard for exp\n" // "v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n" // "v_rcp_f32 %[v_y], %[v_tmp] ; 1/(emu+1.0f)\n" // "s_nop 4 ; hazard for rcp \n" // : [v_y] "=v"(y), [v_tmp] "+v"(tmp) // : [v_x] "v"(neg_a), [s_log2e] "n"(log2e_) // :); // return static_cast(y); return static_cast(1.0f / (1.0f + expf(static_cast(-x)))); } static torch::Tensor compute(torch::Tensor &input) { return torch::sigmoid(input); } }; template __global__ void unary_operator_tile_kernel(const void *__restrict a, void *__restrict c, const int M, const int N, const int K) { uint64_t idx = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; uint32_t n_tiles = N / _rows; uint32_t k_tiles = K / _vec; if (idx < (uint64_t)M * n_tiles * k_tiles) { uint32_t ti = idx / (k_tiles * n_tiles); uint64_t idx_block = idx % (k_tiles * n_tiles); uint32_t tj = (idx_block / k_tiles) % n_tiles; uint32_t tk = idx_block % k_tiles; for (int row = 0; row < _rows; row++) { uint64_t offset_ac = (uint64_t)(tj + row * n_tiles) * K + tk * _vec + (uint64_t)ti * N * K; const _T *pa = (const _T *)a + offset_ac; _T *pc = (_T *)c + offset_ac; for (int col = 0; col < _vec; col++) { const _T *pfa = (const _T *)(pa + col); _T *pfc = (_T *)(pc + col); *pfc = Operation::apply(*pfa); } } } } } template torch::Tensor unary_operation(torch::Tensor &input) { int dim = input.dim(); bool is_support = true; is_support &= input.is_contiguous() == true; int M = dim == 2 ? 1 : input.size(0); int N = dim == 2 ? input.size(0) : input.size(1); int K = dim == 2 ? input.size(1) : input.size(2); const uint32_t rows = 8; const uint32_t vec = 16 / sizeof(input.dtype()); is_support &= N % rows == 0; is_support &= K % vec == 0; if (is_support) { auto options = torch::TensorOptions().dtype(input.dtype()).device("cuda"); auto output = torch::empty(input.sizes(), options); void *buf_c = reinterpret_cast(output.data_ptr()); void *buf_a = reinterpret_cast(input.data_ptr()); const hipStream_t stream = at::hip::getCurrentHIPStream(); int elements = N * K; constexpr uint32_t wg = 256; int grid_x = (elements / (rows * vec) + wg - 1) / wg; const dim3 grid_dim(grid_x, 1, 1); const dim3 block_dim(wg, 1, 1); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "unary_operator_tile_kernel", [&] { aiter::unary_operator_tile_kernel <<>>(buf_a, buf_c, M, N, K); }); return output; } else { return Operation::compute(input); } } torch::Tensor aiter_sigmoid(torch::Tensor &input) { return unary_operation(input); } torch::Tensor aiter_tanh(torch::Tensor &input) { return unary_operation(input); }