Unverified Commit d0359f3e authored by mysterious hhhh's avatar mysterious hhhh Committed by GitHub
Browse files

[Bugfix] Guard mxfp4_experts_quant bindings on ENABLE_NVFP4_SM100 (#40191)


Signed-off-by: default avatarultranationalism <www913363043@gmail.com>
Signed-off-by: default avatarmgoin <mike.goin12@gmail.com>
Co-authored-by: default avatarmgoin <mike.goin12@gmail.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent ed0622e3
......@@ -134,20 +134,6 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
......
......@@ -23,6 +23,7 @@
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
......@@ -420,3 +421,12 @@ void silu_and_mul_mxfp4_experts_quant(
stream);
});
}
// Registered here (not torch_bindings.cpp) because VLLM_GPU_FLAGS is applied
// only under COMPILE_LANGUAGE:CUDA, so ENABLE_NVFP4_SM100 is invisible to
// .cpp files and cannot gate the registration from there.
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant));
m.impl("silu_and_mul_mxfp4_experts_quant",
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
}
......@@ -252,12 +252,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
ops.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant));
ops.impl("silu_and_mul_mxfp4_experts_quant",
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
// mxfp4_experts_quant: registered in mxfp4_experts_quant.cu (SM100 only).
// W4A8 ops: registered in w4a8_mm_entry.cu / w4a8_grouped_mm_entry.cu.
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment