Unverified Commit b3251e9f authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

refine quant kernel code style (#4211)

parent 2cadd51d
......@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
max_value = fmaxf(max_value, fabsf(val));
}
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
max_value = blockReduceMax(max_value);
if (tid == 0) {
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
......
......@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
max_value = fmaxf(max_value, fabsf(val));
}
max_value = warpReduceMax(max_value);
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
if (warpId == 0) {
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
max_value = warpReduceMax(max_value);
}
max_value = blockReduceMax(max_value);
__shared__ float block_max;
if (tid == 0) {
......
......@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value;
}
__device__ __forceinline__ float blockReduceMax(float max_value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
return max_value;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment