Unverified Commit 3196999f authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Reduce computation and communication in DP attention (#4521)

parent 9e0186f3
......@@ -189,6 +189,9 @@ class GroupCoordinator:
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
use_message_queue_broadcaster: (
bool # a hint of whether to use message queue broadcaster
)
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
......@@ -241,6 +244,7 @@ class GroupCoordinator:
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster
# lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
......@@ -269,7 +273,7 @@ class GroupCoordinator:
HpuCommunicator,
)
self.hpu_communicator: Optional[HpuCommunicator]
self.hpu_communicator: Optional[HpuCommunicator] = None
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)
......@@ -277,7 +281,7 @@ class GroupCoordinator:
XpuCommunicator,
)
self.xpu_communicator: Optional[XpuCommunicator]
self.xpu_communicator: Optional[XpuCommunicator] = None
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)
......
......@@ -53,10 +53,8 @@ def initialize_dp_attention(
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
......@@ -65,7 +63,7 @@ def initialize_dp_attention(
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
local_rank,
tp_group.local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,
......@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
is_partial: bool,
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
assert (
global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
......@@ -216,6 +213,22 @@ def dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_gather_partial(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
def dp_gather_replicate(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
......@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter
......@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
......@@ -428,7 +428,7 @@ class LogitsProcessor(nn.Module):
logits_metadata.gathered_buffer,
hidden_states.clone(),
)
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
logits = torch.matmul(
......
......@@ -33,7 +33,7 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
......@@ -939,47 +939,58 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if residual is None:
if hidden_states.shape[0] == 0:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
if get_attention_tp_rank() == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather(
hidden_states, local_hidden_states, forward_batch, self.layer_id
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
......@@ -1025,18 +1036,6 @@ class DeepseekV2Model(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
......@@ -1087,15 +1086,6 @@ class DeepseekV2ForCausalLM(nn.Module):
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
......
......@@ -11,7 +11,7 @@ from sglang.test.test_utils import (
)
class TestDPAttention(unittest.TestCase):
class TestDPAttentionDP2TP2(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
......@@ -59,7 +59,3 @@ class TestDPAttention(unittest.TestCase):
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment