"vscode:/vscode.git/clone" did not exist on "4254aeb56f280609e57e6b0134b3d6268d2fa87f"
Unverified Commit 59b2f7b6 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Fuse Zero Initializer for FP8 DeepGemm Block Quant Kernel (#39547)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
parent 92feb999
...@@ -240,8 +240,9 @@ template <typename T, typename DST_DTYPE> ...@@ -240,8 +240,9 @@ template <typename T, typename DST_DTYPE>
__global__ void per_token_group_quant_8bit_packed_kernel( __global__ void per_token_group_quant_8bit_packed_kernel(
const T* __restrict__ input, void* __restrict__ output_q, const T* __restrict__ input, void* __restrict__ output_q,
unsigned int* __restrict__ output_s_packed, const int group_size, unsigned int* __restrict__ output_s_packed, const int group_size,
const int num_groups, const int groups_per_block, const int groups_per_row, const int num_groups_padded, const int groups_per_block,
const int mn, const int tma_aligned_mn, const float eps, const int padded_groups_per_row, const int groups_per_row, const int mn,
const int tma_aligned_mn, const int num_scale_elems, const float eps,
const float min_8bit, const float max_8bit) { const float min_8bit, const float max_8bit) {
const int threads_per_group = 16; const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group; const int64_t local_group_id = threadIdx.x / threads_per_group;
...@@ -249,51 +250,62 @@ __global__ void per_token_group_quant_8bit_packed_kernel( ...@@ -249,51 +250,62 @@ __global__ void per_token_group_quant_8bit_packed_kernel(
const int64_t block_group_id = blockIdx.x * groups_per_block; const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id; const int64_t global_group_id = block_group_id + local_group_id;
if (global_group_id >= num_groups) { if (global_group_id >= num_groups_padded) {
return; return;
} }
const int64_t block_group_offset = global_group_id * group_size; // map flat group id to 2D indices (mn_idx, sf_k_idx)
const int sf_k_idx =
static_cast<int>(global_group_id % padded_groups_per_row);
const int mn_idx = static_cast<int>(global_group_id / padded_groups_per_row);
const T* group_input = input + block_group_offset; // whether it is a valid group (not padding)
DST_DTYPE* group_output = const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row);
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
// shared memory to cache each group's data to avoid double DRAM reads. // shared memory to cache each group's data to avoid double DRAM reads.
extern __shared__ __align__(16) char smem_raw[]; extern __shared__ __align__(16) char smem_raw[];
T* smem = reinterpret_cast<T*>(smem_raw); T* smem = reinterpret_cast<T*>(smem_raw);
T* smem_group = smem + local_group_id * group_size; T* smem_group = smem + local_group_id * group_size;
const float y_s =
ComputeGroupScale<T, true>(group_input, smem_group, group_size, lane_id,
threads_per_group, eps, max_8bit);
// pack 4 scales into a uint32 // compute scale for valid groups
if (lane_id == 0) { float y_s = 0.f;
// map flat group id to 2D indices (mn_idx, sf_k_idx) if (is_valid_group) {
const int sf_k_idx = static_cast<int>(global_group_id % groups_per_row); const T* group_input =
const int mn_idx = static_cast<int>(global_group_id / groups_per_row); input + static_cast<int64_t>(mn_idx) * groups_per_row * group_size +
sf_k_idx * group_size;
y_s = ComputeGroupScale<T, true>(group_input, smem_group, group_size,
lane_id, threads_per_group, eps, max_8bit);
}
if (mn_idx < mn) { // pack 4 scales into a uint32 exponent
if (lane_id == 0) {
// each uint32 in output_s_packed stores 4 packed scales // each uint32 in output_s_packed stores 4 packed scales
const int sf_k_pack_idx = sf_k_idx / 4; const int sf_k_pack_idx = sf_k_idx / 4;
const int pos = sf_k_idx % 4; const int pos = sf_k_idx % 4;
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx;
if (is_valid_group) {
// reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit // reinterpret the UE8M0 scale y_s as IEEE bits, extract the 8-bit
// exponent, and place it into the correct byte of the 32-bit word. // exponent, and place it into the correct byte of the 32-bit word.
const unsigned int bits = __float_as_uint(y_s); const unsigned int bits = __float_as_uint(y_s);
const unsigned int exponent = (bits >> 23u) & 0xffu; const uint8_t exponent = static_cast<uint8_t>((bits >> 23u) & 0xffu);
const unsigned int contrib = exponent << (pos * 8u); reinterpret_cast<uint8_t*>(output_s_packed)[out_idx * 4 + pos] = exponent;
} else if (out_idx < num_scale_elems) {
const int out_idx = sf_k_pack_idx * tma_aligned_mn + mn_idx; // write zero for padding groups if within bounds of output_s_packed
// atomically OR 8-bit exponent into the packed scales buffer reinterpret_cast<uint8_t*>(output_s_packed)[out_idx * 4 + pos] = 0;
atomicOr(output_s_packed + out_idx, contrib);
} }
} }
__syncthreads(); __syncthreads();
if (is_valid_group) {
DST_DTYPE* group_output =
static_cast<DST_DTYPE*>(output_q) +
static_cast<int64_t>(mn_idx) * groups_per_row * group_size +
sf_k_idx * group_size;
QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id, QuantizeGroup<T, DST_DTYPE>(smem_group, group_output, group_size, lane_id,
threads_per_group, y_s, min_8bit, max_8bit); threads_per_group, y_s, min_8bit, max_8bit);
}
} }
void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
...@@ -310,7 +322,6 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, ...@@ -310,7 +322,6 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
const int64_t mn = input.numel() / k; const int64_t mn = input.numel() / k;
const int64_t groups_per_row = k / group_size; const int64_t groups_per_row = k / group_size;
const int64_t num_groups = mn * groups_per_row;
STD_TORCH_CHECK(output_s_packed.dim() == 2, STD_TORCH_CHECK(output_s_packed.dim() == 2,
"output_s_packed must be 2D, got dim=", output_s_packed.dim(), "output_s_packed must be 2D, got dim=", output_s_packed.dim(),
...@@ -330,21 +341,30 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, ...@@ -330,21 +341,30 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
"output_s_packed shape must be [", mn, ", ", k_num_packed_sfk, "output_s_packed shape must be [", mn, ", ", k_num_packed_sfk,
"], but got [", output_s_packed.size(0), ", ", "], but got [", output_s_packed.size(0), ", ",
output_s_packed.size(1), "]."); output_s_packed.size(1), "].");
// Verify column-major TMA-aligned layout
STD_TORCH_CHECK(output_s_packed.stride(0) == 1 &&
output_s_packed.stride(1) == tma_aligned_mn,
"output_s_packed must have strides [1, ", tma_aligned_mn,
"], but got [", output_s_packed.stride(0), ", ",
output_s_packed.stride(1), "].");
cudaStream_t stream = get_current_cuda_stream(); cudaStream_t stream = get_current_cuda_stream();
constexpr int THREADS_PER_GROUP = 16; constexpr int THREADS_PER_GROUP = 16;
const int groups_per_block = GetGroupsPerBlock(num_groups); // Expand the grid to cover MN and K padding so every byte in
// output_s_packed is written (padding bytes get zeroed by the kernel).
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
// Number of elements in output_s_packed.
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
const int groups_per_block = GetGroupsPerBlock(num_groups_padded);
auto dst_type = output_q.scalar_type(); auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block; const int num_blocks = num_groups_padded / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP; const int num_threads = groups_per_block * THREADS_PER_GROUP;
// zero-initialize packed scales, since we use atomicOr to accumulate
// exponents from different groups.
torch::stable::zero_(output_s_packed);
#define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \ #define LAUNCH_PACKED_KERNEL(T, DST_DTYPE) \
do { \ do { \
dim3 grid(num_blocks); \ dim3 grid(num_blocks); \
...@@ -355,11 +375,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, ...@@ -355,11 +375,12 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
<<<grid, block, smem_bytes, stream>>>( \ <<<grid, block, smem_bytes, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \ static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \ reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
static_cast<int>(group_size), static_cast<int>(num_groups), \ static_cast<int>(group_size), static_cast<int>(num_groups_padded), \
groups_per_block, static_cast<int>(groups_per_row), \ groups_per_block, static_cast<int>(padded_groups_per_row), \
static_cast<int>(mn), static_cast<int>(tma_aligned_mn), \ static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<float>(eps), static_cast<float>(min_8bit), \ static_cast<int>(tma_aligned_mn), \
static_cast<float>(max_8bit)); \ static_cast<int>(num_scale_elems), static_cast<float>(eps), \
static_cast<float>(min_8bit), static_cast<float>(max_8bit)); \
} while (0) } while (0)
VLLM_STABLE_DISPATCH_FLOATING_TYPES( VLLM_STABLE_DISPATCH_FLOATING_TYPES(
......
...@@ -48,6 +48,116 @@ def test_per_token_group_quant_fp8( ...@@ -48,6 +48,116 @@ def test_per_token_group_quant_fp8(
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
@pytest.mark.parametrize(
"num_tokens,hidden_dim,group_size",
[
# No padding: mn=4 (mult of 4), groups_per_row=56 (mult of 4)
(4, 7168, 128),
# MN padding only: mn=1, tma_aligned_mn=4
(1, 7168, 128),
# MN padding only: mn=3, tma_aligned_mn=4
(3, 7168, 128),
# K padding only: groups_per_row=5 (5%4=1)
(4, 640, 128),
# K padding only: groups_per_row=6 (6%4=2)
(4, 768, 128),
# Single packed column, no padding: k_num_packed=1, mn%4=0
(4, 384, 128),
# Both MN and K padding
(1, 384, 128),
(3, 640, 128),
# Larger shapes with no padding
(64, 7168, 128),
(128, 14336, 128),
# Larger shapes with padding
(127, 7168, 128),
(253, 640, 128),
# Non-power-of-2 group size
(4, 768, 96), # 768/96=8 groups, no padding
(3, 768, 96), # 768/96=8 groups, MN padding
(4, 480, 96), # 480/96=5 groups, K padding
(1, 480, 96), # both MN and K padding
],
)
@pytest.mark.parametrize("poisoned_scales", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8_packed(
num_tokens, hidden_dim, group_size, poisoned_scales
):
"""Test the packed DeepGEMM quantization kernel against the Triton
reference (row-major, UE8M0 scales)."""
device = "cuda"
torch.manual_seed(42)
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
mn = num_tokens
groups_per_row = hidden_dim // group_size
k_num_packed = (groups_per_row + 3) // 4
tma_aligned_mn = ((mn + 3) // 4) * 4
num_scale_elems = mn + (k_num_packed - 1) * tma_aligned_mn
if poisoned_scales:
# Call the kernel with poisoned scale buffer to
# ensure padded indices are correctly zeroed.
fp8_dtype = torch.float8_e4m3fn
finfo = torch.finfo(fp8_dtype)
out_q = torch.empty_like(x, dtype=fp8_dtype)
out_s_packed = torch.empty_strided(
(mn, k_num_packed),
(1, tma_aligned_mn),
device=device,
dtype=torch.int32,
)
torch.as_strided(out_s_packed, (num_scale_elems,), (1,)).fill_(0x7F7F7F7F)
torch.ops._C.per_token_group_fp8_quant_packed(
x,
out_q,
out_s_packed,
group_size,
1e-10,
finfo.min,
finfo.max,
)
else:
out_q, out_s_packed = fp8_utils.per_token_group_quant_fp8_packed_for_deepgemm(
x,
group_size=group_size,
use_ue8m0=True,
)
# Triton reference (row-major float32 scales, UE8M0)
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
x,
group_size,
use_ue8m0=True,
)
# Quantized values must match.
assert torch.equal(out_q, ref_q), "Quantized output mismatch"
# Verify packed scales (valid exponents + padding zeros).
ref_s_flat = ref_s.reshape(mn, groups_per_row)
ref_exponents = (ref_s_flat.view(torch.int32) >> 23) & 0xFF
expected = torch.zeros(num_scale_elems, dtype=torch.int32, device="cpu")
for row in range(mn):
for g in range(groups_per_row):
pack_col = g // 4
pos = g % 4
idx = pack_col * tma_aligned_mn + row
expected[idx] |= int(ref_exponents[row, g].item()) << (pos * 8)
actual = torch.as_strided(out_s_packed, (num_scale_elems,), (1,)).cpu()
assert torch.equal(actual, expected), (
f"Packed scale storage mismatch.\n"
f"First diff at index "
f"{(actual != expected).nonzero(as_tuple=True)[0][0].item()}"
)
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) @pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
@pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment