Unverified Commit 729b7edf authored by Huaiyu, Zheng's avatar Huaiyu, Zheng Committed by GitHub
Browse files

enable rmsnorm on XPU (#10248)

parent 4c03dbaa
......@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
configure_logger,
get_bool_env_var,
is_cuda_alike,
is_xpu,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
......@@ -80,6 +82,15 @@ from sglang.srt.utils import (
)
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
profiler_activity
for available, profiler_activity in [
(is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
(is_xpu(), torch.profiler.ProfilerActivity.XPU),
]
if available
]
@dataclasses.dataclass
class BenchArgs:
......@@ -424,10 +435,7 @@ def latency_test_run_once(
profiler = None
if profile:
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
activities=profile_activities,
with_stack=True,
record_shapes=profile_record_shapes,
)
......@@ -460,10 +468,7 @@ def latency_test_run_once(
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
activities=profile_activities,
with_stack=True,
record_shapes=profile_record_shapes,
)
......
......@@ -52,8 +52,13 @@ if _is_cuda:
gemma_rmsnorm,
rmsnorm,
)
elif _is_xpu:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
......@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
else:
return self.forward_native(x, residual)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_with_allreduce_fusion(
self,
x: torch.Tensor,
......@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
if _is_hip:
self._forward_method = self.forward_native
def _forward_impl(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
......@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
return self._forward_impl(x, residual)
def forward_npu(
self,
......@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
return x if residual is None else (x, residual)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self._forward_impl(x, residual)
class Gemma3RMSNorm(CustomOp):
def __init__(self, dim: int, eps: float = 1e-6):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment