Commit 4440e8c0 authored by zhuwenwen's avatar zhuwenwen
Browse files

use opt layernorm_kernels

parent d8ae62c7
...@@ -52,36 +52,36 @@ class RMSNorm(CustomOp): ...@@ -52,36 +52,36 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
# if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
# ops.fused_add_rms_norm_opt( ops.fused_add_rms_norm_opt(
# x, x,
# residual, residual,
# self.weight.data, self.weight.data,
# self.variance_epsilon, self.variance_epsilon,
# ) )
# else: else:
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt(
out,
x, x,
residual,
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
) )
return x, residual else:
out = torch.empty_like(x) ops.rms_norm(
# if envs.VLLM_USE_OPT_OP: out,
# ops.rms_norm_opt( x,
# out, self.weight.data,
# x, self.variance_epsilon,
# self.weight.data, )
# self.variance_epsilon,
# )
# else:
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out return out
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment