Unverified Commit eb38c7d1 authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent df7f61ee
"""Cutlass MoE kernel."""
"""CUTLASS based Fused MoE kernels."""
import functools
import json
......@@ -14,8 +14,10 @@ _is_cuda = is_cuda()
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
cutlass_fp4_group_mm,
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
scaled_fp4_experts_quant,
silu_and_mul,
)
......@@ -205,3 +207,178 @@ def cutlass_fused_experts(
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = 448.0
def cutlass_moe_fp4(
a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
ab_strides_13: torch.Tensor,
ab_strides_2: torch.Tensor,
c_strides_13: torch.Tensor,
c_strides_2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
):
"""
MoE implementation for FP4 Inputs
# Gemm 1
a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
(Note: `n` is the up projection output dim, `k` is the input dim in
full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4)
# Gemm 2
a2_gscale: Activation scale per expert: [e]
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
Strides for activations, weights and output in logical number of elements.
The activations & output stride is the number of elements to the next row.
The weights stride is the number of elements to the next row per expert.
For example, if the weight is [e, n, k], then the b_stride is a tensor of
shape [e] with each element being k. Similarly for activations, if the
shape is [m, k], then the a_stride has shape [e] with each value k.
Similarly for output, if the output is [m, n], then the c_stride is a
tensor of shape [e] with each element being k.
Note: cutlass_fp4_group_mm is designed to accept the strides of
activations and weights to be the same, so it is passed in as a single
tensor.
ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
c_strides_2: [e] dtype: int64 [Gemm 1: Output Strides]
topk_weights: [m, topk] dtype: float8
topk_ids: [m, topk] dtype: float8
m, n, k: Unquantized weight shapes, dtype: int
e: number of experts for the current rank, dtype: int
assumes that topk < k < n to satisfy - up/down projection expectations.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
assert (
w1_fp4.ndim == 3
and w2_fp4.ndim == 3
and w1_blockscale.ndim == 3
and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, (
"Number of experts must match",
" between weights.",
)
assert (
k_a // 2 == half_k_w1 and k == k_w2
), "Hidden size mismatch between a, w1 and w2"
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`"
assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (
topk_weights.shape[0] == m and topk_ids.shape[0] == m
), "topk must be provided for each row of a"
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
)
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map
)
c1 = cutlass_fp4_group_mm(
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
w1_alphas,
ab_strides_13,
c_strides_13,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty(
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
)
silu_and_mul(c1, intermediate)
int_fp4, int_blockscale = scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
c2 = cutlass_fp4_group_mm(
int_fp4,
w2_fp4,
int_blockscale,
w2_blockscale,
w2_alphas,
ab_strides_2,
c_strides_2,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
)
del int_fp4, int_blockscale
out = (
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()
).sum(dim=1)
return out.to(dtype=out_dtype)
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.topk import select_experts
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
FLOAT8_E4M3_MAX = 448.0
FLOAT4_E2M1_MAX = 6.0
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1024),
(224, 1024, 1024),
(224, 1024, 1536),
]
# Reference implementation of torch_moe
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
):
torch.manual_seed(7)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = scaled_fp4_quant(
w1[expert], w1_gs[expert]
)
w2_q[expert], w2_blockscale[expert] = scaled_fp4_quant(
w2[expert], w2_gs[expert]
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
use_grouped_topk=False,
renormalize=False,
)
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
# strides for the cutlass moe_fp4 kernel
ab_strides_13 = torch.full(
(e,), w1_q.shape[2] * 2, dtype=torch.int64, device=w1_q.device
)
c_strides_13 = torch.full(
(e,), w1_q.shape[1], dtype=torch.int64, device=w1_q.device
)
ab_strides_2 = torch.full(
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
)
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_q,
w1_blockscale=w1_blockscale,
w1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
ab_strides_13=ab_strides_13,
ab_strides_2=ab_strides_2,
c_strides_13=c_strides_13,
c_strides_2=c_strides_2,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=e,
device=a.device,
)
# Reference check:
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
......@@ -210,6 +210,7 @@ set(SOURCES
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/nvfp4_expert_quant.cu"
"csrc/gemm/nvfp4_quant_entry.cu"
"csrc/gemm/nvfp4_quant_kernels.cu"
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
......@@ -222,6 +223,7 @@ set(SOURCES
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
......
......@@ -132,6 +132,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
m.def(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
/*
* From csrc/moe
*/
......@@ -161,9 +175,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
"()");
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
/*
* From csrc/speculative
*/
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]),
"f"(array[1]),
"f"(array[2]),
"f"(array[3]),
"f"(array[4]),
"f"(array[5]),
"f"(array[6]),
"f"(array[7]));
return val;
#else
return 0;
#endif
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
// PTX instructions used here requires sm100a.
#if CUDA_VERSION >= 12080
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
return val;
#else
return 0;
#endif
#endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, int numCols, SFType* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
#endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
#else
cvt_fp16_to_fp4(
#endif
int32_t numRows,
int32_t numCols,
Type const* in,
float const* SFScale,
uint32_t* out,
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts.
int rowIdx_in_expert = 0;
int expert_idx = 0;
for (int i = 0; i < n_experts; i++) {
if (rowIdx >= input_offset_by_experts[i] && rowIdx < input_offset_by_experts[i + 1]) {
rowIdx_in_expert = rowIdx - input_offset_by_experts[i];
expert_idx = i;
break;
}
}
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
}
#endif
}
template <typename T>
void quant_impl(
void* output,
void* output_scale,
void* input,
void* input_global_scale,
void* input_offset_by_experts,
void* output_scale_offset_by_experts,
int m_topk,
int k,
int n_experts,
cudaStream_t stream) {
// TODO: this multiProcessorCount should be cached.
int device;
cudaGetDevice(&device);
int multiProcessorCount;
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(k / ELTS_PER_THREAD), 512));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM));
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
m_topk,
k,
reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
}
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
// constexpr auto FP8 = at::ScalarType::Float8_e4m3fn;
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
auto in_dtype = input.dtype();
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
if (in_dtype == at::ScalarType::Half) {
quant_impl<half>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
m_topk,
k,
n_experts,
stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
quant_impl<__nv_bfloat16>(
output.data_ptr(),
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
m_topk,
k,
n_experts,
stream);
} else {
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
}
}
......@@ -18,6 +18,15 @@ limitations under the License.
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
void scaled_fp4_quant_sm100a(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void scaled_fp4_experts_quant_sm100a(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(
......@@ -27,3 +36,17 @@ void scaled_fp4_quant(
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
}
void scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return scaled_fp4_experts_quant_sm100a(
output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel");
}
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <cutlass/arch/arch.h>
#include <torch/all.h>
#include <cassert>
#include "cute/tensor.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/tensor_view_io.h"
using namespace cute;
template <
typename ElementAB,
typename ElementC,
typename ElementSF,
typename ElementAccumulator,
typename LayoutSFA,
typename LayoutSFB,
typename ScaleConfig>
__global__ void __get_group_gemm_starts(
ElementAB** a_offsets,
ElementAB** b_offsets,
ElementC** out_offsets,
ElementSF** a_scales_offsets,
ElementSF** b_scales_offsets,
ElementAccumulator** alpha_offsets,
LayoutSFA* layout_sfa_base_as_int,
LayoutSFB* layout_sfb_base_as_int,
ElementAB* a_base_as_int,
ElementAB* b_base_as_int,
ElementC* out_base_as_int,
ElementSF* a_scales_base_as_int,
ElementSF* b_scales_base_as_int,
ElementAccumulator* alphas_base_as_int,
const int32_t* expert_offsets,
const int32_t* sf_offsets,
const int32_t* problem_sizes_as_shapes,
const int K,
const int N) {
int64_t expert_id = threadIdx.x;
if (expert_id >= gridDim.x * blockDim.x) {
return;
}
// Originally int32_t but upcasting to int64_t to avoid overflow
// during offset calculations
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t sf_offset = static_cast<int64_t>(sf_offsets[expert_id]);
// size for block in block scale.
int64_t group_size = 16;
int64_t m = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3]);
int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
assert((m >= 0 && n == N && k == K && k % 2 == 0) && "unexpected problem sizes");
int64_t half_k = static_cast<int64_t>(k / 2);
int64_t group_k = static_cast<int64_t>(k / group_size);
// Shape of A as uint8/byte = [M, K // 2]
// Shape of B as uint8/byte = [E, N, K // 2]
a_offsets[expert_id] = a_base_as_int + expert_offset * half_k;
b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k;
// Shape of C = [M, N]
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
// Shape of a_scale = [sum(sf_sizes), K // group_size]
a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k;
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
// Shape of B scale = [E, N, K // group_size]
b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k;
assert((reinterpret_cast<uintptr_t>(b_scales_offsets[expert_id]) % 128) == 0 && "TMA requires 128-byte alignment");
// Shape of alpha = [E]
alpha_offsets[expert_id] = alphas_base_as_int + expert_id;
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
*layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(
cute::make_shape(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), 1));
}
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE( \
ELEMENT_AB_TYPE, SF_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
static_cast<ELEMENT_AB_TYPE**>(a_starts.data_ptr()), \
static_cast<ELEMENT_AB_TYPE**>(b_starts.data_ptr()), \
static_cast<C_TYPE**>(out_starts.data_ptr()), \
static_cast<SF_TYPE**>(a_scales_starts.data_ptr()), \
static_cast<SF_TYPE**>(b_scales_starts.data_ptr()), \
static_cast<float**>(alpha_starts.data_ptr()), \
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(a_tensors.data_ptr()), \
static_cast<ELEMENT_AB_TYPE*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<SF_TYPE*>(a_scales.data_ptr()), \
static_cast<SF_TYPE*>(b_scales.data_ptr()), \
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
K, \
N); \
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts,
const torch::Tensor& b_starts,
const torch::Tensor& out_starts,
const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts,
const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& alphas,
torch::Tensor const& expert_offsets,
torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes,
int M,
int N,
int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
TORCH_CHECK(out_tensors.size(1) == N, "Output tensor shape doesn't match expected shape");
TORCH_CHECK(
K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t,
cutlass::float_ue4m3_t,
torch::kBFloat16,
cutlass::bfloat16_t,
LayoutSFA,
LayoutSFB,
ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kFloat16, half, LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets,
int M,
int N,
int K) {
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
using ElementSFType = cutlass::float_ue4m3_t;
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using ElementC = OutType;
using ElementD = ElementC;
using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Alignment constraints
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100;
using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag
using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based
// on the tile size
using ClusterShape = Shape<_1, _1, _1>;
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
EpilogueOperatorClass,
typename MMA1SMConfig::MmaTileShape,
ClusterShape,
Shape<_128, _64>,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC*,
AlignmentC,
ElementD,
LayoutC*,
AlignmentD,
typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
MainloopOperatorClass,
ElementA,
LayoutA*,
AlignmentA,
ElementB,
LayoutB*,
AlignmentB,
ElementAccumulator,
typename MMA1SMConfig::MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMA1SMConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Gemm = Gemm1SM;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using ScaleConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
alpha_ptrs,
layout_sfa,
layout_sfb,
a,
b,
output,
a_blockscale,
b_blockscales,
alphas,
expert_offsets,
sf_offsets,
problem_sizes,
M,
N,
K);
// Create an instance of the GEMM
Gemm gemm_op;
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Set the Scheduler info
cutlass::KernelHardwareInfo hw_info;
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams<
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
// Mainloop Arguments
typename GemmKernel::MainloopArguments mainloop_args{
static_cast<const ElementType**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(ab_strides.data_ptr()),
static_cast<const ElementType**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(ab_strides.data_ptr()),
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
// Epilogue Arguments
typename GemmKernel::EpilogueArguments epilogue_args{
{}, // epilogue.thread
nullptr,
static_cast<StrideC*>(c_strides.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
auto& fusion_args = epilogue_args.thread;
fusion_args.alpha_ptr_array = reinterpret_cast<float**>(alpha_ptrs.data_ptr());
fusion_args.dAlpha = {_0{}, _0{}, 1};
// Gemm Arguments
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
mainloop_args,
epilogue_args,
hw_info,
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets) {
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
// Input validation
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(
b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
} else {
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
output,
a,
b,
a_blockscale,
b_blockscales,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
sf_offsets,
M,
N,
K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, sgl-kernel must "
"be compiled with ENABLE_NVFP4 for SM100+ and CUDA "
"12.8 or above.");
#endif
}
......@@ -4,6 +4,8 @@
#include <iostream>
#include "cutlass/array.h"
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(
......@@ -11,9 +13,9 @@ __global__ void compute_problem_sizes(
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
const int64_t topk_length,
const int64_t n,
const int64_t k) {
int expert_id = blockIdx.x;
int occurrences = 0;
......@@ -26,11 +28,11 @@ __global__ void compute_problem_sizes(
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
}
}
......@@ -38,7 +40,7 @@ __global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
const int64_t num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
......@@ -48,13 +50,34 @@ __global__ void compute_expert_offsets(
}
}
__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* blockscale_offsets,
int32_t* atomic_buffer,
const int64_t num_experts) {
int32_t tot_offset = 0;
int32_t tot_rounded_offset = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
int num_tokens = problem_sizes1[i * 3];
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
tot_offset += num_tokens;
tot_rounded_offset += rounded_num_tokens;
expert_offsets[i + 1] = tot_offset;
blockscale_offsets[i + 1] = tot_rounded_offset;
}
}
__global__ void compute_arg_sorts(
const int* __restrict__ topk_ids,
const int32_t* __restrict__ topk_ids,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer,
const int topk_length,
const int topk) {
const int64_t topk_length,
const int64_t topk) {
int expert_id = blockIdx.x;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
......@@ -69,6 +92,7 @@ __global__ void compute_arg_sorts(
void get_moe_prepare_input_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
......@@ -80,8 +104,10 @@ void get_moe_prepare_input_caller(
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
......@@ -89,12 +115,21 @@ void get_moe_prepare_input_caller(
topk_ids.numel(),
n,
k);
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
......@@ -106,6 +141,7 @@ void get_moe_prepare_input_caller(
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
......@@ -117,6 +153,7 @@ void prepare_moe_input(
get_moe_prepare_input_caller(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
......@@ -126,3 +163,92 @@ void prepare_moe_input(
k);
return;
}
template <typename T>
__global__ void shuffleRowsKernel(
const T* input,
const int32_t* dst2src_map,
T* output,
int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
auto const start_offset = threadIdx.x;
auto const stride = blockDim.x;
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
#define DECLARE_SHUFFLE_ROWS(T) \
__global__ void shuffleRowsKernel( \
const T* input, \
const int32_t* dst2src_map, \
T* output, \
int64_t num_src_rows, \
int64_t num_dest_rows, \
int64_t num_cols);
DECLARE_SHUFFLE_ROWS(float);
DECLARE_SHUFFLE_ROWS(half);
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
DECLARE_SHUFFLE_ROWS(uint8_t);
#define SHUFFLE_ROWS(T) \
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
reinterpret_cast<const T*>(input), \
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
reinterpret_cast<T*>(output), \
num_src_rows, \
num_dst_rows, \
num_cols)
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
case T: \
SHUFFLE_ROWS(CUDA_T); \
break;
void shuffle_rows_caller(
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
TORCH_CHECK(
input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
uint32_t threads = 256;
int64_t num_dst_rows = output_tensor.size(0);
int64_t num_src_rows = input_tensor.size(0);
int64_t num_cols = input_tensor.size(1);
const void* input = input_tensor.data_ptr();
void* output = output_tensor.data_ptr();
switch (input_tensor.scalar_type()) {
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
default:
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
}
return;
}
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
return;
}
......@@ -232,6 +232,7 @@ void fp8_blockwise_scaled_grouped_mm(
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
const std::optional<torch::Tensor>& blockscale_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
......@@ -251,6 +252,29 @@ void ep_moe_pre_reorder(
int64_t topk,
bool use_per_token_if_dynamic);
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
void cutlass_fp4_group_mm(
torch::Tensor& output,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_blockscale,
const torch::Tensor& b_blockscales,
const torch::Tensor& alphas,
const torch::Tensor& ab_strides,
const torch::Tensor& c_strides,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& sf_offsets);
void scaled_fp4_experts_quant(
torch::Tensor& output,
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
/*
* From csrc/speculative
*/
......
......@@ -38,14 +38,17 @@ from sgl_kernel.gemm import (
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
cutlass_fp4_group_mm,
ep_moe_pre_reorder,
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
......
......@@ -241,3 +241,80 @@ def qserve_w4a8_per_group_gemm(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert (
input_tensor.ndim == 2
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
if expert_map is not None:
(m, k) = input_tensor.shape
output_tensor_shape = (m * topk, k)
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
import os
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
from typing import Optional
import torch
......@@ -138,10 +140,12 @@ def prepare_moe_input(
num_experts,
n,
k,
blockscale_offsets: Optional[torch.Tensor] = None,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
......@@ -150,3 +154,54 @@ def prepare_moe_input(
n,
k,
)
def cutlass_fp4_group_mm(
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
out_dtype,
device,
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- ab_strides/c_strides: Strides for the a/b tensors between rows.
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_fp4.shape[0]
n = b_fp4.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
c,
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
)
return c.to(dtype=out_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment