Unverified Commit 1fe691a4 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Fix FP8 block quantization when N or K is not multiples of 128 (#8648)

parent e2521926
...@@ -955,16 +955,16 @@ static inline void check_moe_scales( ...@@ -955,16 +955,16 @@ static inline void check_moe_scales(
} }
} }
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ #define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \ auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \ auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \ auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \ int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \ int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ TORCH_CHECK(w1s.size(DIM0) == div_up(2 * N, block_size_N)); \
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ TORCH_CHECK(w1s.size(DIM1) == div_up(K, block_size_K)); \
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ TORCH_CHECK(w2s.size(DIM0) == div_up(K, block_size_N)); \
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) TORCH_CHECK(w2s.size(DIM1) == div_up(N, block_size_K))
// hidden_states: [M, K] // hidden_states: [M, K]
// w1: [E, 2N, K] // w1: [E, 2N, K]
......
...@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase): ...@@ -75,8 +75,8 @@ class TestFusedExperts(CustomTestCase):
topk_int8 = [3] topk_int8 = [3]
M_fp8 = [2, 121] M_fp8 = [2, 121]
N_fp8 = [512] N_fp8 = [352, 512]
K_fp8 = [256] K_fp8 = [256, 320]
E_fp8 = [8] E_fp8 = [8]
topk_fp8 = [4] topk_fp8 = [4]
...@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase): ...@@ -201,8 +201,14 @@ class TestFusedExperts(CustomTestCase):
w2_fp32 = torch.randn(E, K, N) w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale w1s = (
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)
w1_scaled = scaled_weight(w1, w1s) w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s) w2_scaled = scaled_weight(w2, w2s)
......
...@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto ...@@ -136,18 +136,33 @@ def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_facto
def scaled_weight(weight, scales): def scaled_weight(weight, scales):
E, N, K = weight.shape E, N, K = weight.shape
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_block = ( weight_block = (
weight.view(E, N // BLOCK_N, BLOCK_N, K // BLOCK_K, BLOCK_K) weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
.permute(0, 1, 3, 2, 4) .permute(0, 1, 3, 2, 4)
.float() .float()
.contiguous() .contiguous()
) )
return (
(weight_block * scales.view(E, N // BLOCK_N, K // BLOCK_K, 1, 1)) weight_scaled = (
(
weight_block
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
)
.permute(0, 1, 3, 2, 4) .permute(0, 1, 3, 2, 4)
.contiguous() .contiguous()
.view(E, N, K)
) )
if pad_N > 0 or pad_K > 0:
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
weight_scaled = weight_scaled[..., :N, :K].contiguous()
else:
weight_scaled = weight_scaled.view(E, N, K)
return weight_scaled
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize): def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment