Commit 4457d4f5 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove rms_norm_opt and fused_add_rms_norm_opt

parent cbff8d34
...@@ -296,7 +296,7 @@ set(VLLM_EXT_SRC ...@@ -296,7 +296,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu" "csrc/opt/activation_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu" # "csrc/opt/layernorm_kernels_opt.cu"
"csrc/fused_qknorm_rope_kernel.cu" "csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu" "csrc/sampler.cu"
......
...@@ -97,11 +97,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, ...@@ -97,11 +97,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids); bool is_neox, torch::Tensor& position_ids);
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, // void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); // double epsilon);
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, // void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); // torch::Tensor& weight, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits, void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask, const torch::Tensor& prompt_mask,
......
...@@ -203,16 +203,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -203,16 +203,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( // ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " // "rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"); // "()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt); // ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt) // In-place fused Add and RMS Normalization. (opt)
ops.def( // ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, " // "fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"); // "float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt); // ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
......
...@@ -6,7 +6,7 @@ requires = [ ...@@ -6,7 +6,7 @@ requires = [
"packaging>=24.2", "packaging>=24.2",
"setuptools>=77.0.3,<81.0.0", "setuptools>=77.0.3,<81.0.0",
"setuptools-scm>=8.0", "setuptools-scm>=8.0",
"torch == 2.9.0", "torch >= 2.7.1",
"wheel", "wheel",
"jinja2", "jinja2",
] ]
......
...@@ -25,7 +25,7 @@ quart ...@@ -25,7 +25,7 @@ quart
fastrlock==0.8.3 fastrlock==0.8.3
cupy==12.3.0 cupy==12.3.0
torch == 2.7.1 torch >= 2.7.1
triton == 3.1 triton == 3.1
flash_attn == 2.6.1 flash_attn == 2.6.1
flash_mla == 1.0.0 flash_mla == 1.0.0
......
...@@ -349,14 +349,14 @@ def fused_add_rms_norm( ...@@ -349,14 +349,14 @@ def fused_add_rms_norm(
# layer norm ops (opt) # layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, # def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None: # epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon) # torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, # def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: # weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon) # torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def fused_qk_norm_rope( def fused_qk_norm_rope(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment