Unverified Commit bc5fc332 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix slow fused add RMSNorm (#10141)

parent f3440adc
...@@ -39,12 +39,8 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -39,12 +39,8 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
fused_add_rmsnorm, from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter: if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd as rms_norm
...@@ -86,7 +82,9 @@ class RMSNorm(CustomOp): ...@@ -86,7 +82,9 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
if residual is not None: if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) flashinfer_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon) out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment