Unverified Commit b3251e9f authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

refine quant kernel code style (#4211)

parent 2cadd51d
...@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output ...@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
max_value = fmaxf(max_value, fabsf(val)); max_value = fmaxf(max_value, fabsf(val));
} }
static __shared__ float warpLevelMaxs[WARP_SIZE]; max_value = blockReduceMax(max_value);
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
if (tid == 0) { if (tid == 0) {
atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX); atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX);
......
...@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel( ...@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
max_value = fmaxf(max_value, fabsf(val)); max_value = fmaxf(max_value, fabsf(val));
} }
max_value = warpReduceMax(max_value); max_value = blockReduceMax(max_value);
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
if (warpId == 0) {
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
max_value = warpReduceMax(max_value);
}
__shared__ float block_max; __shared__ float block_max;
if (tid == 0) { if (tid == 0) {
......
...@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) { ...@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1)); max_value = fmaxf(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
return max_value; return max_value;
} }
__device__ __forceinline__ float blockReduceMax(float max_value) {
static __shared__ float warpLevelMaxs[WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
max_value = warpReduceMax(max_value);
if (laneId == 0) warpLevelMaxs[warpId] = max_value;
__syncthreads();
max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
if (warpId == 0) max_value = warpReduceMax(max_value);
return max_value;
}
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment