Unverified Commit 7f19e083 authored by tarinkk's avatar tarinkk Committed by GitHub
Browse files

Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)


Co-authored-by: default avatarCheng Wan <cwan39@gatech.edu>
parent 98a2cfa9
......@@ -90,7 +90,7 @@ Please consult the documentation below to learn more about the parameters you ma
### Expert parallelism
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. Currently DeepEP is bind to DP Attention. Please set `--enable-dp-attention --enable-deepep-moe`, perfer `tp_size=dp_size=ep_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
## Memory and scheduling
......@@ -184,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma
*Note: Some of these options are still in experimental stage.*
* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163).
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this.
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models.
* `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`.
* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics.
......
......@@ -5,7 +5,7 @@ import logging
import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar, Union
import torch
import torch.distributed as dist
......
......@@ -439,6 +439,15 @@ class GroupCoordinator:
else:
torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
......@@ -456,11 +465,23 @@ class GroupCoordinator:
output, input, group_name=self.unique_name
)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(
self,
input_: torch.Tensor,
dim: int = -1,
tensor_list: List[torch.Tensor] = None,
) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if tensor_list is not None:
# TODO(ch-wan): support other backends
return torch.distributed.all_gather(
tensor_list, input_, group=self.device_group
)
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, List
import torch
import triton
......@@ -249,3 +249,14 @@ def dp_scatter(
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
......@@ -1186,7 +1186,7 @@ class Scheduler(
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention:
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
......
......@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -245,8 +246,8 @@ class CudaGraphRunner:
)
else:
self.encoder_lens = None
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
......@@ -288,7 +289,7 @@ class CudaGraphRunner:
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
......@@ -369,7 +370,7 @@ class CudaGraphRunner:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
......@@ -471,7 +472,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left(
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
)
......@@ -497,7 +498,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"):
......
......@@ -281,9 +281,6 @@ class ModelRunner:
if server_args.enable_deepep_moe:
logger.info("DeepEP is turned on.")
assert (
server_args.enable_dp_attention == True
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......
......@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if forward_mode is not None and not forward_mode.is_idle():
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
if self.n_shared_experts is not None:
......@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn: bool = False,
prefix: str = "",
) -> None:
def is_sparse_layer(l: int):
return (
config.n_routed_experts is not None
and l >= config.first_k_dense_replace
and l % config.moe_layer_freq == 0
)
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
......@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
if is_nextn or is_sparse_layer(layer_id):
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = True
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
......@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = False
self.input_is_scattered = (
is_sparse_layer(layer_id - 1)
and global_server_args_dict["enable_deepep_moe"]
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
......@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
return self.forward_deepep(
positions, hidden_states, forward_batch, residual
)
else:
return self.forward_normal(
positions, hidden_states, forward_batch, residual
)
def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
......@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch,
)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if global_server_args_dict["enable_deepep_moe"] and isinstance(
self.mlp, DeepseekV2MoE
):
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
return hidden_states, residual
else:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
......@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected
hidden_states = self.mlp(hidden_states)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
......@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
def forward_deepep(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
return hidden_states, residual
class DeepseekV2Model(nn.Module):
......
......@@ -290,12 +290,17 @@ class ServerArgs:
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
)
# DeepEP MoE
if self.enable_deepep_moe:
self.ep_size = self.dp_size
logger.info(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
)
self.enable_sp_layernorm = False
# DeepEP MoE
if self.enable_deepep_moe:
self.ep_size = self.tp_size
self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True
)
logger.info(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Speculative Decoding
if self.speculative_algorithm == "NEXTN":
......
......@@ -12,7 +12,42 @@ from sglang.test.test_utils import (
)
class TestDeepEPMoE(CustomTestCase):
class TestPureTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-deepep-moe",
"--disable-cuda-graph",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
class TestDPAttn(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment