Unverified Commit ee21817c authored by Huaiyu, Zheng's avatar Huaiyu, Zheng Committed by GitHub
Browse files

enable llama3.1-8B on xpu (#9434)

parent b7d1f17b
from torch import nn from torch import nn
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_hip,
is_npu,
is_xpu,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu() _is_npu = is_npu()
_is_xpu = is_xpu()
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -88,5 +96,7 @@ class CustomOp(nn.Module): ...@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
return self.forward_cpu return self.forward_cpu
elif _is_npu: elif _is_npu:
return self.forward_npu return self.forward_npu
elif _is_xpu:
return self.forward_xpu
else: else:
return self.forward_native return self.forward_native
...@@ -35,6 +35,7 @@ from sglang.srt.utils import ( ...@@ -35,6 +35,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
is_xpu,
set_weight_attrs, set_weight_attrs,
) )
from sglang.utils import resolve_obj_by_qualname from sglang.utils import resolve_obj_by_qualname
...@@ -44,8 +45,9 @@ _is_npu = is_npu() ...@@ -44,8 +45,9 @@ _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_hip = is_hip() _is_hip = is_hip()
_is_xpu = is_xpu()
if _is_cuda: if _is_cuda or _is_xpu:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif _is_hip: elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
...@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp): ...@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor: def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available: if _is_cpu_amx_available:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x) out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
return out return out
else: else:
...@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp): ...@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
out = torch_npu.npu_swiglu(x) out = torch_npu.npu_swiglu(x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"): def __init__(self, approximate="tanh"):
super().__init__() super().__init__()
self.approximate = approximate self.approximate = approximate
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,) output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
...@@ -103,6 +106,16 @@ class GeluAndMul(CustomOp): ...@@ -103,6 +106,16 @@ class GeluAndMul(CustomOp):
raise RuntimeError("GeluAndMul only support tanh or none") raise RuntimeError("GeluAndMul only support tanh or none")
return out return out
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self._forward_impl(x)
def forward_npu(self, x: torch.Tensor) -> torch.Tensor: def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
y_npu, gelu_npu = torch_npu.npu_geglu( y_npu, gelu_npu = torch_npu.npu_geglu(
x, x,
...@@ -242,7 +255,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): ...@@ -242,7 +255,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity() return nn.Identity()
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): if not (
_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
):
logger.info( logger.info(
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
) )
......
...@@ -28,6 +28,7 @@ from sglang.srt.utils import ( ...@@ -28,6 +28,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
is_xpu,
supports_custom_op, supports_custom_op,
) )
...@@ -37,6 +38,7 @@ _is_npu = is_npu() ...@@ -37,6 +38,7 @@ _is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_xpu = is_xpu()
if _is_cuda: if _is_cuda:
from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm from flashinfer.norm import fused_add_rmsnorm as flashinfer_fused_add_rmsnorm
...@@ -327,7 +329,9 @@ class Gemma3RMSNorm(CustomOp): ...@@ -327,7 +329,9 @@ class Gemma3RMSNorm(CustomOp):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)): if not (
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
):
logger.info( logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
) )
......
...@@ -8,10 +8,11 @@ import psutil ...@@ -8,10 +8,11 @@ import psutil
import torch import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import is_npu from sglang.srt.utils import is_npu, is_xpu
_is_npu = is_npu() _is_npu = is_npu()
if not _is_npu: _is_xpu = is_xpu()
if not (_is_npu or _is_xpu):
from sgl_kernel.kvcacheio import ( from sgl_kernel.kvcacheio import (
transfer_kv_all_layer, transfer_kv_all_layer,
transfer_kv_all_layer_lf_pf, transfer_kv_all_layer_lf_pf,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment