Unverified Commit 10d60cd4 authored by u4lr451's avatar u4lr451 Committed by GitHub
Browse files

feat: mtp support dp-attention (#6081)


Co-authored-by: default avataraustindeng <austindeng@tencent.com>
Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarQiaolin Yu <liin1211@outlook.com>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent 8a10c4c3
...@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend): ...@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None: if kv_indices_buf is None:
...@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device=self.device, device=self.device,
) )
......
...@@ -19,7 +19,7 @@ class AttentionBackend(ABC): ...@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
raise NotImplementedError() raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Init the global shared states for cuda graph.""" """Init the global shared states for cuda graph."""
raise NotImplementedError() raise NotImplementedError()
......
...@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None, block_kv_indices: Optional[torch.Tensor] = None,
): ):
if block_kv_indices is None: if block_kv_indices is None:
......
...@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend. """Initialize CUDA graph state for the attention backend.
Args: Args:
...@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend: ...@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch) self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs) self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
......
...@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
if kv_indices_buf is None: if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,), (max_num_tokens * self.max_context_len,),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
...@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device="cuda", device="cuda",
) )
...@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
...@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
...@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
if kv_indices_buf is None: if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
...@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
...@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend: ...@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
...@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None, block_kv_indices: Optional[torch.Tensor] = None,
): ):
if block_kv_indices is None: if block_kv_indices is None:
...@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend: ...@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None) self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
......
...@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend): ...@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
if forward_batch_child.batch_size > 0: if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child) child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs) self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children: for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize # TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs) item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
......
...@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None attn_lse = None
elif forward_batch.forward_mode.is_draft_extend(): elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
...@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend): ...@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_attn_logits = torch.zeros( self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.cuda_graph_attn_lse = torch.zeros( self.cuda_graph_attn_lse = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits), (max_num_tokens, self.num_head, self.max_kv_splits),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.cuda_graph_num_kv_splits = torch.full( self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
) )
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
...@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device=self.device, device=self.device,
) )
...@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros( self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size), (max_num_tokens * self.sliding_window_size),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
...@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend): ...@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf) self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
self.cuda_graph_window_num_kv_splits = torch.full( self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
...@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
) )
custom_mask = self.cuda_graph_custom_mask custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1] mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
...@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend: ...@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
......
...@@ -238,6 +238,10 @@ def _dp_gather( ...@@ -238,6 +238,10 @@ def _dp_gather(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
) )
...@@ -288,6 +292,10 @@ def dp_scatter( ...@@ -288,6 +292,10 @@ def dp_scatter(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed" ), "aliasing between local_tokens and global_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
) )
......
...@@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
...@@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor, enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
) )
def __str__(self): def __str__(self):
return ( return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, " f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
f"#req={(len(self.reqs))})" f"#req={(len(self.reqs))})"
) )
...@@ -1833,6 +1838,7 @@ class ModelWorkerBatch: ...@@ -1833,6 +1838,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None
......
...@@ -1350,6 +1350,29 @@ class Scheduler( ...@@ -1350,6 +1350,29 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._publish_kv_events() self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
chunked_req_to_exclude = set() chunked_req_to_exclude = set()
...@@ -1383,7 +1406,14 @@ class Scheduler( ...@@ -1383,7 +1406,14 @@ class Scheduler(
self.running_batch.merge_batch(self.last_batch) self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
# Run prefill first if possible # Run prefill first if possible
ret = new_batch ret = new_batch
else: else:
...@@ -1732,8 +1762,6 @@ class Scheduler( ...@@ -1732,8 +1762,6 @@ class Scheduler(
num_tokens_for_logprob = 0 num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
...@@ -1809,6 +1837,7 @@ class Scheduler( ...@@ -1809,6 +1837,7 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = ( local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob global_num_tokens_for_logprob
) )
local_batch.is_extend_in_batch = any(is_extend_in_batch)
local_batch.tbo_split_seq_index = tbo_split_seq_index local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode local_batch.global_forward_mode = global_forward_mode
...@@ -1816,6 +1845,7 @@ class Scheduler( ...@@ -1816,6 +1845,7 @@ class Scheduler(
if not disable_cuda_graph: if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch) return local_batch, any(is_extend_in_batch)
def get_idle_batch(self): def get_idle_batch(self):
......
...@@ -242,13 +242,13 @@ class CudaGraphRunner: ...@@ -242,13 +242,13 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
if global_server_args_dict["attention_backend"] == "flashmla": self.model_runner.attn_backend.init_cuda_graph_state(
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.max_bs, self.max_num_token
else: )
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
...@@ -323,12 +323,15 @@ class CudaGraphRunner: ...@@ -323,12 +323,15 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu) total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_global_tokens in self.graphs total_batch_size in self.graphs
if self.disable_padding if self.disable_padding
else total_global_tokens <= self.max_bs else total_batch_size <= self.max_bs
) )
else: else:
is_bs_supported = ( is_bs_supported = (
...@@ -460,7 +463,7 @@ class CudaGraphRunner: ...@@ -460,7 +463,7 @@ class CudaGraphRunner:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
num_tokens // self.dp_size + (i < bs % self.dp_size) num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size) for i in range(self.dp_size)
], ],
dtype=torch.int32, dtype=torch.int32,
...@@ -605,9 +608,12 @@ class CudaGraphRunner: ...@@ -605,9 +608,12 @@ class CudaGraphRunner:
# Pad # Pad
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left( total_batch_size = (
self.capture_bs, sum(forward_batch.global_num_tokens_cpu) sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
) )
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
...@@ -650,13 +656,13 @@ class CudaGraphRunner: ...@@ -650,13 +656,13 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
self.req_pool_indices, self.req_pool_indices[:bs],
self.seq_lens, self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens, self.encoder_lens[:bs] if self.is_encoder_decoder else None,
forward_batch.forward_mode, forward_batch.forward_mode,
forward_batch.spec_info, forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu, seq_lens_cpu=self.seq_lens_cpu[:bs],
) )
# Store fields # Store fields
......
...@@ -320,17 +320,30 @@ class ForwardBatch: ...@@ -320,17 +320,30 @@ class ForwardBatch:
# For DP attention # For DP attention
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor( ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64 global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor( ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64 global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True) ).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens) sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros( ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size), (sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype, dtype=model_runner.dtype,
......
...@@ -163,6 +163,7 @@ class ModelRunner: ...@@ -163,6 +163,7 @@ class ModelRunner:
logger.addFilter(RankZeroFilter(tp_rank == 0)) logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.pp_size = pp_size self.pp_size = pp_size
self.dist_port = nccl_port self.dist_port = nccl_port
...@@ -196,6 +197,7 @@ class ModelRunner: ...@@ -196,6 +197,7 @@ class ModelRunner:
| { | {
# TODO it is indeed not a "server args" # TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
} }
) )
......
...@@ -22,7 +22,6 @@ from transformers import PretrainedConfig ...@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
...@@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module): ...@@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator( zero_allocator = BumpAllocator(
buffer_size=2, buffer_size=2,
dtype=torch.float32, dtype=torch.float32,
...@@ -90,15 +90,16 @@ class DeepseekModelNextN(nn.Module): ...@@ -90,15 +90,16 @@ class DeepseekModelNextN(nn.Module):
else: else:
hidden_states = input_embeds hidden_states = input_embeds
hidden_states = self.eh_proj( if hidden_states.shape[0] > 0:
torch.cat( hidden_states = self.eh_proj(
( torch.cat(
self.enorm(hidden_states), (
self.hnorm(forward_batch.spec_info.hidden_states), self.enorm(hidden_states),
), self.hnorm(forward_batch.spec_info.hidden_states),
dim=-1, ),
dim=-1,
)
) )
)
residual = None residual = None
hidden_states, residual = self.decoder( hidden_states, residual = self.decoder(
...@@ -127,23 +128,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -127,23 +128,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.model = DeepseekModelNextN( self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
self.lm_head = ParallelLMHead(
if global_server_args_dict["enable_dp_attention"]: config.vocab_size,
self.lm_head = ReplicatedLinear( config.hidden_size,
config.hidden_size, quant_config=quant_config,
config.vocab_size, prefix=add_prefix("model.shared_head.head", prefix),
bias=False, use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
prefix=add_prefix("model.shared_head.head", prefix), )
) self.logits_processor = LogitsProcessor(config)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
......
...@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn( def op_comm_prepare_attn(
......
...@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner: ...@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
...@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) self.model_runner.draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0 0
].get_cuda_graph_seq_len_fill_value() ].get_cuda_graph_seq_len_fill_value()
...@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner: ...@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size), (self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
...@@ -92,11 +114,26 @@ class EAGLEDraftCudaGraphRunner: ...@@ -92,11 +114,26 @@ class EAGLEDraftCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = ( if self.enable_dp_attention:
forward_batch.batch_size in self.graphs # TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if self.disable_padding if not forward_batch.can_run_dp_cuda_graph:
else forward_batch.batch_size <= self.max_bs return False
) total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
return is_bs_supported return is_bs_supported
def capture(self): def capture(self):
...@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner: ...@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs] topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs] hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
capture_hidden_mode=CaptureHiddenMode.LAST,
) )
# Forward batch # Forward batch
...@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner: ...@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=( capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
), ),
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
) )
# Attention backend # Attention backend
...@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
...@@ -184,7 +259,15 @@ class EAGLEDraftCudaGraphRunner: ...@@ -184,7 +259,15 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad # Pad
index = bisect.bisect_left(self.capture_bs, raw_bs) if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
...@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner: ...@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
# Attention backend # Attention backend
if bs != raw_bs: if bs != raw_bs:
forward_batch.batch_size = bs forward_batch.batch_size = bs
...@@ -210,8 +300,10 @@ class EAGLEDraftCudaGraphRunner: ...@@ -210,8 +300,10 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens] forward_batch.positions = self.positions[:num_tokens]
if forward_batch.seq_lens_cpu is not None and bs != raw_bs: # Special handle for seq_len_cpu used when flashinfer mla is used
self.seq_lens_cpu.fill_(self.seq_len_fill_value) if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
......
...@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
...@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state( self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token self.max_bs, self.max_num_token
) )
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
...@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
) )
if self.enable_dp_attention or self.enable_sp_layernorm:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
...@@ -100,15 +117,30 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -100,15 +117,30 @@ class EAGLEDraftExtendCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
batch_size = forward_batch.seq_lens.numel() if self.enable_dp_attention or self.enable_sp_layernorm:
if not forward_batch.can_run_dp_cuda_graph:
is_bs_supported = ( return False
batch_size in self.graphs total_batch_size = (
if self.disable_padding sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
else batch_size <= self.max_bs if self.model_runner.spec_algorithm.is_eagle()
) else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel()
is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)
return is_bs_supported return is_bs_supported
def capture(self): def capture(self):
CudaGraphRunner.capture(self) CudaGraphRunner.capture(self)
...@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens] hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
accept_length=accept_length, accept_length=accept_length,
...@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
...@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
...@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs # in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0] num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens: if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.accept_length.fill_(1) self.accept_length.fill_(1)
self.extend_seq_lens.fill_(1)
# Common inputs # Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids) self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) if forward_batch.extend_seq_lens is not None:
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions) self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) if forward_batch.spec_info.accept_length is not None:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
......
...@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator ...@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
logger = logging.getLogger(__name__)
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
fast_topk, fast_topk,
...@@ -69,6 +71,8 @@ class EagleDraftInput: ...@@ -69,6 +71,8 @@ class EagleDraftInput:
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token. # Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens) assert len(self.verified_id) == len(batch.seq_lens)
...@@ -80,6 +84,24 @@ class EagleDraftInput: ...@@ -80,6 +84,24 @@ class EagleDraftInput:
) )
pt += extend_len pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
hidden_states=torch.empty(
(0, hidden_size), device=device, dtype=torch.float32
),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
)
def prepare_extend_after_decode( def prepare_extend_after_decode(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
...@@ -193,7 +215,35 @@ class EagleVerifyInput: ...@@ -193,7 +215,35 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None grammar: BaseGrammarObject = None
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_cum_len=None,
topk=topk,
draft_token_num=num_verify_tokens,
spec_steps=spec_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=0,
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
)
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
if page_size == 1: if page_size == 1:
...@@ -279,6 +329,25 @@ class EagleVerifyInput: ...@@ -279,6 +329,25 @@ class EagleVerifyInput:
tokens. I.e., logits_output.next_token_logits only contains tokens. I.e., logits_output.next_token_logits only contains
accepted token logits. accepted token logits.
""" """
if batch.forward_mode.is_idle():
return EagleVerifyOutput(
draft_input=EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
),
logits_output=logits_output,
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
accept_length_per_req_cpu=[],
accepted_indices=torch.full(
(0, self.spec_steps + 1),
-1,
dtype=torch.int32,
device=batch.device,
),
)
bs = self.retrive_index.shape[0] bs = self.retrive_index.shape[0]
candidates = self.draft_token.reshape(bs, self.draft_token_num) candidates = self.draft_token.reshape(bs, self.draft_token_num)
sampling_info = batch.sampling_info sampling_info = batch.sampling_info
...@@ -992,10 +1061,11 @@ def select_top_k_tokens( ...@@ -992,10 +1061,11 @@ def select_top_k_tokens(
topk_index = topk_index.reshape(-1, topk**2) topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
selected_input_index = topk_cs_index.flatten() // topk + torch.arange( if hidden_states.shape[0] > 0:
0, hidden_states.shape[0], step=topk, device="cuda" selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
).repeat_interleave(topk) 0, hidden_states.shape[0], step=topk, device="cuda"
hidden_states = hidden_states[selected_input_index, :] ).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = ( tree_info = (
expand_scores, # shape: (b, topk, topk) expand_scores, # shape: (b, topk, topk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment