Unverified Commit b2e95f62 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix two issues related to `--moe-dense-tp-size=1` (#5657)


Co-authored-by: default avatarliusy58 <liusy58@linux.alibaba.com>
Co-authored-by: default avatar颉沆 <xiehang.lsy@alibaba-inc.com>
parent 1ab14c4c
......@@ -24,8 +24,10 @@ if TYPE_CHECKING:
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_DP_RANK = None
_DP_SIZE = None
_ATTN_DP_RANK = None
_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
......@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return tp_rank, tp_size, 0
attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size
attn_dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank
return attn_tp_rank, attn_tp_size, attn_dp_rank
def compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
):
if not enable_dp_attention:
return tp_rank, tp_size, 0
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
local_tp_rank = tp_rank % local_tp_size
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
local_attn_tp_size = local_tp_size // local_dp_size
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
def initialize_dp_attention(
......@@ -43,22 +63,32 @@ def initialize_dp_attention(
tp_rank: int,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
_ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else:
local_rank = tp_rank
_DP_SIZE = 1
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
......@@ -93,13 +123,33 @@ def get_attention_tp_size():
def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
@contextmanager
......@@ -112,19 +162,19 @@ def disable_dp_size():
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
global _ATTN_DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
old_dp_size = _ATTN_DP_SIZE
_ATTN_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
_ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
dp_rank = get_local_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
......
......@@ -30,9 +30,10 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -46,6 +47,18 @@ from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
......@@ -170,7 +183,7 @@ class LogitsMetadata:
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
dp_rank = get_local_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
......@@ -324,7 +337,8 @@ class LogitsProcessor(nn.Module):
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
not self.do_tensor_parallel_all_gather
or get_local_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
......
......@@ -207,7 +207,8 @@ class Scheduler(
self.page_size = server_args.page_size
# Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
......@@ -768,7 +769,7 @@ class Scheduler(
)
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
......@@ -815,7 +816,7 @@ class Scheduler(
recv_reqs = None
else:
if self.attn_tp_rank == 0:
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
......@@ -1610,6 +1611,7 @@ class Scheduler(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
......@@ -1622,6 +1624,7 @@ class Scheduler(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
......@@ -1631,15 +1634,15 @@ class Scheduler(
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
......@@ -1666,7 +1669,7 @@ class Scheduler(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
......@@ -1689,8 +1692,15 @@ class Scheduler(
local_batch = get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
# Check forward mode for cuda graph
if not disable_cuda_graph:
......@@ -2177,8 +2187,8 @@ class Scheduler(
def get_print_prefix(self):
prefix = ""
if self.dp_rank is not None:
prefix += f" DP{self.dp_rank}"
if self.attn_dp_rank is not None:
prefix += f" DP{self.attn_dp_rank}"
if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}"
if self.pp_size > 1:
......
......@@ -401,6 +401,7 @@ class ModelRunner:
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size,
)
......
......@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
......@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = DeepseekV2AttentionMLA(
......@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
)
self.input_is_scattered = (
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
......@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
......@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
......@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
......@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
self.dp_size = get_local_attention_dp_size()
def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM"
......
......@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
from sglang.srt.layers.dp_attention import (
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
self.use_rope = int((layer_id + 1) % 4 != 0)
self.use_qk_norm = config.use_qk_norm and self.use_rope
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
......@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.dp_size = get_attention_dp_size()
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
......@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
......@@ -430,7 +429,7 @@ class Llama4DecoderLayer(nn.Module):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.dp_size != 1:
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment