"vscode:/vscode.git/clone" did not exist on "0b5abba824d548932d7689188bf2c0cc67c9f658"
Unverified Commit f5f6b3b4 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Refactor fused_add_rmsnorm import logic (#10207)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 94fb4e9e
......@@ -26,6 +26,7 @@ from sglang.srt.utils import (
get_bool_env_var,
is_cpu,
is_cuda,
is_flashinfer_available,
is_hip,
is_npu,
is_xpu,
......@@ -33,6 +34,7 @@ from sglang.srt.utils import (
)
_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
......@@ -41,7 +43,10 @@ _is_cpu = is_cpu()
_is_xpu = is_xpu()
if _is_cuda:
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
if _is_flashinfer_available:
from flashinfer.norm import fused_add_rmsnorm
else:
from sgl_kernel import fused_add_rmsnorm
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
if _use_aiter:
......@@ -84,9 +89,7 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
flashinfer_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment