Commit 4457d4f5 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove rms_norm_opt and fused_add_rms_norm_opt

parent cbff8d34
......@@ -296,7 +296,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
# "csrc/opt/layernorm_kernels_opt.cu"
"csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
......
......@@ -97,11 +97,11 @@ void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids);
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
// void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
// double epsilon);
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
// void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
// torch::Tensor& weight, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
......
......@@ -203,16 +203,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// ops.def(
// "rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
// "()");
// ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// ops.def(
// "fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
// "float epsilon) -> ()");
// ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
......
......@@ -6,7 +6,7 @@ requires = [
"packaging>=24.2",
"setuptools>=77.0.3,<81.0.0",
"setuptools-scm>=8.0",
"torch == 2.9.0",
"torch >= 2.7.1",
"wheel",
"jinja2",
]
......
......@@ -25,7 +25,7 @@ quart
fastrlock==0.8.3
cupy==12.3.0
torch == 2.7.1
torch >= 2.7.1
triton == 3.1
flash_attn == 2.6.1
flash_mla == 1.0.0
......
......@@ -349,14 +349,14 @@ def fused_add_rms_norm(
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
# def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
# epsilon: float) -> None:
# torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
# def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
# weight: torch.Tensor, epsilon: float) -> None:
# torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def fused_qk_norm_rope(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment