Unverified Commit bc5fc332 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix slow fused add RMSNorm (#10141)

parent f3440adc
......@@ -39,12 +39,8 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
......@@ -86,7 +82,9 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
flashinfer_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment