"vscode:/vscode.git/clone" did not exist on "7e7376eb4a1a696046d008462d30f8ad541ed8f9"
Unverified Commit 5c34b4f1 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] [2/N] Optimize `silu_and_mul_scaled_fp4_grouped_quant` perf (#9556)

parent ff9b5618
import argparse
import itertools
import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
)
torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")
NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()
test_accuracy()
benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
...@@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," "Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()");
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant); m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);
m.def( m.def(
......
...@@ -347,7 +347,7 @@ cvt_fp16_to_fp4( ...@@ -347,7 +347,7 @@ cvt_fp16_to_fp4(
} }
} }
// Eerly exit when using masks. // Early exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue; continue;
} }
...@@ -383,6 +383,107 @@ cvt_fp16_to_fp4( ...@@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
#endif #endif
} }
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
#else
cvt_fp16_to_fp4_expert(
#endif
int32_t numRows,
int32_t numCols,
Type const* in,
float const* SFScale,
uint32_t* out,
uint32_t* SFout,
int32_t* mask,
bool use_silu_and_mul,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = (gridDim.x * blockDim.x) / n_experts;
int remainder = (gridDim.x * blockDim.x) % n_experts;
int expert_idx;
int tid_in_expert;
int actual_stride;
if (remainder > 0) {
int bound = remainder * (stride + 1);
if (tid < bound) {
expert_idx = tid / (stride + 1);
tid_in_expert = tid % (stride + 1);
actual_stride = stride + 1;
} else {
expert_idx = remainder + (tid - bound) / stride;
tid_in_expert = (tid - bound) % stride;
actual_stride = stride;
}
} else {
expert_idx = tid / stride;
tid_in_expert = tid % stride;
actual_stride = stride;
}
int m = numRows / n_experts;
int padded_m = (m + (128 - 1)) / 128 * 128;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool use_mask = mask != nullptr;
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
globalIdx += actual_stride) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
// Find index within the experts
int rowIdx_in_expert = rowIdx - expert_idx * m;
// Early exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
break;
}
int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_silu_and_mul) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout;
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
#endif
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version) // Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void __global__ void
...@@ -499,6 +600,7 @@ void quant_impl( ...@@ -499,6 +600,7 @@ void quant_impl(
void* input_offset_by_experts, void* input_offset_by_experts,
void* output_scale_offset_by_experts, void* output_scale_offset_by_experts,
void* mask, void* mask,
bool use_silu_and_mul,
int m_topk, int m_topk,
int k, int k,
int n_experts, int n_experts,
...@@ -522,6 +624,22 @@ void quant_impl( ...@@ -522,6 +624,22 @@ void quant_impl(
block.x = (block.x + 1) / 2; block.x = (block.x + 1) / 2;
} }
// TODO(kaixih@nvidia): Should relax this to allow any grid size.
if (mask != nullptr) {
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
m_topk,
k,
reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<int32_t*>(mask),
use_silu_and_mul,
n_experts);
return;
}
int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) { if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
...@@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a( ...@@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(),
nullptr, // mask nullptr, // mask
false, // use_silu_and_mul
m_topk, m_topk,
k, k,
n_experts, n_experts,
...@@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a( ...@@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(),
nullptr, // mask nullptr, // mask
false, // use_silu_and_mul
m_topk, m_topk,
k, k,
n_experts, n_experts,
...@@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& input_global_scale, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& mask,
torch::Tensor const& output_scale_offset_by_experts, bool use_silu_and_mul) {
torch::Tensor const& mask) {
CHECK_INPUT(output, "output must be a CUDA tensor"); CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor"); CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(mask, "mask must be a CUDA tensor"); CHECK_INPUT(mask, "mask must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2); TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2); TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2); TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1); TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(mask.scalar_type() == INT); TORCH_CHECK(mask.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8) // output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32) // output_scale is int32 (four fp8 values are packed into one int32)
...@@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
const int BLOCK_SIZE = 16; const int BLOCK_SIZE = 16;
auto m_topk = input.size(0); auto m_topk = input.size(0);
auto k_by_2 = input.size(1); auto k_by_2 = input.size(1);
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2"); auto k = k_by_2;
auto k = k_by_2 / 2; if (use_silu_and_mul) {
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
k = k_by_2 / 2;
}
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(mask.size(0) == n_experts); TORCH_CHECK(mask.size(0) == n_experts);
TORCH_CHECK(output.size(0) == m_topk); TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2); TORCH_CHECK(output.size(1) == k / 2);
...@@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(), nullptr, // input_offset_by_experts
output_scale_offset_by_experts.data_ptr(), nullptr, // output_scale_offset_by_experts
mask.data_ptr(), mask.data_ptr(),
use_silu_and_mul,
m_topk, m_topk,
k, k,
n_experts, n_experts,
...@@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale.data_ptr(), output_scale.data_ptr(),
input.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(), nullptr, // input_offset_by_experts
output_scale_offset_by_experts.data_ptr(), nullptr, // output_scale_offset_by_experts
mask.data_ptr(), mask.data_ptr(),
use_silu_and_mul,
m_topk, m_topk,
k, k,
n_experts, n_experts,
......
...@@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( ...@@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& input_global_scale, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& mask,
torch::Tensor const& output_scale_offset_by_experts, bool use_silu_and_mul);
torch::Tensor const& mask);
#endif #endif
...@@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant( ...@@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& input_global_scale, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& mask,
torch::Tensor const& output_scale_offset_by_experts, bool use_silu_and_mul) {
torch::Tensor const& mask) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 #if defined ENABLE_NVFP4 && ENABLE_NVFP4
return silu_and_mul_scaled_fp4_experts_quant_sm100a( return silu_and_mul_scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask); output, output_scale, input, input_global_scale, mask, use_silu_and_mul);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
} }
...@@ -394,9 +394,8 @@ void silu_and_mul_scaled_fp4_experts_quant( ...@@ -394,9 +394,8 @@ void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& input_global_scale, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& mask,
torch::Tensor const& output_scale_offset_by_experts, bool use_silu_and_mul);
torch::Tensor const& mask);
/* /*
* From csrc/moe/cutlass_moe/w4a8 * From csrc/moe/cutlass_moe/w4a8
*/ */
......
...@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): ...@@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
def scaled_fp4_grouped_quant( def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
mask: torch.Tensor,
): ):
""" """
Quantize input tensor to FP4 and return quantized tensor and scale, for Quantize input tensor to FP4 and return quantized tensor and scale, for
...@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant( ...@@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant(
output_scales = torch.empty( output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32 l, padded_m, padded_k_int32, device=device, dtype=torch.int32
) )
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2), output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32), output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k), input_tensor.view(l * m, k),
input_global_scale, input_global_scale,
input_offsets, mask,
output_offsets, use_silu_and_mul=False,
) )
# The physical layout of the output is (l, m, k // 2), but we want to return a # The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm. # logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
...@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant( ...@@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant(
output_scales = torch.empty( output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32 l, padded_m, padded_k_int32, device=device, dtype=torch.int32
) )
input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device)
output_offsets = torch.arange(
0,
(l + 1) * padded_m,
step=padded_m,
dtype=torch.int,
device=device,
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default( torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2), output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32), output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k_by_2), input_tensor.view(l * m, k_by_2),
input_global_scale, input_global_scale,
input_offsets,
output_offsets,
mask, mask,
use_silu_and_mul=True,
) )
# The physical layout of the output is (l, m, k // 2), but we want to return a # The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm. # logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
......
...@@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: ...@@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
@pytest.mark.skipif( @pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
) )
def test_quantize_to_fp4_grouped(): @pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)])
def test_quantize_to_fp4_grouped(shape):
torch.manual_seed(42) torch.manual_seed(42)
torch.set_default_device("cuda:0") torch.set_default_device("cuda:0")
l, m, k = 2, 512, 2048 l, m, k = shape
x = torch.randn((l, m, k), dtype=torch.bfloat16) x = torch.randn((l, m, k), dtype=torch.bfloat16)
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32) tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
output, output_scales = scaled_fp4_grouped_quant( output, output_scales = scaled_fp4_grouped_quant(
x, x,
x_sf_global, x_sf_global,
mask,
) )
# output in logical (m, k, l), but its physical layout is (l, m, k). # output in logical (m, k, l), but its physical layout is (l, m, k).
# So permute first to (l, m, k). # So permute first to (l, m, k).
...@@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped(): ...@@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped():
output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1)
for i in range(l): for i in range(l):
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i]) a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
torch.testing.assert_close(a_fp4, output[i]) torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]])
torch.testing.assert_close( # Recover swizzled scales to linear layout and drop padded values, so
a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float) # no extra checks on padding are needed.
) scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]])
@pytest.mark.skipif( @pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
) )
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)]) @pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None: def test_silu_and_mul_quantize_to_fp4_grouped(shape):
torch.manual_seed(42) torch.manual_seed(42)
torch.set_default_device("cuda:0") torch.set_default_device("cuda:0")
l, m, k = shape l, m, k = shape
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16) x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
max_m = 8 max_m = m // 2
assert max_m <= m assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32) mask = torch.randint(1, max_m, (l,), dtype=torch.int32)
...@@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None: ...@@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
ref_output, ref_output_scales = scaled_fp4_grouped_quant( ref_output, ref_output_scales = scaled_fp4_grouped_quant(
ref_y, ref_y,
y_sf_global, y_sf_global,
mask,
) )
output, output_scales = silu_and_mul_scaled_fp4_grouped_quant( output, output_scales = silu_and_mul_scaled_fp4_grouped_quant(
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment