Unverified Commit d0359f3e authored by mysterious hhhh's avatar mysterious hhhh Committed by GitHub
Browse files

[Bugfix] Guard mxfp4_experts_quant bindings on ENABLE_NVFP4_SM100 (#40191)


Signed-off-by: default avatarultranationalism <www913363043@gmail.com>
Signed-off-by: default avatarmgoin <mike.goin12@gmail.com>
Co-authored-by: default avatarmgoin <mike.goin12@gmail.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent ed0622e3
...@@ -134,20 +134,6 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out, ...@@ -134,20 +134,6 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& input, torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale); torch::stable::Tensor& input_global_scale);
void mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void silu_and_mul_mxfp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts,
int64_t n_experts);
void cutlass_mxfp4_group_mm(torch::stable::Tensor& output, void cutlass_mxfp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& b,
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h> #include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h" #include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h" #include "libtorch_stable/dispatch_utils.h"
...@@ -420,3 +421,12 @@ void silu_and_mul_mxfp4_experts_quant( ...@@ -420,3 +421,12 @@ void silu_and_mul_mxfp4_experts_quant(
stream); stream);
}); });
} }
// Registered here (not torch_bindings.cpp) because VLLM_GPU_FLAGS is applied
// only under COMPILE_LANGUAGE:CUDA, so ENABLE_NVFP4_SM100 is invisible to
// .cpp files and cannot gate the registration from there.
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant));
m.impl("silu_and_mul_mxfp4_experts_quant",
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
}
...@@ -252,12 +252,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { ...@@ -252,12 +252,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("silu_and_mul_scaled_fp4_experts_quant", ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant)); TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant)); ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
ops.impl("mxfp4_experts_quant", TORCH_BOX(&mxfp4_experts_quant)); // mxfp4_experts_quant: registered in mxfp4_experts_quant.cu (SM100 only).
ops.impl("silu_and_mul_mxfp4_experts_quant", // W4A8 ops: registered in w4a8_mm_entry.cu / w4a8_grouped_mm_entry.cu.
TORCH_BOX(&silu_and_mul_mxfp4_experts_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment