// !!! This is a file automatically generated by hipify!!! #include "hip/hip_runtime.h" #include #include "custom_hip_layers.h" namespace cg = cooperative_groups; __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; float2* vals_cast = reinterpret_cast(vals); float2 data[MAX_REG]; int group_id = blockIdx.x; { int group_index = id; int reg_count = 0; int offset = group_id * group_size; float max = -10000.0; while (group_index < group_size && reg_count < MAX_REG) { data[reg_count] = vals_cast[offset + group_index]; __half* data_h = reinterpret_cast<__half*>(&data[reg_count]); if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]); if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]); if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]); if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]); group_index += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } __shared__ float partialMax[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; b.sync(); if (lane < warp_num) max = partialMax[lane]; #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } max = g.shfl(max, 0); float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + id; if (group_index < group_size) { __half2* data_h = reinterpret_cast<__half2*>(&data[i]); float2 q_data[2]; q_data[0] = __half22float2(data_h[0]); q_data[1] = __half22float2(data_h[1]); float2 q_data_int[2]; q_data_int[0].x = roundf(q_data[0].x * q_scale); q_data_int[0].y = roundf(q_data[0].y * q_scale); q_data_int[1].x = roundf(q_data[1].x * q_scale); q_data_int[1].y = roundf(q_data[1].y * q_scale); q_data_int[0].x *= q_scale_inv; q_data_int[0].y *= q_scale_inv; q_data_int[1].x *= q_scale_inv; q_data_int[1].y *= q_scale_inv; data_h[0] = __float22half2_rn(q_data_int[0]); data_h[1] = __float22half2_rn(q_data_int[1]); vals_cast[offset + group_index] = data[i]; } } } #endif } __global__ void quantize_kernel(float* vals, int group_size, int num_bits) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; float4* vals_cast = reinterpret_cast(vals); float4 data[MAX_REG]; int bid = blockIdx.x; int group_index = bid * group_size + id; int reg_count = 0; float max = -10000.0; while (id < group_size && reg_count < MAX_REG) { float4 data_reg = vals_cast[group_index]; data[reg_count] = data_reg; if (abs(data_reg.x) > max) max = abs(data_reg.x); if (abs(data_reg.y) > max) max = abs(data_reg.y); if (abs(data_reg.z) > max) max = abs(data_reg.z); if (abs(data_reg.w) > max) max = abs(data_reg.w); group_index += blockDim.x; id += blockDim.x; reg_count++; } id = threadIdx.x; #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } __shared__ float partialMax[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; b.sync(); if (lane < warp_num) max = partialMax[lane]; b.sync(); #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } max = g.shfl(max, 0); float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + id; if (group_index < group_size) { float4 q_data; q_data = data[i]; float4 q_data_int; q_data_int.x = roundf(q_data.x * q_scale); q_data_int.y = roundf(q_data.y * q_scale); q_data_int.w = roundf(q_data.w * q_scale); q_data_int.z = roundf(q_data.z * q_scale); q_data.x = q_data_int.x * q_scale_inv; q_data.y = q_data_int.y * q_scale_inv; q_data.w = q_data_int.w * q_scale_inv; q_data.z = q_data_int.z * q_scale_inv; vals_cast[group_index + bid * group_size] = q_data; } } } template void launch_quantize_kernel(T* vals, int total_count, int group_num, int num_bits, hipStream_t stream) { dim3 grid_dim(group_num); dim3 block_dim(1024); hipLaunchKernelGGL(( quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, vals, (total_count / group_num) / 4, num_bits); } template void launch_quantize_kernel(float* vals, int total_count, int group_num, int num_bits, hipStream_t stream); template void launch_quantize_kernel(__half* vals, int total_count, int group_num, int num_bits, hipStream_t stream); __global__ void sr_quantize_kernel(__half* vals, int token_size, int token_num, int num_bits, std::pair seed) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int idx = blockIdx.x * blockDim.x + threadIdx.x; float2* vals_cast = reinterpret_cast(vals); __half2 data_low[128]; __half2 data_high[128]; int bid = blockIdx.x; hiprandStatePhilox4_32_10_t state; hiprand_init(seed.first, idx, seed.second, &state); unsigned int tid = threadIdx.x; int reg_count = 0; int offset = bid * token_size; int group_index = bid * token_size + tid; int total_count = token_size * token_num; if (group_index < total_count) { // float min = 10000.0; float max = -10000.0; while (tid < token_size) { float2 data = vals_cast[offset + tid]; __half2* data_h = reinterpret_cast<__half2*>(&data); data_low[reg_count] = data_h[0]; data_high[reg_count] = data_h[1]; float2 data_f[2]; data_f[0] = __half22float2(data_h[0]); data_f[1] = __half22float2(data_h[1]); if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x); if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y); if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x); if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y); tid += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } __shared__ float partialMax[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; b.sync(); if (lane < warp_num) max = partialMax[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } max = g.shfl(max, 0); float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5); float high_q = (float)((1 << (num_bits - 1)) - 1); float low_q = (float)(-((1 << (num_bits - 1)))); for (int i = 0; i < reg_count; i++) { int token_index = i * blockDim.x + threadIdx.x; if (token_index < token_size) { float2 data_f[2]; data_f[0] = __half22float2(data_low[i]); data_f[1] = __half22float2(data_high[i]); float2 q_data_int[2]; q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val)); q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val)); q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val)); q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val)); // Stochastic rounding float4 rand = hiprand_uniform4(&state); float q_error[4]; q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val; q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val; q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val; q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val; q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q) ? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1)) : q_data_int[0].x; q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q) ? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1)) : q_data_int[0].y; q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q) ? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1)) : q_data_int[1].x; q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q) ? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1)) : q_data_int[1].y; data_f[0].x = q_data_int[0].x / q_scale_val; data_f[0].y = q_data_int[0].y / q_scale_val; data_f[1].x = q_data_int[1].x / q_scale_val; data_f[1].y = q_data_int[1].y / q_scale_val; float2 result; __half2* result_h = reinterpret_cast<__half2*>(&result); result_h[0] = __float22half2_rn(data_f[0]); result_h[1] = __float22half2_rn(data_f[1]); vals_cast[offset + token_index] = result; } } } #endif } __global__ void sr_quantize_kernel(float* vals, int token_size, int token_num, int num_bits, std::pair seed) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; int idx = blockIdx.x * blockDim.x + id; float4* vals_cast = reinterpret_cast(vals); float4 data[128]; int bid = blockIdx.x; int tid = threadIdx.x; hiprandStatePhilox4_32_10_t state; hiprand_init(seed.first, idx, seed.second, &state); int group_index = bid * token_size + threadIdx.x; int reg_count = 0; int total_count = token_size * token_num; if (group_index < total_count) { // float min = 10000.0; float max = -10000.0; while (tid < token_size) { data[reg_count] = vals_cast[group_index]; if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x); if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y); if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z); if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w); group_index += blockDim.x; tid += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } __shared__ float partialMax[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; b.sync(); if (lane < warp_num) max = partialMax[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } max = g.shfl(max, 0); float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5); float high_q = (float)((1 << (num_bits - 1)) - 1); float low_q = (float)(-((1 << (num_bits - 1)))); int offset = (bid)*token_size; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + threadIdx.x; if (group_index < token_size) { float4 q_data = data[i]; float4 q_data_int; q_data_int.x = (float)((int)(q_data.x * q_scale_val)); q_data_int.y = (float)((int)(q_data.y * q_scale_val)); q_data_int.w = (float)((int)(q_data.w * q_scale_val)); q_data_int.z = (float)((int)(q_data.z * q_scale_val)); // Stochastic rounding float4 rand = hiprand_uniform4(&state); float q_error[4]; q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val; q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val; q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val; q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val; q_data_int.x = (rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q) ? (q_data_int.x + (q_data.x > 0 ? 1 : -1)) : q_data_int.x; q_data_int.y = (rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q) ? (q_data_int.y + (q_data.y > 0 ? 1 : -1)) : q_data_int.y; q_data_int.w = (rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q) ? (q_data_int.w + (q_data.w > 0 ? 1 : -1)) : q_data_int.w; q_data_int.z = (rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q) ? (q_data_int.z + (q_data.z > 0 ? 1 : -1)) : q_data_int.z; q_data_int.x /= q_scale_val; q_data_int.y /= q_scale_val; q_data_int.w /= q_scale_val; q_data_int.z /= q_scale_val; vals_cast[group_index + offset] = q_data_int; } } } } template void launch_sr_quantize_kernel(T* vals, int total_count, int group_num, int num_bits, hipStream_t stream) { dim3 block_dim(1024); dim3 grid_dim(group_num); uint64_t inc = total_count / grid_dim.x / block_dim.x; std::pair seed = Context::Instance().IncrementOffset(inc); hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, vals, (total_count / group_num) / 4, group_num, num_bits, seed); } template void launch_sr_quantize_kernel(float* vals, int total_count, int group_num, int num_bits, hipStream_t stream); template void launch_sr_quantize_kernel(__half* vals, int total_count, int group_num, int num_bits, hipStream_t stream); __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; float2* vals_cast = reinterpret_cast(vals); float2 data[MAX_REG]; int group_id = blockIdx.x; { int group_index = id; int reg_count = 0; int offset = group_id * group_size; float max = -10000.0; float min = 10000.0; while (group_index < group_size && reg_count < MAX_REG) { data[reg_count] = vals_cast[offset + group_index]; __half* data_h = reinterpret_cast<__half*>(&data[reg_count]); if (((float)data_h[0]) > max) max = (float)data_h[0]; if (((float)data_h[1]) > max) max = (float)data_h[1]; if (((float)data_h[2]) > max) max = (float)data_h[2]; if (((float)data_h[3]) > max) max = (float)data_h[3]; if (((float)data_h[0]) < min) min = (float)data_h[0]; if (((float)data_h[1]) < min) min = (float)data_h[1]; if (((float)data_h[2]) < min) min = (float)data_h[2]; if (((float)data_h[3]) < min) min = (float)data_h[3]; group_index += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(min, i); if (min > temp) min = temp; } __shared__ float partialMax[WARP_SIZE]; __shared__ float partialMin[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; b.sync(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(min, i); if (min > temp) min = temp; } max = g.shfl(max, 0); min = g.shfl(min, 0); float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + id; if (group_index < group_size) { __half2* data_h = reinterpret_cast<__half2*>(&data[i]); float2 q_data[2]; q_data[0] = __half22float2(data_h[0]); q_data[1] = __half22float2(data_h[1]); float2 q_data_int[2]; q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv); q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv); q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv); q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv); q_data_int[0].x = q_data_int[0].x * q_scale + min; q_data_int[0].y = q_data_int[0].y * q_scale + min; q_data_int[1].x = q_data_int[1].x * q_scale + min; q_data_int[1].y = q_data_int[1].y * q_scale + min; data_h[0] = __float22half2_rn(q_data_int[0]); data_h[1] = __float22half2_rn(q_data_int[1]); vals_cast[offset + group_index] = data[i]; } } } #endif } __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; float4* vals_cast = reinterpret_cast(vals); float4 data[MAX_REG]; int bid = blockIdx.x; int group_index = bid * group_size + id; int reg_count = 0; float max = -10000.0; float min = 10000.0; while (id < group_size && reg_count < MAX_REG) { float4 data_reg = vals_cast[group_index]; data[reg_count] = data_reg; if (data_reg.x > max) max = data_reg.x; if (data_reg.y > max) max = data_reg.y; if (data_reg.w > max) max = data_reg.w; if (data_reg.z > max) max = data_reg.z; if (data_reg.x < min) min = data_reg.x; if (data_reg.y < min) min = data_reg.y; if (data_reg.w < min) min = data_reg.w; if (data_reg.z < min) min = data_reg.z; group_index += blockDim.x; id += blockDim.x; reg_count++; } id = threadIdx.x; #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(min, i); if (min > temp) min = temp; } __shared__ float partialMax[WARP_SIZE]; __shared__ float partialMin[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; b.sync(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(min, i); if (min > temp) min = temp; } max = g.shfl(max, 0); min = g.shfl(min, 0); float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits); float q_scale_inv = 1 / q_scale; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + id; if (group_index < group_size) { float4 q_data; q_data = data[i]; float4 q_data_int; q_data_int.x = roundf((q_data.x - min) * q_scale_inv); q_data_int.y = roundf((q_data.y - min) * q_scale_inv); q_data_int.w = roundf((q_data.w - min) * q_scale_inv); q_data_int.z = roundf((q_data.z - min) * q_scale_inv); q_data.x = q_data_int.x * q_scale + min; q_data.y = q_data_int.y * q_scale + min; q_data.w = q_data_int.w * q_scale + min; q_data.z = q_data_int.z * q_scale + min; vals_cast[group_index + bid * group_size] = q_data; } } } template void launch_quantize_kernel_asym(T* vals, int total_count, int group_num, int num_bits, hipStream_t stream) { dim3 grid_dim(group_num); dim3 block_dim(1024); hipLaunchKernelGGL(( quantize_kernel_asym), dim3(grid_dim), dim3(block_dim), 0, stream, vals, (total_count / group_num) / 4, num_bits); } template void launch_quantize_kernel_asym(float* vals, int total_count, int group_num, int num_bits, hipStream_t stream); template void launch_quantize_kernel_asym(__half* vals, int total_count, int group_num, int num_bits, hipStream_t stream); __global__ void sr_quantize_kernel_asym(__half* vals, int token_size, int token_num, int num_bits, std::pair seed) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int idx = blockIdx.x * blockDim.x + threadIdx.x; float2* vals_cast = reinterpret_cast(vals); __half2 data_low[128]; __half2 data_high[128]; int bid = blockIdx.x; hiprandStatePhilox4_32_10_t state; hiprand_init(seed.first, idx, seed.second, &state); unsigned int tid = threadIdx.x; int reg_count = 0; int offset = bid * token_size; int group_index = bid * token_size + tid; int total_count = token_size * token_num; if (group_index < total_count) { float min = 10000.0; float max = -10000.0; while (tid < token_size) { float2 data = vals_cast[offset + tid]; __half2* data_h = reinterpret_cast<__half2*>(&data); data_low[reg_count] = data_h[0]; data_high[reg_count] = data_h[1]; float2 data_f[2]; data_f[0] = __half22float2(data_h[0]); data_f[1] = __half22float2(data_h[1]); if (((float)data_f[0].x) > max) max = (float)data_f[0].x; if (((float)data_f[0].y) > max) max = (float)data_f[0].y; if (((float)data_f[1].x) > max) max = (float)data_f[1].x; if (((float)data_f[1].y) > max) max = (float)data_f[1].y; if (((float)data_f[0].x) < min) min = (float)data_f[0].x; if (((float)data_f[0].y) < min) min = (float)data_f[0].y; if (((float)data_f[1].x) < min) min = (float)data_f[1].x; if (((float)data_f[1].y) < min) min = (float)data_f[1].y; tid += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(min, i); if (min > temp) min = temp; } __shared__ float partialMax[WARP_SIZE]; __shared__ float partialMin[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; b.sync(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(min, i); if (min > temp) min = temp; } max = g.shfl(max, 0); min = g.shfl(min, 0); float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits); float q_scale_val_inv = 1 / q_scale_val; float high_q = (float)((1 << num_bits) - 1); for (int i = 0; i < reg_count; i++) { int token_index = i * blockDim.x + threadIdx.x; if (token_index < token_size) { float2 data_f[2]; data_f[0] = __half22float2(data_low[i]); data_f[1] = __half22float2(data_high[i]); float2 q_data_int[2]; q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv)); q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv)); q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv)); q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv)); // Stochastic rounding float4 rand = hiprand_uniform4(&state); float q_error[4]; q_error[0] = abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv; q_error[1] = abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv; q_error[2] = abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv; q_error[3] = abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv; q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q) ? (q_data_int[0].x + 1) : q_data_int[0].x; q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q) ? (q_data_int[0].y + 1) : q_data_int[0].y; q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q) ? (q_data_int[1].x + 1) : q_data_int[1].x; q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q) ? (q_data_int[1].y + 1) : q_data_int[1].y; data_f[0].x = q_data_int[0].x * q_scale_val + min; data_f[0].y = q_data_int[0].y * q_scale_val + min; data_f[1].x = q_data_int[1].x * q_scale_val + min; data_f[1].y = q_data_int[1].y * q_scale_val + min; float2 result; __half2* result_h = reinterpret_cast<__half2*>(&result); result_h[0] = __float22half2_rn(data_f[0]); result_h[1] = __float22half2_rn(data_f[1]); vals_cast[offset + token_index] = result; } } } #endif } __global__ void sr_quantize_kernel_asym(float* vals, int token_size, int token_num, int num_bits, std::pair seed) { cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; int idx = blockIdx.x * blockDim.x + id; float4* vals_cast = reinterpret_cast(vals); float4 data[128]; int bid = blockIdx.x; int tid = threadIdx.x; hiprandStatePhilox4_32_10_t state; hiprand_init(seed.first, idx, seed.second, &state); int group_index = bid * token_size + threadIdx.x; int reg_count = 0; int total_count = token_size * token_num; if (group_index < total_count) { float min = 10000.0; float max = -10000.0; while (tid < token_size) { float4 data_reg = vals_cast[group_index]; data[reg_count] = data_reg; if (data_reg.x > max) max = data_reg.x; if (data_reg.y > max) max = data_reg.y; if (data_reg.w > max) max = data_reg.w; if (data_reg.z > max) max = data_reg.z; if (data_reg.x < min) min = data_reg.x; if (data_reg.y < min) min = data_reg.y; if (data_reg.w < min) min = data_reg.w; if (data_reg.z < min) min = data_reg.z; group_index += blockDim.x; tid += blockDim.x; reg_count++; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(min, i); if (min > temp) min = temp; } __shared__ float partialMax[WARP_SIZE]; __shared__ float partialMin[WARP_SIZE]; if (lane == 0) partialMax[gid] = max; if (lane == 0) partialMin[gid] = min; b.sync(); if (lane < warp_num) max = partialMax[lane]; if (lane < warp_num) min = partialMin[lane]; #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(max, i); if (max < temp) max = temp; } #pragma unroll for (int i = 1; i < warp_num; i <<= 1) { auto temp = g.shfl_down(min, i); if (min > temp) min = temp; } max = g.shfl(max, 0); min = g.shfl(min, 0); float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits); float high_q = (float)((1 << num_bits) - 1); int offset = (bid)*token_size; for (int i = 0; i < reg_count; i++) { group_index = i * blockDim.x + threadIdx.x; if (group_index < token_size) { float4 q_data = data[i]; float4 q_data_int; q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val)); q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val)); q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val)); q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val)); // Stochastic rounding float4 rand = hiprand_uniform4(&state); float q_error[4]; q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val; q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val; q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val; q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val; q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1) : q_data_int.x; q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1) : q_data_int.y; q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1) : q_data_int.w; q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1) : q_data_int.z; q_data_int.x = q_data_int.x * q_scale_val + min; q_data_int.y = q_data_int.y * q_scale_val + min; q_data_int.w = q_data_int.w * q_scale_val + min; q_data_int.z = q_data_int.z * q_scale_val + min; vals_cast[group_index + offset] = q_data_int; } } } } template void launch_sr_quantize_kernel_asym(T* vals, int total_count, int group_num, int num_bits, hipStream_t stream) { dim3 block_dim(1024); dim3 grid_dim(group_num); uint64_t inc = total_count / grid_dim.x / block_dim.x; std::pair seed = Context::Instance().IncrementOffset(inc); hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, vals, (total_count / group_num) / 4, group_num, num_bits, seed); } template void launch_sr_quantize_kernel_asym(float* vals, int total_count, int group_num, int num_bits, hipStream_t stream); template void launch_sr_quantize_kernel_asym(__half* vals, int total_count, int group_num, int num_bits, hipStream_t stream);