"vscode:/vscode.git/clone" did not exist on "88256082058fdbd41281c4f1f9a19663a4d7a668"
Commit 4440e8c0 authored by zhuwenwen's avatar zhuwenwen
Browse files

use opt layernorm_kernels

parent d8ae62c7
...@@ -52,14 +52,14 @@ class RMSNorm(CustomOp): ...@@ -52,14 +52,14 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
# if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
# ops.fused_add_rms_norm_opt( ops.fused_add_rms_norm_opt(
# x, x,
# residual, residual,
# self.weight.data, self.weight.data,
# self.variance_epsilon, self.variance_epsilon,
# ) )
# else: else:
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
residual, residual,
...@@ -68,14 +68,14 @@ class RMSNorm(CustomOp): ...@@ -68,14 +68,14 @@ class RMSNorm(CustomOp):
) )
return x, residual return x, residual
out = torch.empty_like(x) out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
# ops.rms_norm_opt( ops.rms_norm_opt(
# out, out,
# x, x,
# self.weight.data, self.weight.data,
# self.variance_epsilon, self.variance_epsilon,
# ) )
# else: else:
ops.rms_norm( ops.rms_norm(
out, out,
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment