Unverified Commit 3196999f authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Reduce computation and communication in DP attention (#4521)

parent 9e0186f3
...@@ -189,6 +189,9 @@ class GroupCoordinator: ...@@ -189,6 +189,9 @@ class GroupCoordinator:
device_group: ProcessGroup # group for device communication device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1 # communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator ca_comm: Optional[Any] # Custom allreduce communicator
...@@ -241,6 +244,7 @@ class GroupCoordinator: ...@@ -241,6 +244,7 @@ class GroupCoordinator:
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
...@@ -269,7 +273,7 @@ class GroupCoordinator: ...@@ -269,7 +273,7 @@ class GroupCoordinator:
HpuCommunicator, HpuCommunicator,
) )
self.hpu_communicator: Optional[HpuCommunicator] self.hpu_communicator: Optional[HpuCommunicator] = None
if use_hpu_communicator and self.world_size > 1: if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group) self.hpu_communicator = HpuCommunicator(group=self.device_group)
...@@ -277,7 +281,7 @@ class GroupCoordinator: ...@@ -277,7 +281,7 @@ class GroupCoordinator:
XpuCommunicator, XpuCommunicator,
) )
self.xpu_communicator: Optional[XpuCommunicator] self.xpu_communicator: Optional[XpuCommunicator] = None
if use_xpu_communicator and self.world_size > 1: if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group) self.xpu_communicator = XpuCommunicator(group=self.device_group)
......
...@@ -53,10 +53,8 @@ def initialize_dp_attention( ...@@ -53,10 +53,8 @@ def initialize_dp_attention(
) )
if enable_dp_attention: if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size _DP_SIZE = dp_size
else: else:
local_rank = tp_rank
_DP_SIZE = 1 _DP_SIZE = 1
tp_group = get_tp_group() tp_group = get_tp_group()
...@@ -65,7 +63,7 @@ def initialize_dp_attention( ...@@ -65,7 +63,7 @@ def initialize_dp_attention(
list(range(head, head + _ATTN_TP_SIZE)) list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE) for head in range(0, tp_size, _ATTN_TP_SIZE)
], ],
local_rank, tp_group.local_rank,
torch.distributed.get_backend(tp_group.device_group), torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP, SYNC_TOKEN_IDS_ACROSS_TP,
False, False,
...@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src): ...@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE) memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather( def _dp_gather(
global_tokens: torch.Tensor, global_tokens: torch.Tensor,
local_tokens: torch.Tensor, local_tokens: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
layer_id: Union[str, int], is_partial: bool,
): ):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0) global_tokens.fill_(0)
assert local_tokens.is_contiguous() assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0 if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
):
assert ( assert (
global_tokens.untyped_storage().data_ptr() global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr() != local_tokens.untyped_storage().data_ptr()
...@@ -216,6 +213,22 @@ def dp_gather( ...@@ -216,6 +213,22 @@ def dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_gather_partial(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
def dp_gather_replicate(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
def dp_scatter( def dp_scatter(
local_tokens: torch.Tensor, # output local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input global_tokens: torch.Tensor, # input
...@@ -236,16 +249,3 @@ def dp_scatter( ...@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
) )
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter
...@@ -28,7 +28,7 @@ from sglang.srt.distributed import ( ...@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
dp_gather, dp_gather_replicate,
dp_scatter, dp_scatter,
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
...@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module): ...@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
logits_metadata.gathered_buffer, logits_metadata.gathered_buffer,
hidden_states.clone(), hidden_states.clone(),
) )
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding") dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
logits = torch.matmul( logits = torch.matmul(
......
...@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( ...@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
dp_gather, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
...@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if residual is None: if hidden_states.shape[0] == 0:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Scatter # Self Attention
if self.dp_size != 1: hidden_states = self.self_attn(
# important: forward batch.gathered_buffer is used both after scatter and after gather. positions=positions,
# be careful about this! hidden_states=hidden_states,
hidden_states, global_hidden_states = ( forward_batch=forward_batch,
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
) )
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather # Gather
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce # all gather and all reduce
if self.dp_size != 1: if self.dp_size != 1:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
forward_batch.gathered_buffer, forward_batch.gathered_buffer,
hidden_states, hidden_states,
) )
dp_gather( dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
hidden_states, local_hidden_states, forward_batch, self.layer_id dp_scatter(residual, hidden_states, forward_batch)
) hidden_states = self.post_attention_layernorm(hidden_states)
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected # Fully Connected
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual return hidden_states, residual
...@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module): ...@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
...@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
......
...@@ -11,7 +11,7 @@ from sglang.test.test_utils import ( ...@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
) )
class TestDPAttention(unittest.TestCase): class TestDPAttentionDP2TP2(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
...@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase): ...@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
print(f"{metrics=}") print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8) self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment