Unverified Commit 646cef2e authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

support qwen3 dense model dp attention (#7681)

parent 1dce6c48
...@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -264,6 +265,7 @@ class Qwen2Model(nn.Module): ...@@ -264,6 +265,7 @@ class Qwen2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else: else:
...@@ -332,7 +334,11 @@ class Qwen2Model(nn.Module): ...@@ -332,7 +334,11 @@ class Qwen2Model(nn.Module):
} }
) )
else: else:
hidden_states, _ = self.norm(hidden_states, residual) if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
......
...@@ -14,6 +14,8 @@ from sglang.srt.distributed import ( ...@@ -14,6 +14,8 @@ from sglang.srt.distributed import (
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
...@@ -54,18 +56,21 @@ class Qwen3Attention(nn.Module): ...@@ -54,18 +56,21 @@ class Qwen3Attention(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % self.tp_size == 0 attn_tp_rank = get_attention_tp_rank()
self.num_heads = self.total_num_heads // self.tp_size attn_tp_size = get_attention_tp_size()
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= self.tp_size: if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % self.tp_size == 0 assert self.total_num_kv_heads % attn_tp_size == 0
else: else:
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert self.tp_size % self.total_num_kv_heads == 0 assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -84,6 +89,8 @@ class Qwen3Attention(nn.Module): ...@@ -84,6 +89,8 @@ class Qwen3Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=attention_bias, bias=attention_bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -91,6 +98,9 @@ class Qwen3Attention(nn.Module): ...@@ -91,6 +98,9 @@ class Qwen3Attention(nn.Module):
hidden_size, hidden_size,
bias=attention_bias, bias=attention_bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
) )
...@@ -176,6 +186,18 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -176,6 +186,18 @@ class Qwen3DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=False,
is_previous_layer_sparse=False,
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -184,20 +206,24 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -184,20 +206,24 @@ class Qwen3DecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: hidden_states, residual = self.layer_communicator.prepare_attn(
residual = hidden_states hidden_states, residual, forward_batch
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
) )
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment