/* * Copyright (c) 2024, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include constexpr int WARP_SIZE = 64; template __device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) return (T)(((float)x) / (1.0f + expf((float)-x))); } template __device__ __forceinline__ T loadnt(T *addr) { return __builtin_nontemporal_load(addr); } __device__ __forceinline__ float4 load_ntmprl(const float4 *addr) { auto addr_alias = reinterpret_cast(addr); auto dat0 = loadnt(addr_alias); auto dat1 = loadnt(addr_alias + 1); auto dat2 = loadnt(addr_alias + 2); auto dat3 = loadnt(addr_alias + 3); // auto dat0 = *(addr_alias); // auto dat1 = *(addr_alias+1); // auto dat2 = *(addr_alias+2); // auto dat3 = *(addr_alias+3); return make_float4(dat0, dat1, dat2, dat3); } // TBlock fetches entire rows of A, and entire col of B (K dimension); assume // N=1 for time being grid is M/A_NUM_ROWS blocks template __global__ void LLGemm_Silu_kernel(float4 *af4, __half2 *bf4, _Float16 *c, const int d) { __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; const int row_addr_d = row_addr + d * blockDim.x; // int row_addr_1 = row_addr + CUDA_NUM_THREADS; // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; const int threadid = threadIdx.x; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; const int qwarpid = threadid / 16; const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; // float4 colB_elem4; __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; __half2 acch2; __half2 oval; // rowA_elem4 = af4[row_addr + threadid]; //__syncthreads(); // rowA_elem4_1 = af4[row_addr_1 + threadid]; // rowA_elem4_2 = af4[row_addr_2 + threadid]; // rowA_elem4_3 = af4[row_addr_3 + threadid]; #pragma unroll for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); rowA_elem4[2 * i + 1] = load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; //__syncthreads(); } colB_elem4x = bf4[threadid * 4 + 0]; colB_elem4y = bf4[threadid * 4 + 1]; colB_elem4z = bf4[threadid * 4 + 2]; colB_elem4w = bf4[threadid * 4 + 3]; // __syncthreads(); __half2 Af2; __half2 Bf2; float2 S; // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); // auto Bf2x = *Bh2ptr; // auto Bf2y = *(Bh2ptr+1); // auto Bf2z = *(Bh2ptr+2); // auto Bf2w = *(Bh2ptr+3); auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); __half2 *ah2lptr; #pragma unroll for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { ah2lptr = Ah2ptr + i * 4; Af2 = *(ah2lptr); acch2 = __hmul2(Af2, colB_elem4x); Af2 = *(ah2lptr + 1); acch2 = __hfma2(Af2, colB_elem4y, acch2); Af2 = *(ah2lptr + 2); acch2 = __hfma2(Af2, colB_elem4z, acch2); Af2 = *(ah2lptr + 3); acch2 = __hfma2(Af2, colB_elem4w, acch2); S = __half22float2(acch2); acc[i] = S.x + S.y; } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { #pragma unroll for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { acc[i] += __shfl_xor(acc[i], mask); } } // Warp leaders store the data to shared memory. // if (lane == 0) { // #pragma unroll // for (int i=0; i= 1; mask /= 2) { // #pragma unroll // for (int i=0; i void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block = 4) { float4 *af4 = reinterpret_cast(in_a); auto *bf4 = reinterpret_cast<__half2 *>(in_b); auto *c = reinterpret_cast<_Float16 *>(out_c); const int d = M / 2; const int NUM_THREADS = K * 2 / 16; int NUM_BLOCKS = M / rows_per_block; if (rows_per_block == 2) { LLGemm_Silu_kernel<2> <<>>(af4, bf4, c, d); } else if (rows_per_block == 4) { LLGemm_Silu_kernel<4> <<>>(af4, bf4, c, d); } else if (rows_per_block == 8) { LLGemm_Silu_kernel<8> <<>>(af4, bf4, c, d); } else if (rows_per_block == 16) { LLGemm_Silu_kernel<16> <<>>(af4, bf4, c, d); } else { NUM_BLOCKS = M / 4; LLGemm_Silu_kernel<4> <<>>(af4, bf4, c, d); } cudaError_t err = cudaGetLastError(); if (cudaSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); }