"wrappers/python/src/vscode:/vscode.git/clone" did not exist on "e3b252043f1a45bad2253e6c755a48efacf38649"
Commit ce47a56e authored by yangyn's avatar yangyn Committed by zhangzbb
Browse files

更新 vllm/model_executor/layers/layernorm.py, vllm/_custom_ops.py

parent 15883da4
...@@ -370,7 +370,8 @@ def rms_norm_opt_fake( ...@@ -370,7 +370,8 @@ def rms_norm_opt_fake(
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None: weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace) op.fused_add_rms_norm_opt(input, residual, weight, epsilon)
#op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake( def fused_add_rms_norm_opt_fake(
input: torch.Tensor, input: torch.Tensor,
...@@ -379,8 +380,8 @@ def fused_add_rms_norm_opt_fake( ...@@ -379,8 +380,8 @@ def fused_add_rms_norm_opt_fake(
epsilon: float, epsilon: float,
training: Optional[bool] = False, training: Optional[bool] = False,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> torch.Tensor: ) -> None:
return torch.empty_like(input) return None
def fused_qk_norm_rope( def fused_qk_norm_rope(
qkv: torch.Tensor, qkv: torch.Tensor,
...@@ -3626,7 +3627,7 @@ direct_register_custom_op( ...@@ -3626,7 +3627,7 @@ direct_register_custom_op(
direct_register_custom_op( direct_register_custom_op(
op_name="fused_add_rms_norm_opt", op_name="fused_add_rms_norm_opt",
op_func=fused_add_rms_norm_opt, op_func=fused_add_rms_norm_opt,
mutates_args=[], mutates_args=["input", "residual"],
fake_impl=fused_add_rms_norm_opt_fake, fake_impl=fused_add_rms_norm_opt_fake,
) )
......
...@@ -58,8 +58,7 @@ def fused_add_rms_norm( ...@@ -58,8 +58,7 @@ def fused_add_rms_norm(
x + residual, weight, variance_epsilon x + residual, weight, variance_epsilon
), x + residual ), x + residual
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
from lightop import fused_add_rms_norm torch.ops.vllm.fused_add_rms_norm_opt(
fused_add_rms_norm(
x, x,
residual, residual,
weight, weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment