Unverified Commit f5f6b3b4 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

Refactor fused_add_rmsnorm import logic (#10207)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 94fb4e9e
...@@ -26,6 +26,7 @@ from sglang.srt.utils import ( ...@@ -26,6 +26,7 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cpu, is_cpu,
is_cuda, is_cuda,
is_flashinfer_available,
is_hip, is_hip,
is_npu, is_npu,
is_xpu, is_xpu,
...@@ -33,6 +34,7 @@ from sglang.srt.utils import ( ...@@ -33,6 +34,7 @@ from sglang.srt.utils import (
) )
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...@@ -41,7 +43,10 @@ _is_cpu = is_cpu() ...@@ -41,7 +43,10 @@ _is_cpu = is_cpu()
_is_xpu = is_xpu() _is_xpu = is_xpu()
if _is_cuda: if _is_cuda:
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm if _is_flashinfer_available:
from flashinfer.norm import fused_add_rmsnorm
else:
from sgl_kernel import fused_add_rmsnorm
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
if _use_aiter: if _use_aiter:
...@@ -84,9 +89,7 @@ class RMSNorm(CustomOp): ...@@ -84,9 +89,7 @@ class RMSNorm(CustomOp):
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
if residual is not None: if residual is not None:
flashinfer_fused_add_rmsnorm( fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon) out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment