Commit 9f217825 authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.0.6_develop_sugon' into 'main'

v0.0.6

See merge request dcutoolkit/deeplearing/autoawq_kernels!1
parents b2c05ad6 1c46b800
Pipeline #1718 failed with stages
in 0 seconds
#pragma once
#include <torch/extension.h>
torch::Tensor gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int group_size);
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#pragma once
__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result)
{
// uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// return result;
}
\ No newline at end of file
#include <cuda_fp16.h>
#include "semaphore.h"
#include "gemm_cuda.h"
#include "../dequantize.cuh"
#include <torch/extension.h>
#include <cuda_pipeline_primitives.h>
#define kInterleave 4
#define OP_M 16
#define OP_N 8
#define OP_K 16
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 16
#define WARP_SIZE 32
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 8
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \
auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(half); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template <int N>
__inline__ __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
template <int SLICES, int NUM_WARPS_MN>
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1)
{
__syncthreads();
}
else
{
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
__inline__ __device__ void ldmatrix_m8n8_x4_b16(half *shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
{
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
half scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
half zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
half2 scale2 = make_half2(scale, scale);
half2 zero2 = make_half2(zero, zero);
half2 loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK>
__global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)
{
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
half *A_shared = mem_shared;
half *B_shared = mem_shared + kSmemSizeA;
half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB;
half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
float *C_shared = reinterpret_cast<float *>(mem_shared);
half A_shared_warp_[2][WARP_M * INTRIN_K /
WARP_SIZE];
half B_shared_warp_[2][WARP_N * 32 /
WARP_SIZE];
half B_shared_warp_tmp_[2][WARP_N * 16 /
WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_scales<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
half *A_shared_this_compute_stage;
half *B_shared_this_compute_stage;
half *scales_shared_this_compute_stage;
half *zeros_shared_this_compute_stage;
#pragma unroll
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
half *A_shared_warp = A_shared_warp_[iter_k % 2];
half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
if (slice_id == 0)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
if (slice_id == 0)
{
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if constexpr (SPLITK > 1)
{
semaphore.fetch();
}
if (blockIdx_z != 0)
{
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
half2 *existing_psum_ptr = reinterpret_cast<half2 *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr = __hadd2(*existing_psum_ptr,
__float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id)));
}
};
}
}
}
else
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<half2 *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
__float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
}
if constexpr (SPLITK > 1)
{
int lock = 0;
if (SPLITK == blockIdx_z + 1)
{
lock = 0;
}
else
{
lock = blockIdx_z + 1;
}
semaphore.release(lock);
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>
__device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int warp_offset_m, int warp_offset_n, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>
__device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
half scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
half zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
half2 scale2 = make_half2(scale, scale);
half2 zero2 = make_half2(zero, zero);
half2 loaded[4];
dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);
}
}
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int M, int N, int K)
{
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
half *A_shared = mem_shared;
half *B_shared = mem_shared + kSmemSizeA;
half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB;
half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
half A_shared_warp_[2][WARP_M * INTRIN_K /
WARP_SIZE];
half B_shared_warp_[2][WARP_N * 32 /
WARP_SIZE];
half B_shared_warp_tmp_[2][WARP_N * 16 /
WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
half *A_shared_this_compute_stage;
half *B_shared_this_compute_stage;
half *scales_shared_this_compute_stage;
half *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
half *A_shared_warp = A_shared_warp_[iter_k % 2];
half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast<half2 *>(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
__float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
}
torch::Tensor gemm_forward_cuda_prefill(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto in_feats = reinterpret_cast<half *>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<half *>(_kernel.data_ptr<int16_t>());
auto scales = reinterpret_cast<half *>(_scales.data_ptr<at::Half>());
auto zeros = reinterpret_cast<half *>(_zeros.data_ptr<at::Half>());
auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
auto out_feats = reinterpret_cast<half *>(_out_feats.data_ptr<at::Half>());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(half);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
return _out_feats;
}
\ No newline at end of file
#include <torch/extension.h>
torch::Tensor gemm_forward_cuda_prefill(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros);
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
// namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
public:
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// } // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
static inline __device__ float to_float(half src)
{
return __half2float(src);
}
static inline __device__ float to_float(float src)
{
return src;
}
static inline __device__ half to_half(float src)
{
return __float2half(src);
}
static inline __device__ half to_half(half src)
{
return src;
}
// Reduce sum within the warp using the tree reduction algorithm.
template <int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = to_float(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
template <int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs,
const int IC, const int OC)
{
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
half local_inputs[kElemsPerThread];
uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
half half_weight_buffer[kElemsPerThread];
half dequantized_weight[kElemsPerThread * NPerBlock];
half local_scale[NPerBlock];
half local_scaled_zeros[NPerBlock];
half psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = to_half(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half* inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2 w =
*reinterpret_cast<half2*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2)
= __hfma2(*reinterpret_cast<half2*>(dequantized_weight + y * NPerBlock + x * 2),
__half2half2(local_inputs[y]),
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2));
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor gemv_forward_cuda_decode(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half*>(_zeros.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
half * out_feats = reinterpret_cast<half *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
// if (group_size == 64)
// {
// gemv_kernel_g64<<<num_blocks, num_threads>>>(
// // pointers
// in_feats, kernel, zeros, scaling_factors, out_feats,
// // constants
// num_in_channels, num_out_channels
// );
// }
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 2:
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 3:
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 4:
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 5:
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 6:
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 7:
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
return _out_feats;
}
#pragma once
#include <torch/extension.h>
torch::Tensor gemv_forward_cuda_decode(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size);
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#define VLLM_LDG(arg) *(arg)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template<typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template<typename scalar_t>
__global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
}
\ No newline at end of file
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
\ No newline at end of file
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
const static size_t NUM_MAX_EXPERTS = 64;
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
template <typename scalar_t>
__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t *sorted_token_ids,
int32_t *expert_ids,
int32_t *total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
for(int i = 0;i < num_experts;i++){
tokens_cnts[threadIdx.x + 1][i] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
}
__syncthreads();
tokens_cnts[0][threadIdx.x] = 0;
for(int i=1;i<=blockDim.x;++i){
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
}
__syncthreads();
if(threadIdx.x ==0){
cumsum[0] = 0;
for(int i=1;i<=num_experts;++i){
cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
for(int i= cumsum[threadIdx.x];i<cumsum[threadIdx.x + 1];i += block_size){
expert_ids[i / block_size] = threadIdx.x;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x][expert_id];
}
}
void moe_alig_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const at::cuda::OptionalCUDAGuard device_guard_topk_ids(device_of(topk_ids));
const at::cuda::OptionalCUDAGuard device_guard_sorted(device_of(sorted_token_ids));
const at::cuda::OptionalCUDAGuard device_guard_experts(device_of(experts_ids));
const at::cuda::OptionalCUDAGuard device_guard_num_tokens(device_of(num_tokens_post_pad));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
moe_alig_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
});
}
\ No newline at end of file
void moe_alig_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad
);
\ No newline at end of file
/*
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
static constexpr int WARP_SIZE = 32;
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
float data[N];
};
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x])
{
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0)
{
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0)
{
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int num_rows = gridDim.x;
const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
{
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
{
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert)
{
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0)
{
// Ignore experts the node isn't responsible for with expert parallelism
const int expert = result_kvp.key;
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small power of 2.
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows)
{
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
{
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii)
{
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
{
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
{
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if (val > max_val)
{
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
{
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0)
{
// Add a guard to ignore experts not included by this node
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if (k_idx + 1 < k)
{
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group)
{
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indicies, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
const int num_experts,
const int topk,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
break;
case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
break;
case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
break;
case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
break;
case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
break;
case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
break;
case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
break;
case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
break;
case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
break;
default: {
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
num_experts, topk, 0, num_experts);
}
}
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output) // [num_tokens, num_experts]
{
const int num_experts = gating_output.size(-1);
const int num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
\ No newline at end of file
#pragma once
#include <torch/extension.h>
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
\ No newline at end of file
#!/bin/bash
# Set variables
AWQ_KERNELS_VERSION="0.0.6"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ_kernels/releases/tags/v${AWQ_KERNELS_VERSION}"
# Create a directory to download the wheels
mkdir -p dist
cd dist
# Download all the wheel files from the GitHub release
# excluding ones with '+cu' (%2B is + but encoded)
curl -s $RELEASE_URL | \
jq -r ".assets[].browser_download_url" | \
grep '\.whl' | \
grep -v '%2Bcu' | \
grep -v '%2Brocm' | \
xargs -n 1 -P 4 wget
# Rename the wheels from 'linux_x86_64' to 'manylinux_x86_64'
for file in *linux_x86_64.whl; do
mv "$file" "$(echo $file | sed 's/linux_x86_64/manylinux2014_x86_64/')"
done
cd ..
import os
import torch
import subprocess
from pathlib import Path
from setuptools import setup, find_packages
from distutils.sysconfig import get_python_lib
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from typing import Optional, Union
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
AUTOAWQ_KERNELS_VERSION = "0.0.6"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
ROCM_VERSION = os.environ.get("ROCM_VERSION", None) or torch.version.hip
def get_sha(pytorch_root: Union[str, Path]) -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=pytorch_root).decode('ascii').strip()
except Exception:
return 'Unknown'
def get_abi():
try:
command = "echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout.strip()
abi = "abi" + output.split(" ")[-1]
return abi
except Exception:
return 'abiUnknown'
def get_version_add(sha: Optional[str] = None) -> str:
version=''
autoawq_root = os.path.dirname(os.path.abspath(__file__))
add_version_path = os.path.join(os.path.join(autoawq_root, ""), "version.py")
if sha != 'Unknown':
if sha is None:
sha = get_sha(autoawq_root)
version = 'git' + sha[:7]
# abi
version += "." + get_abi()
# dtk version
if os.getenv("ROCM_PATH"):
rocm_path = os.getenv('ROCM_PATH', "")
rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
with open(rocm_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
rocm_version=lines[0][:-2].replace(".", "")
version += ".dtk" + rocm_version
# torch version
version += ".torch" + torch.__version__[:5]
lines=[]
with open(add_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
lines[1] = "__dcu_version__ = '0.0.6+das1.1.{}'\n".format(version)
with open(add_version_path, encoding="utf-8",mode="w") as file:
file.writelines(lines)
file.close()
def get_version():
get_version_add()
version_file = 'version.py'
with open(version_file, encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__dcu_version__']
if not PYPI_BUILD:
# only adding CUDA/ROCM version if we are not building for PyPI to comply with PEP 440
if CUDA_VERSION:
CUDA_VERSION = "".join(CUDA_VERSION.split("."))[:3]
AUTOAWQ_KERNELS_VERSION += f"+cu{CUDA_VERSION}"
elif ROCM_VERSION:
ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3]
#AUTOAWQ_KERNELS_VERSION += f"+rocm{ROCM_VERSION}"
AUTOAWQ_KERNELS_VERSION = get_version()
else:
raise RuntimeError(
"Your system must have either Nvidia or AMD GPU to build this package."
)
print(f"Building AutoAWQ Kernels version {AUTOAWQ_KERNELS_VERSION}")
common_setup_kwargs = {
"version": AUTOAWQ_KERNELS_VERSION,
"name": "autoawq_kernels",
"author": "Casper Hansen",
"license": "MIT",
"python_requires": ">=3.8.0",
"description": "AutoAWQ Kernels implements the AWQ kernels.",
"long_description": (Path(__file__).parent / "README.md").read_text(
encoding="UTF-8"
),
"long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ_kernels",
"keywords": ["awq", "autoawq", "quantization", "transformers"],
"platforms": ["linux", "windows"],
"classifiers": [
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
],
}
requirements = [
"torch>=2.0.1",
]
def get_include_dirs():
include_dirs = []
if CUDA_VERSION:
conda_cuda_include_dir = os.path.join(
get_python_lib(), "nvidia/cuda_runtime/include"
)
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
this_dir = os.path.dirname(os.path.abspath(__file__))
include_dirs.append(this_dir)
return include_dirs
def get_generator_flag():
generator_flag = []
# if CUDA_VERSION:
torch_dir = torch.__path__[0]
if os.path.exists(
os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")
):
generator_flag = ["-DOLD_GENERATOR_PATH"]
return generator_flag
def get_compute_capabilities(
compute_capabilities={75, 80, 86, 89, 90}
):
capability_flags = []
if CUDA_VERSION:
# Collect the compute capabilities of all available CUDA GPUs
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cc = major * 10 + minor
if cc < 75:
raise RuntimeError(
"GPUs with compute capability less than 7.5 are not supported."
)
# Figure out compute capability
for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
return capability_flags
def get_extra_compile_args(arch_flags, generator_flags):
extra_compile_args = {}
if os.name == "nt" and CUDA_VERSION:
include_arch = os.getenv("INCLUDE_ARCH", "1") == "1"
# Relaxed args on Windows
if include_arch:
extra_compile_args = {"nvcc": arch_flags}
elif CUDA_VERSION:
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
"nvcc": [
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ arch_flags
+ generator_flags,
}
return extra_compile_args
def get_extra_link_args():
extra_link_args = []
if os.name == "nt" and CUDA_VERSION:
cuda_path = os.environ.get("CUDA_PATH", None)
extra_link_args = ["-L", f"{cuda_path}/lib/x64/cublas.lib"]
return extra_link_args
include_dirs = get_include_dirs()
extra_link_args = get_extra_link_args()
generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities()
extra_compile_args = get_extra_compile_args(arch_flags, generator_flags)
extensions = []
if CUDA_VERSION:
# contain un-hipifiable inline PTX
extensions.append(
CUDAExtension(
"awq_ext",
[
"awq_ext/pybind_awq.cpp",
"awq_ext/quantization/gemm_cuda_gen.cu",
"awq_ext/layernorm/layernorm.cu",
"awq_ext/position_embedding/pos_encoding_kernels.cu",
"awq_ext/quantization/gemv_cuda.cu",
"awq_ext/vllm/moe_alig_block.cu",
"awq_ext/vllm/activation.cu",
"awq_ext/vllm/topk_softmax_kernels.cu",
],
extra_compile_args=extra_compile_args,
)
)
# only compatible with ampere
arch_flags = get_compute_capabilities({80, 86, 89, 90})
extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags)
extensions.append(
CUDAExtension(
"awq_v2_ext",
[
"awq_ext/pybind_awq_v2.cpp",
"awq_ext/quantization_new/gemv/gemv_cuda.cu",
"awq_ext/quantization_new/gemm/gemm_cuda.cu",
],
extra_compile_args=extra_compile_args_v2,
)
)
extensions.append(
CUDAExtension(
"exl_ext",
[
"awq_ext/exllama/exllama_ext.cpp",
"awq_ext/exllama/cuda_buffers.cu",
"awq_ext/exllama/cuda_func/column_remap.cu",
"awq_ext/exllama/cuda_func/q4_matmul.cu",
"awq_ext/exllama/cuda_func/q4_matrix.cu",
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)
extensions.append(
CUDAExtension(
"exlv2_ext",
[
"awq_ext/exllamav2/ext.cpp",
"awq_ext/exllamav2/cuda/q_matrix.cu",
"awq_ext/exllamav2/cuda/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
)
if os.name != "nt" and CUDA_VERSION:
# FasterTransformer kernels
extensions.append(
CUDAExtension(
"awq_ft_ext",
[
"awq_ext/pybind_awq_ft.cpp",
"awq_ext/attention/ft_attention.cpp",
"awq_ext/attention/decoder_masked_multihead_attention.cu",
],
extra_compile_args=extra_compile_args,
)
)
additional_setup_kwargs = {
"ext_modules": extensions,
"cmdclass": {"build_ext": BuildExtension},
}
common_setup_kwargs.update(additional_setup_kwargs)
setup(
packages=find_packages(),
install_requires=requirements,
include_dirs=include_dirs,
**common_setup_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment