"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "63c1acfbfde843eec3f8b2a16edcd15c9b4e7b28"
Unverified Commit 25c83fff authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Performing Vocabulary Parallelism for LM Head across Attention TP Groups (#5558)


Co-authored-by: default avatarliusy58 <liusy58@linux.alibaba.com>
parent 9f2c9568
...@@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -221,3 +221,4 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` | | `triton_attention_num_kv_splits` | Use to adjust the number of KV splits in triton kernels. | `8` |
| `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` | | `flashinfer_mla_disable_ragged` | Disable the use of the [ragged prefill](https://github.com/flashinfer-ai/flashinfer/blob/5751fc68f109877f6e0fc54f674cdcdef361af56/docs/tutorials/kv_layout.rst#L26) wrapper for the FlashInfer MLA attention backend. Ragged prefill increases throughput by computing MHA instead of paged MLA when there is no prefix match. Only use it when FlashInfer is being used as the MLA backend. | `False` |
| `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` | | `disable_chunked_prefix_cache` | Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. | `False` |
| `enable_dp_lm_head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | `False` |
...@@ -252,12 +252,12 @@ def dp_scatter( ...@@ -252,12 +252,12 @@ def dp_scatter(
) )
def tp_reduce_scatter( def attn_tp_reduce_scatter(
output: torch.Tensor, output: torch.Tensor,
input_list: List[torch.Tensor], input_list: List[torch.Tensor],
): ):
return get_attention_tp_group().reduce_scatter(output, input_list) return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list) return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
...@@ -23,15 +23,16 @@ import triton.language as tl ...@@ -23,15 +23,16 @@ import triton.language as tl
from torch import nn from torch import nn
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
dp_gather_replicate, dp_gather_replicate,
dp_scatter, dp_scatter,
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module): ...@@ -198,12 +199,20 @@ class LogitsProcessor(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.logit_scale = logit_scale self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = ( self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
not skip_all_gather and get_tensor_model_parallel_world_size() > 1 if self.use_attn_tp_group:
) self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather_dp_attn = ( self.do_tensor_parallel_all_gather = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1 not skip_all_gather and self.attn_tp_size > 1
) )
self.do_tensor_parallel_all_gather_dp_attn = False
else:
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr( self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None self.config, "final_logit_softcapping", None
) )
...@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module): ...@@ -442,7 +451,19 @@ class LogitsProcessor(nn.Module):
logits.mul_(self.logit_scale) logits.mul_(self.logit_scale)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits) if self.use_attn_tp_group:
global_logits = torch.empty(
(self.config.vocab_size, logits.shape[0]),
device=logits.device,
dtype=logits.dtype,
)
global_logits = global_logits.T
attn_tp_all_gather(
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
)
logits = global_logits
else:
logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = ( logits, global_logits = (
......
...@@ -13,6 +13,7 @@ from sglang.srt.distributed import ( ...@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self, self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
*,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None, org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_tp: bool = True, enable_tp: bool = True,
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
): ):
super().__init__() super().__init__()
...@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.enable_tp = enable_tp self.enable_tp = enable_tp
if self.enable_tp: if self.enable_tp:
tp_rank = get_tensor_model_parallel_rank() if use_attn_tp_group:
self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_attention_tp_rank()
self.tp_size = get_attention_tp_size()
else:
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
else: else:
assert use_attn_tp_group is False
tp_rank = 0 tp_rank = 0
self.tp_size = 1 self.tp_size = 1
...@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
self, self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
*,
bias: bool = False, bias: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None, org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_attn_tp_group: bool = False,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
): ):
super().__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
params_dtype, params_dtype=params_dtype,
org_num_embeddings, org_num_embeddings=org_num_embeddings,
padding_size, padding_size=padding_size,
quant_config, quant_config=quant_config,
prefix, prefix=prefix,
use_attn_tp_group=use_attn_tp_group,
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
) )
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -74,6 +74,7 @@ global_server_args_dict = { ...@@ -74,6 +74,7 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe, "enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection, "enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
......
...@@ -36,13 +36,13 @@ from sglang.srt.distributed import ( ...@@ -36,13 +36,13 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
dp_gather_partial, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
tp_all_gather,
tp_reduce_scatter,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1323,7 +1323,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
tp_all_gather( attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
) )
...@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1339,7 +1339,7 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.input_is_scattered: if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank] hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list) attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual hidden_states, residual
...@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1349,7 +1349,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states += residual hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank] hidden_states = tensor_list[self.attn_tp_rank]
tp_reduce_scatter(hidden_states, tensor_list) attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states residual = hidden_states
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
...@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1373,7 +1373,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
tp_all_gather( attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
) )
...@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1475,6 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size() self.dp_size = get_attention_dp_size()
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -420,6 +421,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -159,6 +159,7 @@ class ServerArgs: ...@@ -159,6 +159,7 @@ class ServerArgs:
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
enable_dp_lm_head: bool = False
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
...@@ -323,6 +324,11 @@ class ServerArgs: ...@@ -323,6 +324,11 @@ class ServerArgs:
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
) )
if self.enable_dp_lm_head:
assert (
self.enable_dp_attention
), "Please enable dp attention when setting enable_dp_attention. "
# DeepEP MoE # DeepEP MoE
self.enable_sp_layernorm = False self.enable_sp_layernorm = False
if self.enable_deepep_moe: if self.enable_deepep_moe:
...@@ -1055,6 +1061,11 @@ class ServerArgs: ...@@ -1055,6 +1061,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
) )
parser.add_argument(
"--enable-dp-lm-head",
action="store_true",
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
)
parser.add_argument( parser.add_argument(
"--enable-ep-moe", "--enable-ep-moe",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment