Unverified Commit fda47926 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)

parent a0b22f2f
...@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8( ...@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
if not is_sm100_supported():
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
w1_scale = w1_scale.contiguous()
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
...@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8( ...@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
silu_and_mul(c1, intermediate) silu_and_mul(c1, intermediate)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
if not is_sm100_supported():
a2_scale = per_group_transpose(a2_scale, expert_offsets)
w2_scale = w2_scale.contiguous()
fp8_blockwise_scaled_grouped_mm( fp8_blockwise_scaled_grouped_mm(
c2, c2,
......
...@@ -8,6 +8,15 @@ from transformers import AutoConfig ...@@ -8,6 +8,15 @@ from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def get_model_config(tp_size: int): def get_model_config(tp_size: int):
...@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
# --- Input Data --- # --- Input Data ---
# Use bf16/fp16 for input activation based on model config # Use bf16/fp16 for input activation based on model config
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001 x = torch.randn((batch_size, H), device="cuda", dtype=dtype)
# --- Weights (Generate in higher precision, then convert to FP8) --- # --- Weights (Generate in higher precision, then convert to FP8) ---
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately) # Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
w1_hp = ( w1_hp = torch.randn((E, I, H), device="cuda", dtype=torch.float32)
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001 w2_hp = torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32)
)
w2_hp = (
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
+ 0.00001
)
w1 = to_fp8(w1_hp) w1 = to_fp8(w1_hp)
w2 = to_fp8(w2_hp) w2 = to_fp8(w2_hp)
...@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
) )
# Note: Triton expects non-transposed weights # Note: Triton expects non-transposed weights
moe_config = MoeRunnerConfig(inplace=False)
triton_lambda = lambda: fused_experts( triton_lambda = lambda: fused_experts(
x, x,
w1, w1,
w2, w2,
(topk_weights, topk_ids, "dummy"), (topk_weights, topk_ids, "dummy"),
inplace=False, moe_config,
activation="silu", # Assuming SiLU activation common in MoEs
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
...@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
w1, # Original shape w1, # Original shape
w2, # Original shape w2, # Original shape
(topk_weights, topk_ids, "dummy"), (topk_weights, topk_ids, "dummy"),
inplace=False, # Important: Use False to get output tensor moe_config,
activation="silu",
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
block_shape=block_shape, block_shape=block_shape,
) )
# Ensure outputs are same dtype for comparison diff = calc_diff(y_cutlass, y_triton)
y_cutlass = y_cutlass.to(dtype) print(f"Diff: {diff:.6f}")
y_triton = y_triton.to(dtype)
abs_error = torch.abs(y_cutlass - y_triton)
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
max_abs_err = abs_error.max().item()
max_rel_err = rel_error.max().item()
print("y_cutlass:", y_cutlass[:, :10])
print("y_triton:", y_triton[:, :10])
print(f"Max absolute error: {max_abs_err:.6f}")
print(f"Max relative error: {max_rel_err:.6f}")
# Tolerance might need adjustment based on FP8 specifics and kernel differences # Tolerance might need adjustment based on FP8 specifics and kernel differences
# FP8 comparisons often require higher tolerance than FP16/BF16 # FP8 comparisons often require higher tolerance than FP16/BF16
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}" assert diff < 1e-4, f"Diff too high! {diff}"
print("Correctness check passed.") print("Correctness check passed.")
...@@ -264,7 +255,21 @@ if __name__ == "__main__": ...@@ -264,7 +255,21 @@ if __name__ == "__main__":
"--batch-sizes", "--batch-sizes",
type=int, type=int,
nargs="+", nargs="+",
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default default=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
], # Adjusted default
help="List of batch sizes to test", help="List of batch sizes to test",
) )
parser.add_argument("--check", action="store_true", help="Enable check mode") parser.add_argument("--check", action="store_true", help="Enable check mode")
......
...@@ -45,7 +45,7 @@ include(FetchContent) ...@@ -45,7 +45,7 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG 664c4f7b3ed1959414905025728eef5568209479 GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-cutlass) FetchContent_Populate(repo-cutlass)
......
...@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) { const torch::Tensor& workspace) {
struct MmaConfig0 { struct MmaConfigSmallM {
// Swap A/B
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>; using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>; using ClusterShape = Shape<_2, _1, _1>;
// TODO: Check Pingpong or Cooperative
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
struct MmaConfig1 { struct MmaConfigH20LargeK {
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>; using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
// [NOTE] default for H20 struct MmaConfigHx00AndH20SmallK {
struct MmaConfigH20_default {
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>; using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
...@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
const std::string H20_device_type_str = "NVIDIA H20"; const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = isDeviceType(H20_device_type_str); bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
if (is_h20_device) { if (a.size(0) <= 2048) {
using execute_gemm_config = MmaConfigH20_default; run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
run_get_group_gemm_starts<
execute_gemm_config::LayoutSFA,
execute_gemm_config::LayoutSFB,
execute_gemm_config::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
out_ptrs, out_ptrs,
a_scales_ptrs, a_scales_ptrs,
b_scales_ptrs, b_scales_ptrs,
a, b_t,
b, a_t,
output, output_t,
scales_a, scales_b_t,
scales_b, scales_a_t,
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
problem_sizes_transpose); problem_sizes_transpose,
true);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>( launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
stride_c, stride_c,
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
problem_sizes, problem_sizes_transpose,
expert_offsets, expert_offsets,
workspace); workspace);
output = output_t.t();
} else { } else {
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { if (is_h20_device && a.size(1) > 128) {
// For H20 with K > 128, use Pingpong Schedule // For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>( run_get_group_gemm_starts<
MmaConfigH20LargeK::LayoutSFA,
MmaConfigH20LargeK::LayoutSFB,
MmaConfigH20LargeK::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
problem_sizes_transpose); problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>( launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
workspace); workspace);
} else { } else {
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>( run_get_group_gemm_starts<
MmaConfigHx00AndH20SmallK::LayoutSFA,
MmaConfigHx00AndH20SmallK::LayoutSFB,
MmaConfigHx00AndH20SmallK::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( ...@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
problem_sizes_transpose); problem_sizes_transpose);
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>( launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
......
...@@ -5,10 +5,6 @@ import pytest ...@@ -5,10 +5,6 @@ import pytest
import torch import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm from sgl_kernel import fp8_blockwise_scaled_grouped_mm
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8_hopper_moe_mn_major,
)
def cdiv(a: int, b: int) -> int: def cdiv(a: int, b: int) -> int:
return -(a // -b) return -(a // -b)
...@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool: ...@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool:
not (is_sm100_supported() or is_sm90_supported()), not (is_sm100_supported() or is_sm90_supported()),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90", reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
) )
@pytest.mark.parametrize("num_experts", [8, 16]) @pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("use_custom_kernel", [True, False]) def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
cc = torch.cuda.get_device_capability(None)[0]
if cc == 10 and use_custom_kernel:
return
device = "cuda" device = "cuda"
alignment = 16 alignment = 128
n_g = alignment * random.randint(1, 5) * 128 n_g = random.randint(1, 64) * 128
k_g = alignment * random.randint(1, 5) * 128 k_g = random.randint(1, 64) * 128
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
a_original_tensors = []
a_tensors = [] a_tensors = []
b_tensors = [] b_tensors = []
a_scales_tensors = [] a_scales_tensors = []
...@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern ...@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline_tensors = [] baseline_tensors = []
for g in range(num_experts): for g in range(num_experts):
m_g = alignment * random.randint(1, 64) m_g = random.randint(1, 256)
expert_offsets[g + 1] = expert_offsets[g] + m_g expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
...@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern ...@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
b_g, b_scale = per_block_cast_to_fp8( b_g, b_scale = per_block_cast_to_fp8(
b b
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1) ) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
a_original_tensors.append(a)
a_tensors.append(a_g) a_tensors.append(a_g)
b_tensors.append(b_g) b_tensors.append(b_g)
a_scales_tensors.append(a_scale) a_scales_tensors.append(a_scale)
...@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern ...@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline = torch.mm(a, b) baseline = torch.mm(a, b)
baseline_tensors.append(baseline) baseline_tensors.append(baseline)
a_original_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=out_dtype
)
a_stack = torch.empty( a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
) )
...@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern ...@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
) )
a_scale_stack = torch.empty( a_scale_stack = torch.empty(
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32 (expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32
) )
b_scale_stack = torch.empty( b_scale_stack = torch.empty(
(num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32 (num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
) )
for g in range(num_experts): for g in range(num_experts):
# Matrix A is Row-Major. # Matrix A is Row-Major.
a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = ( a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[
a_original_tensors[g]
)
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
g g
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1) ] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1)
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1) b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
if cc == 9:
# For SM90, we need MN-Major scale factor # We need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1) a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1) g
a_scale_stack[ ]
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
] = (a_scales_tensors[g].t().contiguous().view(-1))
b_scale_stack[g] = b_scales_tensors[g] # b_scale_stack[g] -- (k, n):(n, 1)
elif cc == 10:
# For SM100, we need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
a_scale_stack[
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
] = a_scales_tensors[g].view(-1)
b_scale_stack[g] = b_scales_tensors[ b_scale_stack[g] = b_scales_tensors[
g g
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later ].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
if cc == 10: b_scale_stack = b_scale_stack.transpose(1, 2)
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
if use_custom_kernel:
# Replace a_stack, a_scale_stack with custom kernel output
a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
a_original_stack,
expert_offsets[:-1],
problem_sizes,
128,
expert_tokens_alignment=alignment,
)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full( a_strides = torch.full(
...@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern ...@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
diff = calc_diff(actual, baseline) diff = calc_diff(actual, baseline)
assert diff < 0.001 assert diff < 0.001
print( print(
f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK" f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment