Commit bf278a88 authored by zhuwenwen's avatar zhuwenwen
Browse files

use ori layernorm_kernels

parent 668ec4ef
...@@ -52,14 +52,14 @@ class RMSNorm(CustomOp): ...@@ -52,14 +52,14 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
if envs.VLLM_USE_OPT_OP: # if envs.VLLM_USE_OPT_OP:
ops.fused_add_rms_norm_opt( # ops.fused_add_rms_norm_opt(
x, # x,
residual, # residual,
self.weight.data, # self.weight.data,
self.variance_epsilon, # self.variance_epsilon,
) # )
else: # else:
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
residual, residual,
...@@ -68,14 +68,14 @@ class RMSNorm(CustomOp): ...@@ -68,14 +68,14 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
if envs.VLLM_USE_OPT_OP: # if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt( # ops.rms_norm_opt(
out, # out,
x, # x,
self.weight.data, # self.weight.data,
self.variance_epsilon, # self.variance_epsilon,
) # )
else: # else:
ops.rms_norm( ops.rms_norm(
out, out,
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment