Unverified Commit 09e4576f authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Kernel] Add non-gated support for NVFP4 CUTLASS MoE (#37320)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 3ed7b1e6
...@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data( ...@@ -262,7 +262,8 @@ void get_cutlass_moe_mm_data(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset, const torch::Tensor& expert_first_token_offset,
......
...@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, ...@@ -17,8 +17,11 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes2, int32_t* problem_sizes2,
int32_t* atomic_buffer, int32_t* atomic_buffer,
const int topk_length, const int n, const int topk_length, const int n,
const int k) { const int k, const bool is_gated) {
int expert_id = blockIdx.x; int expert_id = blockIdx.x;
// For gated activations (gate + up), first GEMM output is 2*n.
// For non-gated activations (up only), first GEMM output is n.
int const n1 = is_gated ? 2 * n : n;
int occurrences = 0; int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
...@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, ...@@ -31,13 +34,13 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int final_occurrences = atomic_buffer[expert_id]; int final_occurrences = atomic_buffer[expert_id];
if constexpr (!SWAP_AB) { if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3] = final_occurrences; problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n; problem_sizes1[expert_id * 3 + 1] = n1;
problem_sizes1[expert_id * 3 + 2] = k; problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = final_occurrences; problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k; problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n; problem_sizes2[expert_id * 3 + 2] = n;
} else { } else {
problem_sizes1[expert_id * 3] = 2 * n; problem_sizes1[expert_id * 3] = n1;
problem_sizes1[expert_id * 3 + 1] = final_occurrences; problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k; problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k; problem_sizes2[expert_id * 3] = k;
...@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, ...@@ -107,13 +110,11 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
} }
namespace { namespace {
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, inline void launch_compute_problem_sizes(
torch::Tensor& problem_sizes1, const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
torch::Tensor& atomic_buffer, int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
int64_t num_experts, int64_t n, const bool swap_ab, const bool is_gated) {
int64_t k, cudaStream_t stream,
const bool swap_ab) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>(); auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
...@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, ...@@ -125,7 +126,7 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>( compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n), static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k)); static_cast<int>(k), is_gated);
}); });
} }
} // namespace } // namespace
...@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller( ...@@ -222,7 +223,8 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) { const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
...@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller( ...@@ -236,7 +238,7 @@ void get_cutlass_moe_mm_data_caller(
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream, atomic_buffer, num_experts, n, k, stream,
may_swap_ab); may_swap_ab, is_gated);
if (blockscale_offsets.has_value()) { if (blockscale_offsets.has_value()) {
// fp4 path // fp4 path
......
...@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller( ...@@ -75,7 +75,8 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset, const torch::Tensor& expert_first_token_offset,
...@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data( ...@@ -278,7 +279,8 @@ void get_cutlass_moe_mm_data(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) { const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe // This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for. // mm to run it for.
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
...@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data( ...@@ -288,7 +290,7 @@ void get_cutlass_moe_mm_data(
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation, problem_sizes2, input_permutation,
output_permutation, num_experts, n, k, output_permutation, num_experts, n, k,
blockscale_offsets); blockscale_offsets, is_gated);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
......
...@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -489,8 +489,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, " " Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, " " Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, " " Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets) -> " " int n, int k, Tensor? blockscale_offsets, "
"()"); " bool is_gated) -> ()");
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// compute per-expert problem sizes from expert_first_token_offset // compute per-expert problem sizes from expert_first_token_offset
......
model_name: "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
accuracy_threshold: 0.29
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --moe-backend=cutlass"
...@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml ...@@ -15,3 +15,4 @@ Mixtral-8x7B-BF16-fi-cutlass.yaml
Mixtral-8x7B-BF16-triton.yaml Mixtral-8x7B-BF16-triton.yaml
Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml Nemotron-Nano-30B-Fp8-ModelOpt-fi-trtllm.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml Nemotron-Nano-30B-NvFp4-ModelOpt-fi-cutlass.yaml
Nemotron-Nano-30B-NvFp4-ModelOpt-vllm-cutlass.yaml
...@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data( ...@@ -989,6 +989,7 @@ def get_cutlass_moe_mm_data(
n: int, n: int,
k: int, k: int,
blockscale_offsets: torch.Tensor | None = None, blockscale_offsets: torch.Tensor | None = None,
is_gated: bool = True,
): ):
""" """
Prepare data necessary to perform CUTLASS grouped matrix multiplications Prepare data necessary to perform CUTLASS grouped matrix multiplications
...@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data( ...@@ -1012,6 +1013,8 @@ def get_cutlass_moe_mm_data(
its computation. The number of block scale rows its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] - computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E] blockscale_offsets[E]
- is_gated: Whether the activation is gated (gate + up). When True, the
first GEMM N dimension is 2*n; when False, it is n.
""" """
return torch.ops._C.get_cutlass_moe_mm_data( return torch.ops._C.get_cutlass_moe_mm_data(
topk_ids, topk_ids,
...@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data( ...@@ -1024,6 +1027,7 @@ def get_cutlass_moe_mm_data(
n, n,
k, k,
blockscale_offsets, blockscale_offsets,
is_gated,
) )
......
...@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4( ...@@ -507,11 +507,12 @@ def run_cutlass_moe_fp4(
# Gemm 1 # Gemm 1
a: Input tensor: [m, k] (half/bfloat16) a: Input tensor: [m, k] (half/bfloat16)
a1_gscale: Activation scale per expert: [e] (float32) a1_gscale: Activation scale per expert: [e] (float32)
w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] w1 (not an argument to cutlass_moe_fp4): [e, w1_n, k]
w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) w1_fp4: [e, w1_n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
where w1_n = 2*n for gated activations (gate+up), n for non-gated (up only).
(Note: `n` is the up projection output dim, `k` is the input dim in (Note: `n` is the up projection output dim, `k` is the input dim in
full precision) full precision)
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) w1_blockscale: [e, w1_n, k // block_size] (float8_e4m3)
(Block size = 16 for NVFP4) (Block size = 16 for NVFP4)
# Gemm 2 # Gemm 2
...@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4( ...@@ -528,6 +529,11 @@ def run_cutlass_moe_fp4(
assumes that topk < k < n to satisfy - up/down projection expectations. assumes that topk < k < n to satisfy - up/down projection expectations.
""" """
is_gated = activation.is_gated
# For gated activations (e.g. SiLU), w1 output is 2*n (gate + up).
# For non-gated activations (e.g. SiLU_NO_MUL), w1 output is n (up only).
w1_n = n * 2 if is_gated else n
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
...@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4( ...@@ -538,7 +544,7 @@ def run_cutlass_moe_fp4(
and w2_blockscale.ndim == 3 and w2_blockscale.ndim == 3
), "All Weights must be of rank 3 for cutlass_moe_fp4" ), "All Weights must be of rank 3 for cutlass_moe_fp4"
m_a, k_a = a.shape m_a, k_a = a.shape
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w1, w1_n_actual, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, ( assert e_w1 == e_w2 and e_w1 == e, (
...@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4( ...@@ -548,7 +554,7 @@ def run_cutlass_moe_fp4(
assert k_a == half_k_w1 * 2 and k == k_w2, ( assert k_a == half_k_w1 * 2 and k == k_w2, (
"Hidden size mismatch between a, w1 and w2" "Hidden size mismatch between a, w1 and w2"
) )
assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`" assert w1_n_actual == w1_n and half_n_w2 * 2 == n, "mismatch in expected `n`"
assert m == m_a, "input shape mismatch" assert m == m_a, "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
...@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4( ...@@ -589,6 +595,7 @@ def run_cutlass_moe_fp4(
n, n,
k, k,
blockscale_offsets, blockscale_offsets,
is_gated=is_gated,
) )
a = ops.shuffle_rows(a, a_map) a = ops.shuffle_rows(a, a_map)
...@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4( ...@@ -599,7 +606,7 @@ def run_cutlass_moe_fp4(
blockscale_offsets, blockscale_offsets,
num_topk, num_topk,
) )
c1 = _resize_cache(workspace13, (m * topk, n * 2)) c1 = _resize_cache(workspace13, (m * topk, w1_n))
c2 = _resize_cache(workspace2, (m * topk, n)) c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k)) c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm( ops.cutlass_fp4_moe_mm(
...@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular): ...@@ -681,7 +688,7 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
@staticmethod @staticmethod
def _supports_no_act_and_mul() -> bool: def _supports_no_act_and_mul() -> bool:
return False return True
@staticmethod @staticmethod
def _supports_quant_scheme( def _supports_quant_scheme(
...@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular): ...@@ -695,11 +702,16 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
# SILU uses a fused silu+mul+fp4_quant kernel path. # SILU uses a fused silu+mul+fp4_quant kernel path.
# Other gated activations use the generic apply_moe_activation() # Other gated activations use the generic apply_moe_activation()
# fallback + separate fp4 quantization in run_cutlass_moe_fp4(). # fallback + separate fp4 quantization in run_cutlass_moe_fp4().
# Non-gated activations (_NO_MUL) are also supported for models
# like Nemotron-Nano that don't use gated MLP.
return activation in [ return activation in [
MoEActivation.SILU, MoEActivation.SILU,
MoEActivation.GELU, MoEActivation.GELU,
MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP, MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
] ]
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment