Unverified Commit 094c116f authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)


Co-authored-by: default avatarWu, Chunyuan <chunyuan.wu@intel.com>
Co-authored-by: default avatarjianan-gu <jianan.gu@intel.com>
Co-authored-by: default avatarsdp <sdp@gnr799219.jf.intel.com>
parent e56685ac
...@@ -39,6 +39,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \ ...@@ -39,6 +39,7 @@ RUN git clone https://github.com/sgl-project/sglang.git && \
cp pyproject_cpu.toml pyproject.toml && \ cp pyproject_cpu.toml pyproject.toml && \
pip install -v . pip install -v .
ENV SGLANG_USE_CPU_ENGINE=1
ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2 ENV LD_PRELOAD=/sgl-workspace/miniforge3/lib/libiomp5.so:/sgl-workspace/miniforge3/lib/libtcmalloc.so:/sgl-workspace/miniforge3/lib/libtbbmalloc.so.2
WORKDIR /sgl-workspace/sglang WORKDIR /sgl-workspace/sglang
from torch import nn from torch import nn
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -75,5 +77,7 @@ class CustomOp(nn.Module): ...@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
return self.forward_cuda return self.forward_cuda
elif _is_hip: elif _is_hip:
return self.forward_hip return self.forward_hip
elif _is_cpu and _is_cpu_amx_available:
return self.forward_cpu
else: else:
return self.forward_native return self.forward_native
...@@ -29,11 +29,19 @@ from sglang.srt.distributed import ( ...@@ -29,11 +29,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_cuda,
is_npu,
set_weight_attrs,
)
from sglang.utils import resolve_obj_by_qualname from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
...@@ -53,6 +61,15 @@ class SiluAndMul(CustomOp): ...@@ -53,6 +61,15 @@ class SiluAndMul(CustomOp):
silu_and_mul(x, out) silu_and_mul(x, out)
return out return out
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
if _is_cpu_amx_available:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
return out
else:
return self.forward_native(x)
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"): def __init__(self, approximate="tanh"):
...@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): ...@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity() return nn.Identity()
if not _is_cuda and not _is_npu: if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
logger.info( logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
) )
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
...@@ -20,12 +20,21 @@ import torch ...@@ -20,12 +20,21 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -122,6 +131,23 @@ class RMSNorm(CustomOp): ...@@ -122,6 +131,23 @@ class RMSNorm(CustomOp):
else: else:
return x, residual return x, residual
def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if _is_cpu_amx_available:
if residual is not None:
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
return torch.ops.sgl_kernel.rmsnorm_cpu(
x, self.weight.data, self.variance_epsilon
)
else:
return self.forward_native(x, residual)
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
...@@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module): ...@@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
return f"{tuple(self.weight.shape)}, eps={self.eps}" return f"{tuple(self.weight.shape)}, eps={self.eps}"
if not (_is_cuda or _is_hip or _is_npu): if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
logger.info( logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
) )
......
...@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import ( ...@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support,
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
get_device_name, get_device_name,
is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
log_info_on_rank0, log_info_on_rank0,
...@@ -36,9 +38,13 @@ from sglang.srt.utils import ( ...@@ -36,9 +38,13 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
elif _is_cpu and _is_cpu_amx_available:
pass
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -241,7 +241,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -241,7 +241,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True, inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return moe_forward_native( return moe_forward_native(
layer, layer,
...@@ -260,7 +264,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -260,7 +264,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.") raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cuda forward_native = forward_cpu
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
......
...@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import ( ...@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip from sglang.srt.utils import (
cpu_has_amx_support,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
)
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
...@@ -40,7 +48,7 @@ if _is_cuda or _is_hip: ...@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
def fused_topk_native( def fused_topk_torch_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
...@@ -61,6 +69,20 @@ def fused_topk_native( ...@@ -61,6 +69,20 @@ def fused_topk_native(
return topk_weights, topk_ids return topk_weights, topk_ids
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
return torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
def fused_topk( def fused_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -115,7 +137,7 @@ def _fused_topk_postprocess( ...@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
# This is used by the Deepseek V2/V3/R1 series models # This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk( def grouped_topk_gpu(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
...@@ -171,6 +193,32 @@ def grouped_topk( ...@@ -171,6 +193,32 @@ def grouped_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
def biased_grouped_topk_impl( def biased_grouped_topk_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
...@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess( ...@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
return topk_ids return topk_ids
def biased_grouped_topk( def biased_grouped_topk_gpu(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
correction_bias: torch.Tensor, correction_bias: torch.Tensor,
...@@ -322,6 +370,45 @@ def biased_grouped_topk( ...@@ -322,6 +370,45 @@ def biased_grouped_topk(
) )
def biased_grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu
else:
biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu
fused_topk_native = fused_topk_torch_native
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
......
...@@ -14,15 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu ...@@ -14,15 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
cpu_has_amx_support,
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if not _is_cuda and not _is_npu: if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -64,7 +64,9 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -64,7 +64,9 @@ from sglang.srt.layers.quantization.utils import (
) )
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
...@@ -76,6 +78,8 @@ from sglang.srt.utils import ( ...@@ -76,6 +78,8 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -88,7 +92,7 @@ if _is_hip: ...@@ -88,7 +92,7 @@ if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
if not _is_cuda and not _is_npu: if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -6,12 +6,14 @@ from typing import List, Mapping, Tuple, Union ...@@ -6,12 +6,14 @@ from typing import List, Mapping, Tuple, Union
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if not _is_cuda and not _is_npu: if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
......
...@@ -8,11 +8,13 @@ import torch ...@@ -8,11 +8,13 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
...@@ -85,7 +87,9 @@ class RotaryEmbedding(CustomOp): ...@@ -85,7 +87,9 @@ class RotaryEmbedding(CustomOp):
if not _is_cuda: if not _is_cuda:
cache = cache.to(dtype) cache = cache.to(dtype)
if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]: if (
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
) and not (_is_cpu and _is_cpu_amx_available):
from vllm._custom_ops import rotary_embedding from vllm._custom_ops import rotary_embedding
self.vllm_rotary_embedding = rotary_embedding self.vllm_rotary_embedding = rotary_embedding
...@@ -148,6 +152,26 @@ class RotaryEmbedding(CustomOp): ...@@ -148,6 +152,26 @@ class RotaryEmbedding(CustomOp):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
else:
return self.forward_native(positions, query, key, offsets)
def forward_cuda( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -697,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -697,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key = key_rot key = key_rot
return query.to(dtype), key.to(dtype) return query.to(dtype), key.to(dtype)
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
positions, query, key, self.head_size, self.cos_sin_cache, False
)
else:
return self.forward_native(positions, query, key, offsets)
class Llama3RotaryEmbedding(RotaryEmbedding): class Llama3RotaryEmbedding(RotaryEmbedding):
......
...@@ -111,6 +111,7 @@ from sglang.srt.utils import ( ...@@ -111,6 +111,7 @@ from sglang.srt.utils import (
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
# Use a small KV cache pool size for tests in CI # Use a small KV cache pool size for tests in CI
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
...@@ -302,7 +303,7 @@ class ModelRunner: ...@@ -302,7 +303,7 @@ class ModelRunner:
if ( if (
server_args.attention_backend == "intel_amx" server_args.attention_backend == "intel_amx"
and server_args.device == "cpu" and server_args.device == "cpu"
and not cpu_has_amx_support() and not _is_cpu_amx_available
): ):
logger.info( logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend." "The current platform does not support Intel AMX, will fallback to torch_native backend."
......
...@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import ( ...@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant, block_dequant as int8_block_dequant,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -95,8 +95,10 @@ from sglang.srt.utils import ( ...@@ -95,8 +95,10 @@ from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
bind_or_assign, bind_or_assign,
cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
is_non_idle_and_non_empty, is_non_idle_and_non_empty,
...@@ -107,9 +109,13 @@ _is_hip = is_hip() ...@@ -107,9 +109,13 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
elif _is_cpu and _is_cpu_amx_available:
pass
else: else:
from vllm._custom_ops import awq_dequantize from vllm._custom_ops import awq_dequantize
...@@ -665,13 +671,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -665,13 +671,14 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope( self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"],
) )
if rope_scaling: if rope_scaling:
......
...@@ -160,7 +160,7 @@ def is_npu() -> bool: ...@@ -160,7 +160,7 @@ def is_npu() -> bool:
return hasattr(torch, "npu") and torch.npu.is_available() return hasattr(torch, "npu") and torch.npu.is_available()
def is_cpu() -> bool: def is_host_cpu_x86() -> bool:
machine = platform.machine().lower() machine = platform.machine().lower()
return ( return (
machine in ("x86_64", "amd64", "i386", "i686") machine in ("x86_64", "amd64", "i386", "i686")
...@@ -169,6 +169,10 @@ def is_cpu() -> bool: ...@@ -169,6 +169,10 @@ def is_cpu() -> bool:
) )
def is_cpu() -> bool:
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
def is_flashinfer_available(): def is_flashinfer_available():
""" """
Check whether flashinfer is available. Check whether flashinfer is available.
...@@ -1452,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str: ...@@ -1452,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'." "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
) )
if is_cpu():
if cpu_has_amx_support():
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
else:
logger.warning(
"CPU device enabled, using torch native backend, low performance expected."
)
return "cpu"
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.") raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
......
...@@ -21,7 +21,7 @@ class TestActivation(CustomTestCase): ...@@ -21,7 +21,7 @@ class TestActivation(CustomTestCase):
ref_out = SiluAndMul(x) ref_out = SiluAndMul(x)
atol = rtol = precision[ref_out.dtype] atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self): def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype): for params in itertools.product(self.M, self.N, self.dtype):
......
...@@ -60,8 +60,8 @@ class TestGemm(CustomTestCase): ...@@ -60,8 +60,8 @@ class TestGemm(CustomTestCase):
) )
atol = rtol = precision[ref.dtype] atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(ref, out2, atol=atol, rtol=rtol)) torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
def test_bf16_gemm(self): def test_bf16_gemm(self):
for params in itertools.product( for params in itertools.product(
...@@ -100,13 +100,13 @@ class TestGemm(CustomTestCase): ...@@ -100,13 +100,13 @@ class TestGemm(CustomTestCase):
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu( out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
) )
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
# test the fused version # test the fused version
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant( fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
) )
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
def test_int8_gemm(self): def test_int8_gemm(self):
for params in itertools.product( for params in itertools.product(
...@@ -165,7 +165,7 @@ class TestGemm(CustomTestCase): ...@@ -165,7 +165,7 @@ class TestGemm(CustomTestCase):
prepack, prepack,
) )
atol = rtol = precision[ref.dtype] atol = rtol = precision[ref.dtype]
self.assertTrue(torch.allclose(ref, opt, atol=atol, rtol=rtol)) torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
def test_fp8_gemm(self): def test_fp8_gemm(self):
for params in itertools.product( for params in itertools.product(
......
...@@ -91,9 +91,7 @@ class TestFusedExperts(CustomTestCase): ...@@ -91,9 +91,7 @@ class TestFusedExperts(CustomTestCase):
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack) fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype] atol = rtol = precision[torch_output.dtype]
self.assertTrue( torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol)
)
def test_bf16_moe(self): def test_bf16_moe(self):
for params in itertools.product( for params in itertools.product(
...@@ -171,7 +169,7 @@ class TestFusedExperts(CustomTestCase): ...@@ -171,7 +169,7 @@ class TestFusedExperts(CustomTestCase):
# Increase the tolerance for large input shapes # Increase the tolerance for large input shapes
if M > 35: if M > 35:
atol = rtol = 0.02 atol = rtol = 0.02
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_int8_moe(self): def test_int8_moe(self):
for params in itertools.product( for params in itertools.product(
...@@ -235,7 +233,7 @@ class TestFusedExperts(CustomTestCase): ...@@ -235,7 +233,7 @@ class TestFusedExperts(CustomTestCase):
) )
atol = rtol = precision[dtype] atol = rtol = precision[dtype]
self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
def test_fp8_moe(self): def test_fp8_moe(self):
for params in itertools.product( for params in itertools.product(
......
...@@ -47,7 +47,7 @@ class TestNorm(CustomTestCase): ...@@ -47,7 +47,7 @@ class TestNorm(CustomTestCase):
ref_out = self._forward_native(x, weight, variance_epsilon) ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype] atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone() ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype) residual = torch.randn([m, hidden_size], dtype=dtype)
...@@ -61,8 +61,8 @@ class TestNorm(CustomTestCase): ...@@ -61,8 +61,8 @@ class TestNorm(CustomTestCase):
ref_x, weight, variance_epsilon, ref_residual ref_x, weight, variance_epsilon, ref_residual
) )
self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol)) torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol)) torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype): def _l2norm_test(self, m, n, dtype):
...@@ -75,7 +75,7 @@ class TestNorm(CustomTestCase): ...@@ -75,7 +75,7 @@ class TestNorm(CustomTestCase):
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon) ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype] atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self): def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype): for params in itertools.product(self.M, self.N, self.dtype):
......
...@@ -211,12 +211,12 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -211,12 +211,12 @@ class TestQKVProjWithROPE(CustomTestCase):
qk_rope_head_dim, qk_rope_head_dim,
) )
atol = rtol = precision[q_ref.dtype] atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(fused_q_out, q_out)) torch.testing.assert_close(fused_q_out, q_out)
self.assertTrue(torch.allclose(fused_k_out, k_out)) torch.testing.assert_close(fused_k_out, k_out)
self.assertTrue(torch.allclose(fused_v_out, v_out)) torch.testing.assert_close(fused_v_out, v_out)
def test_int8_qkv_proj_with_rope(self): def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -302,12 +302,12 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -302,12 +302,12 @@ class TestQKVProjWithROPE(CustomTestCase):
qk_rope_head_dim, qk_rope_head_dim,
) )
atol = rtol = precision[q_ref.dtype] atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(fused_q_out, q_out)) torch.testing.assert_close(fused_q_out, q_out)
self.assertTrue(torch.allclose(fused_k_out, k_out)) torch.testing.assert_close(fused_k_out, k_out)
self.assertTrue(torch.allclose(fused_v_out, v_out)) torch.testing.assert_close(fused_v_out, v_out)
def test_fp8_qkv_proj_with_rope(self): def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
......
...@@ -75,8 +75,8 @@ class TestROPE(CustomTestCase): ...@@ -75,8 +75,8 @@ class TestROPE(CustomTestCase):
) )
atol = rtol = precision[q_pe.dtype] atol = rtol = precision[q_pe.dtype]
self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol)) torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol)) torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone) torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self): def test_origin_rope(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment