Commit 3f414133 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'gy_0151_rms_norm_opt' into 'v0.15.1-dev'

rms_norm_opt精度问题解决(换了个kernel)

See merge request dcutoolkit/deeplearing/vllm!499
parents 46bb1d6d 9404668a
...@@ -356,7 +356,8 @@ def fused_add_rms_norm( ...@@ -356,7 +356,8 @@ def fused_add_rms_norm(
# layer norm ops (opt) # layer norm ops (opt)
def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor, def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor,
epsilon: float, training: Optional[bool]=False) -> None: epsilon: float, training: Optional[bool]=False) -> None:
op.rmsnorm_forward(input, weight, out, epsilon, training) op.rms_norm_opt(out,input, weight, epsilon)
#op.rmsnorm_forward(input, weight, out, epsilon, training)
def rms_norm_opt_fake( def rms_norm_opt_fake(
input: torch.Tensor, input: torch.Tensor,
...@@ -364,8 +365,8 @@ def rms_norm_opt_fake( ...@@ -364,8 +365,8 @@ def rms_norm_opt_fake(
out: torch.Tensor, out: torch.Tensor,
epsilon: float, epsilon: float,
training: Optional[bool] = False, training: Optional[bool] = False,
) -> torch.Tensor: ) -> None:
return torch.empty_like(input) return None
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None: weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
...@@ -3618,7 +3619,7 @@ direct_register_custom_op( ...@@ -3618,7 +3619,7 @@ direct_register_custom_op(
direct_register_custom_op( direct_register_custom_op(
op_name="rms_norm_opt", op_name="rms_norm_opt",
op_func=rms_norm_opt, op_func=rms_norm_opt,
mutates_args=[], mutates_args=["out"],
fake_impl=rms_norm_opt_fake, fake_impl=rms_norm_opt_fake,
) )
......
...@@ -27,14 +27,14 @@ def rms_norm( ...@@ -27,14 +27,14 @@ def rms_norm(
if vllm_is_batch_invariant(): if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon) return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x) out = torch.empty_like(x)
# if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if False: torch.ops.vllm.rms_norm_opt(
ops.rms_norm_opt(
x, x,
weight, weight,
out, out,
variance_epsilon, variance_epsilon,
) False,
)#False参数对当前的lightop调用的kernel是多余的
else: else:
ops.rms_norm( ops.rms_norm(
out, out,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment