Unverified Commit 729b7edf authored by Huaiyu, Zheng's avatar Huaiyu, Zheng Committed by GitHub
Browse files

enable rmsnorm on XPU (#10248)

parent 4c03dbaa
...@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
get_bool_env_var, get_bool_env_var,
is_cuda_alike,
is_xpu,
kill_process_tree, kill_process_tree,
require_mlp_sync, require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
...@@ -80,6 +82,15 @@ from sglang.srt.utils import ( ...@@ -80,6 +82,15 @@ from sglang.srt.utils import (
) )
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
profiler_activity
for available, profiler_activity in [
(is_cuda_alike(), torch.profiler.ProfilerActivity.CUDA),
(is_xpu(), torch.profiler.ProfilerActivity.XPU),
]
if available
]
@dataclasses.dataclass @dataclasses.dataclass
class BenchArgs: class BenchArgs:
...@@ -424,10 +435,7 @@ def latency_test_run_once( ...@@ -424,10 +435,7 @@ def latency_test_run_once(
profiler = None profiler = None
if profile: if profile:
profiler = torch.profiler.profile( profiler = torch.profiler.profile(
activities=[ activities=profile_activities,
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True, with_stack=True,
record_shapes=profile_record_shapes, record_shapes=profile_record_shapes,
) )
...@@ -460,10 +468,7 @@ def latency_test_run_once( ...@@ -460,10 +468,7 @@ def latency_test_run_once(
if profile and i == output_len / 2: if profile and i == output_len / 2:
profiler = None profiler = None
profiler = torch.profiler.profile( profiler = torch.profiler.profile(
activities=[ activities=profile_activities,
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True, with_stack=True,
record_shapes=profile_record_shapes, record_shapes=profile_record_shapes,
) )
......
...@@ -52,8 +52,13 @@ if _is_cuda: ...@@ -52,8 +52,13 @@ if _is_cuda:
gemma_rmsnorm, gemma_rmsnorm,
rmsnorm, rmsnorm,
) )
elif _is_xpu:
from sgl_kernel import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
if _use_aiter: if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
...@@ -216,6 +221,19 @@ class RMSNorm(CustomOp): ...@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
else: else:
return self.forward_native(x, residual) return self.forward_native(x, residual)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_with_allreduce_fusion( def forward_with_allreduce_fusion(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp): ...@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
if _is_hip: if _is_hip:
self._forward_method = self.forward_native self._forward_method = self.forward_native
def _forward_impl(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
...@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None: return self._forward_impl(x, residual)
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_npu( def forward_npu(
self, self,
...@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp): ...@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
return x if residual is None else (x, residual) return x if residual is None else (x, residual)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self._forward_impl(x, residual)
class Gemma3RMSNorm(CustomOp): class Gemma3RMSNorm(CustomOp):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment