Commit 4440e8c0 authored by zhuwenwen's avatar zhuwenwen
Browse files

use opt layernorm_kernels

parent d8ae62c7
......@@ -52,14 +52,14 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops
if residual is not None:
# if envs.VLLM_USE_OPT_OP:
# ops.fused_add_rms_norm_opt(
# x,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# else:
if envs.VLLM_USE_OPT_OP:
ops.fused_add_rms_norm_opt(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
else:
ops.fused_add_rms_norm(
x,
residual,
......@@ -68,14 +68,14 @@ class RMSNorm(CustomOp):
)
return x, residual
out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP:
# ops.rms_norm_opt(
# out,
# x,
# self.weight.data,
# self.variance_epsilon,
# )
# else:
if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt(
out,
x,
self.weight.data,
self.variance_epsilon,
)
else:
ops.rms_norm(
out,
x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment