Unverified Commit eedc12e1 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support Deepseek MoE Model (#689)

parent 5a4ef2b5
"""Run the model with cuda graph.""" """Run the model with cuda graph."""
import bisect import bisect
from contextlib import contextmanager
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
...@@ -15,9 +16,10 @@ from sglang.srt.managers.controller.infer_batch import ( ...@@ -15,9 +16,10 @@ from sglang.srt.managers.controller.infer_batch import (
InputMetadata, InputMetadata,
init_flashinfer_args, init_flashinfer_args,
) )
from sglang.srt.utils import monkey_patch_vllm_all_gather
def _to_torch(model: torch.nn.Module, reverse=False): def _to_torch(model: torch.nn.Module, reverse: bool = False):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
if reverse: if reverse:
...@@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False): ...@@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False):
_to_torch(sub, reverse) _to_torch(sub, reverse)
def get_forward(model: torch.nn.Module, use_torch: bool): @contextmanager
if use_torch: def patch_model(
_to_torch(model, reverse=False) model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
return torch.compile(model.forward, mode="max-autotune-no-cudagraphs") ):
else: backup_ca_comm = None
_to_torch(model, reverse=True)
return model.forward try:
if use_compile:
_to_torch(model)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
tp_group.ca_comm = None
yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
else:
yield model.forward
finally:
if use_compile:
_to_torch(model, reverse=True)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
class CudaGraphRunner: class CudaGraphRunner:
...@@ -86,17 +101,21 @@ class CudaGraphRunner: ...@@ -86,17 +101,21 @@ class CudaGraphRunner:
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
for bs in batch_size_list: for bs in batch_size_list:
forward = get_forward(self.model_runner.model, bs in self.compile_bs) with patch_model(
( self.model_runner.model,
graph, bs in self.compile_bs,
input_buffers, self.model_runner.tp_group,
output_buffers, ) as forward:
flashinfer_handler, (
) = self.capture_one_batch_size(bs, forward) graph,
self.graphs[bs] = graph input_buffers,
self.input_buffers[bs] = input_buffers output_buffers,
self.output_buffers[bs] = output_buffers flashinfer_handler,
self.flashinfer_handlers[bs] = flashinfer_handler ) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs, forward): def capture_one_batch_size(self, bs, forward):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
......
...@@ -22,7 +22,6 @@ from vllm.distributed import ( ...@@ -22,7 +22,6 @@ from vllm.distributed import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -241,7 +240,9 @@ class ModelRunner: ...@@ -241,7 +240,9 @@ class ModelRunner:
self.cuda_graph_runner = None self.cuda_graph_runner = None
return return
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") logger.info(
f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, self,
...@@ -252,7 +253,7 @@ class ModelRunner: ...@@ -252,7 +253,7 @@ class ModelRunner:
self.cuda_graph_runner.capture(batch_size_list) self.cuda_graph_runner.capture(batch_size_list)
except RuntimeError as e: except RuntimeError as e:
raise Exception( raise Exception(
f"Capture cuda graph failed {e}. Possible solutions:\n" f"Capture cuda graph failed: {e}. Possible solutions:\n"
f"1. disable cuda graph by --disable-cuda-graph\n" f"1. disable cuda graph by --disable-cuda-graph\n"
f"2. set --mem-fraction-static to a smaller value\n" f"2. set --mem-fraction-static to a smaller value\n"
f"Open an issue on GitHub with reproducible scripts if you need help.\n" f"Open an issue on GitHub with reproducible scripts if you need help.\n"
......
# Adapted from:
# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.infer_batch import InputMetadata
class DeepseekMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}."
)
self.experts = nn.ModuleList(
[
DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
for idx in range(self.n_routed_experts)
]
)
self.pack_params()
self.gate = ReplicatedLinear(
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True,
)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class DeepseekAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = DeepseekAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
DeepseekDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = DeepseekForCausalLM
...@@ -167,7 +167,7 @@ def _set_torch_compile_config(): ...@@ -167,7 +167,7 @@ def _set_torch_compile_config():
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround # FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 128 torch._dynamo.config.accumulated_cache_size_limit = 256
def launch_server( def launch_server(
......
...@@ -411,6 +411,52 @@ def monkey_patch_vllm_dummy_weight_loader(): ...@@ -411,6 +411,52 @@ def monkey_patch_vllm_dummy_weight_loader():
setattr(DummyModelLoader, "load_model", load_model) setattr(DummyModelLoader, "load_model", load_model)
vllm_all_gather_backup = None
def monkey_patch_vllm_all_gather(reverse: bool = False):
"""Monkey patch all-gather to remove in-place operations."""
from torch.distributed import _functional_collectives as funcol
from vllm.distributed.parallel_state import GroupCoordinator
global vllm_all_gather_backup
if vllm_all_gather_backup is None:
vllm_all_gather_backup = GroupCoordinator.all_gather
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty(
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
)
output_tensor = funcol.all_gather_tensor(
input_, gather_dim=0, group=self.device_group
).view((world_size,) + input_size)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor
if reverse:
setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
else:
setattr(GroupCoordinator, "all_gather", all_gather)
API_KEY_HEADER_NAME = "X-API-Key" API_KEY_HEADER_NAME = "X-API-Key"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment