"docs/vscode:/vscode.git/clone" did not exist on "81456af8c532739b41e1cbb345953b993d70316f"
Unverified Commit 7f19e083 authored by tarinkk's avatar tarinkk Committed by GitHub
Browse files

Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)


Co-authored-by: default avatarCheng Wan <cwan39@gatech.edu>
parent 98a2cfa9
...@@ -90,7 +90,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -90,7 +90,7 @@ Please consult the documentation below to learn more about the parameters you ma
### Expert parallelism ### Expert parallelism
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models. * `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`. * `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. Currently DeepEP is bind to DP Attention. Please set `--enable-dp-attention --enable-deepep-moe`, perfer `tp_size=dp_size=ep_size`. * `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
## Memory and scheduling ## Memory and scheduling
...@@ -184,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -184,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma
*Note: Some of these options are still in experimental stage.* *Note: Some of these options are still in experimental stage.*
* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163). * `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163).
* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this. * `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models.
* `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile). * `enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`. * `torch_compile_max_bs`: The maximum batch size when using `torch_compile`.
* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics. * `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics.
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union from typing import Any, Callable, List, Optional, TypeVar, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -439,6 +439,15 @@ class GroupCoordinator: ...@@ -439,6 +439,15 @@ class GroupCoordinator:
else: else:
torch.distributed.all_reduce(input_, group=self.device_group) torch.distributed.all_reduce(input_, group=self.device_group)
def reduce_scatter(
self,
output: torch.Tensor,
input_list: List[torch.Tensor],
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled: if pynccl_comm is not None and not pynccl_comm.disabled:
...@@ -456,11 +465,23 @@ class GroupCoordinator: ...@@ -456,11 +465,23 @@ class GroupCoordinator:
output, input, group_name=self.unique_name output, input, group_name=self.unique_name
) )
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gather(
self,
input_: torch.Tensor,
dim: int = -1,
tensor_list: List[torch.Tensor] = None,
) -> torch.Tensor:
world_size = self.world_size world_size = self.world_size
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
if tensor_list is not None:
# TODO(ch-wan): support other backends
return torch.distributed.all_gather(
tensor_list, input_, group=self.device_group
)
assert ( assert (
-input_.dim() <= dim < input_.dim() -input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import functools import functools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, List
import torch import torch
import triton import triton
...@@ -249,3 +249,14 @@ def dp_scatter( ...@@ -249,3 +249,14 @@ def dp_scatter(
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
) )
def tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
...@@ -1186,7 +1186,7 @@ class Scheduler( ...@@ -1186,7 +1186,7 @@ class Scheduler(
ret = None ret = None
# Handle DP attention # Handle DP attention
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret) ret, _ = self.prepare_dp_attn_batch(ret)
return ret return ret
......
...@@ -174,6 +174,7 @@ class CudaGraphRunner: ...@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
...@@ -245,8 +246,8 @@ class CudaGraphRunner: ...@@ -245,8 +246,8 @@ class CudaGraphRunner:
) )
else: else:
self.encoder_lens = None self.encoder_lens = None
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.enable_dp_attention: # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.dp_size * self.num_tokens_per_bs, self.max_bs * self.dp_size * self.num_tokens_per_bs,
...@@ -288,7 +289,7 @@ class CudaGraphRunner: ...@@ -288,7 +289,7 @@ class CudaGraphRunner:
self.model_runner.token_to_kv_pool.capture_mode = False self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu) total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
...@@ -369,7 +370,7 @@ class CudaGraphRunner: ...@@ -369,7 +370,7 @@ class CudaGraphRunner:
encoder_lens = None encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention: if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
...@@ -471,7 +472,7 @@ class CudaGraphRunner: ...@@ -471,7 +472,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left( index = bisect.bisect_left(
self.capture_bs, sum(forward_batch.global_num_tokens_cpu) self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
) )
...@@ -497,7 +498,7 @@ class CudaGraphRunner: ...@@ -497,7 +498,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention: if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"): if hasattr(forward_batch.spec_info, "hidden_states"):
......
...@@ -281,9 +281,6 @@ class ModelRunner: ...@@ -281,9 +281,6 @@ class ModelRunner:
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:
logger.info("DeepEP is turned on.") logger.info("DeepEP is turned on.")
assert (
server_args.enable_dp_attention == True
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
......
...@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import ( ...@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
topk_weights = torch.empty( topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device (0, self.top_k), dtype=torch.float32, device=hidden_states.device
) )
if forward_mode is not None and not forward_mode.is_idle(): if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
...@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -969,6 +975,14 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn: bool = False, is_nextn: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
def is_sparse_layer(l: int):
return (
config.n_routed_experts is not None
and l >= config.first_k_dense_replace
and l % config.moe_layer_freq == 0
)
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -977,6 +991,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id self.layer_id = layer_id
self.dp_size = get_attention_dp_size() self.dp_size = get_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
if not global_server_args_dict["disable_mla"]: if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
...@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1019,16 +1035,13 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
if is_nextn or ( if is_nextn or is_sparse_layer(layer_id):
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
self.mlp = DeepseekV2MoE( self.mlp = DeepseekV2MoE(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
self.is_sparse = True
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1037,6 +1050,14 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
self.is_sparse = False
self.input_is_scattered = (
is_sparse_layer(layer_id - 1)
and global_server_args_dict["enable_deepep_moe"]
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
...@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1049,6 +1070,23 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
return self.forward_deepep(
positions, hidden_states, forward_batch, residual
)
else:
return self.forward_normal(
positions, hidden_states, forward_batch, residual
)
def forward_normal(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
residual = hidden_states residual = hidden_states
else: else:
...@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1065,29 +1103,35 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
# Gather # Gather
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce # all gather and all reduce
if self.dp_size != 1: if self.dp_size != 1:
if global_server_args_dict["enable_deepep_moe"] and isinstance( if self.attn_tp_rank == 0:
self.mlp, DeepseekV2MoE hidden_states += residual
): hidden_states, local_hidden_states = (
if hidden_states.shape[0] != 0: forward_batch.gathered_buffer,
hidden_states, residual = self.post_attention_layernorm( hidden_states,
hidden_states, residual )
) dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) dp_scatter(residual, hidden_states, forward_batch)
return hidden_states, residual hidden_states = self.post_attention_layernorm(hidden_states)
else:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1101,6 +1145,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
if self.dp_size != 1: if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
...@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1113,6 +1158,82 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def forward_deepep(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
residual, local_residual = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
residual,
)
tp_all_gather(
list(residual.tensor_split(self.attn_tp_size)), local_residual
)
return hidden_states, residual
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
......
...@@ -290,12 +290,17 @@ class ServerArgs: ...@@ -290,12 +290,17 @@ class ServerArgs:
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
) )
# DeepEP MoE
if self.enable_deepep_moe: self.enable_sp_layernorm = False
self.ep_size = self.dp_size # DeepEP MoE
logger.info( if self.enable_deepep_moe:
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]." self.ep_size = self.tp_size
) self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True
)
logger.info(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
......
...@@ -12,7 +12,42 @@ from sglang.test.test_utils import ( ...@@ -12,7 +12,42 @@ from sglang.test.test_utils import (
) )
class TestDeepEPMoE(CustomTestCase): class TestPureTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--enable-deepep-moe",
"--disable-cuda-graph",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
class TestDPAttn(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment