"vscode:/vscode.git/clone" did not exist on "ef5ebdbf8af2696a9a7b12627ca18bf94d222947"
Commit 3f414133 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'gy_0151_rms_norm_opt' into 'v0.15.1-dev'

rms_norm_opt精度问题解决(换了个kernel)

See merge request dcutoolkit/deeplearing/vllm!499
parents 46bb1d6d 9404668a
......@@ -356,7 +356,8 @@ def fused_add_rms_norm(
# layer norm ops (opt)
def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor,
epsilon: float, training: Optional[bool]=False) -> None:
op.rmsnorm_forward(input, weight, out, epsilon, training)
op.rms_norm_opt(out,input, weight, epsilon)
#op.rmsnorm_forward(input, weight, out, epsilon, training)
def rms_norm_opt_fake(
input: torch.Tensor,
......@@ -364,8 +365,8 @@ def rms_norm_opt_fake(
out: torch.Tensor,
epsilon: float,
training: Optional[bool] = False,
) -> torch.Tensor:
return torch.empty_like(input)
) -> None:
return None
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
......@@ -3618,7 +3619,7 @@ direct_register_custom_op(
direct_register_custom_op(
op_name="rms_norm_opt",
op_func=rms_norm_opt,
mutates_args=[],
mutates_args=["out"],
fake_impl=rms_norm_opt_fake,
)
......
......@@ -27,14 +27,14 @@ def rms_norm(
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP:
if False:
ops.rms_norm_opt(
if envs.VLLM_USE_OPT_OP:
torch.ops.vllm.rms_norm_opt(
x,
weight,
out,
variance_epsilon,
)
False,
)#False参数对当前的lightop调用的kernel是多余的
else:
ops.rms_norm(
out,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment