Unverified Commit 63051738 authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Enable CPU device on SGLang (#2806)

parent a8ccacc8
......@@ -40,6 +40,10 @@ srt_xpu = ["sglang[runtime_common]"]
#For Intel Gaudi(device : hpu) follow the installation guide
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu = ["sglang[runtime_common]"]
# CPU: currently, there are no pre-built vllm wheels for CPU.
# To install vllm for CPU, please follow the instruction here:
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
srt_cpu = ["sglang[runtime_common]", "torch"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
......@@ -57,11 +61,13 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
dev_xpu = ["sglang[all_xpu]", "sglang[test]"]
dev_hpu = ["sglang[all_hpu]", "sglang[test]"]
dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
......
......@@ -10,7 +10,7 @@ class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu"]:
if device in ["cuda", "xpu", "hpu", "cpu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
......
......@@ -8,6 +8,7 @@ from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.topk import select_experts
......@@ -44,3 +45,71 @@ def fused_moe_forward_native(
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
def moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
)
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
len_experts = layer.num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
layer_w13_weight = layer.w13_weight[i]
layer_w2_weight = layer.w2_weight[i]
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
......@@ -19,7 +19,10 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
is_hip_flag = False
if not is_hip():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
if torch.cuda.is_available():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
else:
sgl_moe_align_block_size = None
is_hip_flag = False
else:
......
......@@ -13,6 +13,7 @@ from vllm.distributed import (
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
......@@ -185,8 +186,31 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace=True,
)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError("The CPU backend currently does not support MoE.")
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function,
correction_bias,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
......
......@@ -15,6 +15,15 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding,
_rotate_gptj,
_rotate_neox,
_yarn_find_correction_range,
_yarn_linear_ramp_mask,
get_rope,
yarn_get_mscale,
)
class MRotaryEmbedding:
......@@ -110,3 +119,242 @@ class MRotaryEmbedding:
)
for _ in range(3)
]
# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm,
# the device has been hard-coded to "cuda" in these two places:
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665
# We port the related code to this file to make it compatible with the CPU version.
# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then.
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = None,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale))
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.device = device
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
/ self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
device=self.device,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
def get_rope_cpu(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
assert rope_scaling is not None
scaling_type = rope_scaling["rope_type"]
assert (
scaling_type == "deepseek_yarn"
), "Only deepseek_yarn is supported for CPU for now"
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
extra_kwargs["device"] = device
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
_ROPE_DICT[key] = rotary_emb
return rotary_emb
def get_rope_wrapper(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
device,
)
......@@ -65,6 +65,7 @@ global_server_args_dict = {
"enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device,
}
......
......@@ -317,6 +317,8 @@ class Scheduler:
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
# Session info
self.sessions: Dict[str, Session] = {}
......
......@@ -82,6 +82,8 @@ class TpModelWorkerClient:
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
def get_worker_info(self):
return self.worker.get_worker_info()
......
......@@ -106,8 +106,10 @@ class ModelRunner:
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
# TODO: add MLA optimization on CPU
if self.server_args.device != "cpu":
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
logger.info(
......@@ -164,6 +166,7 @@ class ModelRunner:
"enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
}
)
......@@ -221,6 +224,8 @@ class ModelRunner:
backend = "gloo"
elif self.device == "hpu":
backend = "hccl"
elif self.device == "cpu":
backend = "gloo"
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
......@@ -269,7 +274,8 @@ class ModelRunner:
)
# This can reduce thread conflicts and speed up weight loading.
torch.set_num_threads(1)
if self.device != "cpu":
torch.set_num_threads(1)
if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
......
......@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -271,13 +272,14 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config,
)
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
if rope_scaling:
......
......@@ -392,7 +392,7 @@ class ServerArgs:
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu", "hpu"],
choices=["cuda", "xpu", "hpu", "cpu"],
help="The device type.",
)
parser.add_argument(
......
......@@ -223,6 +223,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
elif device == "cpu":
# TODO: rename the variables in the current function to be not GPU specific
free_gpu_memory = psutil.virtual_memory().available
if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
torch.device(device, gpu_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment