Commit ce47a56e authored by yangyn's avatar yangyn Committed by zhangzbb
Browse files

更新 vllm/model_executor/layers/layernorm.py, vllm/_custom_ops.py

parent 15883da4
......@@ -370,7 +370,8 @@ def rms_norm_opt_fake(
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
op.fused_add_rms_norm_opt(input, residual, weight, epsilon)
#op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake(
input: torch.Tensor,
......@@ -379,8 +380,8 @@ def fused_add_rms_norm_opt_fake(
epsilon: float,
training: Optional[bool] = False,
inplace: Optional[bool] = False,
) -> torch.Tensor:
return torch.empty_like(input)
) -> None:
return None
def fused_qk_norm_rope(
qkv: torch.Tensor,
......@@ -3626,7 +3627,7 @@ direct_register_custom_op(
direct_register_custom_op(
op_name="fused_add_rms_norm_opt",
op_func=fused_add_rms_norm_opt,
mutates_args=[],
mutates_args=["input", "residual"],
fake_impl=fused_add_rms_norm_opt_fake,
)
......
......@@ -58,8 +58,7 @@ def fused_add_rms_norm(
x + residual, weight, variance_epsilon
), x + residual
if envs.VLLM_USE_OPT_OP:
from lightop import fused_add_rms_norm
fused_add_rms_norm(
torch.ops.vllm.fused_add_rms_norm_opt(
x,
residual,
weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment