Unverified Commit ed0a0b69 authored by u4lr451's avatar u4lr451 Committed by GitHub
Browse files

Perormance: Enable cuda graph for dp idle batch (#7269)


Co-authored-by: default avataraustindeng <austindeng@tencent.com>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: default avatarch-wan <cwan39@gatech.edu>
parent fa42e419
...@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs] metadata_expand = self.target_verify_metadata_topk_expand[bs]
# metadata_expand.max_seq_len_q = 1, already set in capture # metadata_expand.max_seq_len_q = 1, already set in capture
# metadata_expand.cu_seqlens_q already set in capture # metadata_expand.cu_seqlens_q already set in capture
offsets = torch.arange( offsets = torch.arange(
self.speculative_num_draft_tokens, device=device self.speculative_num_draft_tokens, device=device
).unsqueeze( ).unsqueeze(
0 0
) # shape: (1, self.speculative_num_draft_tokens) ) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1) cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad( cum_len = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
...@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
).view(1, -1) ).view(1, -1)
# avoid extracting padded seq indices which will be out of boundary # avoid extracting padded seq indices which will be out of boundary
mask_extraction_indices[ mask_extraction_indices[
:, spec_info.positions.numel() * self.speculative_num_draft_tokens : :,
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
].fill_(0) ].fill_(0)
mask = spec_info.custom_mask[mask_extraction_indices].view( mask = spec_info.custom_mask[mask_extraction_indices].view(
-1, self.speculative_num_draft_tokens -1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num) ) # (bsz * draft_num, draft_num)
col_indices = offsets.expand( col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens mask.shape[0], self.speculative_num_draft_tokens
) )
keys = torch.where( keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens mask,
col_indices,
col_indices + self.speculative_num_draft_tokens,
) )
_, sort_order = torch.sort(keys, dim=1) _, sort_order = torch.sort(keys, dim=1)
...@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
.gather(1, cols) .gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0) .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num) ) # (bsz, draft_num)
metadata_expand.page_table.copy_( metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order) non_masked_page_table.gather(1, sort_order)
) )
...@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32, dtype=torch.int32,
) )
) )
elif forward_mode.is_draft_extend(): elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs] metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
...@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
) )
accept_length = spec_info.accept_length[:bs] accept_length = spec_info.accept_length[:bs]
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1 metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
else:
metadata.max_seq_len_q = 1
metadata.cu_seqlens_q[1:].copy_( metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32) torch.cumsum(accept_length, dim=0, dtype=torch.int32)
) )
......
...@@ -1821,11 +1821,6 @@ class Scheduler( ...@@ -1821,11 +1821,6 @@ class Scheduler(
else: else:
can_cuda_graph = 0 can_cuda_graph = 0
if not spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
is_extend_in_batch = ( is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False local_batch.forward_mode.is_extend() if local_batch else False
) )
......
...@@ -306,27 +306,29 @@ class CudaGraphRunner: ...@@ -306,27 +306,29 @@ class CudaGraphRunner:
self.encoder_lens = None self.encoder_lens = None
if self.require_gathered_buffer: if self.require_gathered_buffer:
if self.require_mlp_tp_gather:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.dp_size * self.num_tokens_per_bs, self.max_num_token,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros( self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
else: else:
assert self.require_attn_tp_gather assert self.require_attn_tp_gather
self.gathered_buffer = torch.zeros( self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.custom_mask = torch.ones(
( (
self.max_bs * self.num_tokens_per_bs, (self.seq_lens.sum().item() + self.max_num_token)
self.model_runner.model_config.hidden_size, * self.num_tokens_per_bs
), ),
dtype=self.model_runner.dtype, dtype=torch.bool,
device="cuda",
) )
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
# Capture # Capture
try: try:
...@@ -674,11 +676,12 @@ class CudaGraphRunner: ...@@ -674,11 +676,12 @@ class CudaGraphRunner:
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
if self.enable_two_batch_overlap: if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare( self.tbo_plugin.replay_prepare(
forward_mode=forward_batch.forward_mode, forward_mode=self.capture_forward_mode,
bs=bs, bs=bs,
num_token_non_padded=len(forward_batch.input_ids), num_token_non_padded=len(forward_batch.input_ids),
) )
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
forward_batch.spec_info.custom_mask = self.custom_mask
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
...@@ -686,7 +689,7 @@ class CudaGraphRunner: ...@@ -686,7 +689,7 @@ class CudaGraphRunner:
self.seq_lens[:bs], self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens[:bs] if self.is_encoder_decoder else None, self.encoder_lens[:bs] if self.is_encoder_decoder else None,
forward_batch.forward_mode, self.capture_forward_mode,
forward_batch.spec_info, forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu[:bs], seq_lens_cpu=self.seq_lens_cpu[:bs],
) )
...@@ -736,11 +739,7 @@ class CudaGraphRunner: ...@@ -736,11 +739,7 @@ class CudaGraphRunner:
else: else:
spec_info = EagleVerifyInput( spec_info = EagleVerifyInput(
draft_token=None, draft_token=None,
custom_mask=torch.ones( custom_mask=self.custom_mask,
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
),
positions=None, positions=None,
retrive_index=None, retrive_index=None,
retrive_next_token=None, retrive_next_token=None,
......
...@@ -99,6 +99,8 @@ class EagleDraftInput: ...@@ -99,6 +99,8 @@ class EagleDraftInput:
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode, capture_hidden_mode=capture_hidden_mode,
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
) )
def prepare_extend_after_decode( def prepare_extend_after_decode(
......
...@@ -322,13 +322,11 @@ class EAGLEWorker(TpModelWorker): ...@@ -322,13 +322,11 @@ class EAGLEWorker(TpModelWorker):
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info) self.verify(batch, spec_info)
) )
need_forward, can_run_draft_extend_cuda_graph = (
self.check_forward_draft_extend_after_decode(batch) if self.check_forward_draft_extend_after_decode(batch):
)
if need_forward:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode( self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph batch,
) )
return ( return (
logits_output, logits_output,
...@@ -344,7 +342,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -344,7 +342,7 @@ class EAGLEWorker(TpModelWorker):
and batch.spec_info.verified_id.shape[0] > 0 and batch.spec_info.verified_id.shape[0] > 0
) )
if not self.server_args.enable_dp_attention: if not self.server_args.enable_dp_attention:
return local_need_forward, True return local_need_forward
global_need_forward = torch.tensor( global_need_forward = torch.tensor(
[ [
...@@ -357,10 +355,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -357,10 +355,7 @@ class EAGLEWorker(TpModelWorker):
) )
global_need_forward_cnt = global_need_forward[0].item() global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0 need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = ( return need_forward
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
def forward_target_extend( def forward_target_extend(
self, batch: ScheduleBatch self, batch: ScheduleBatch
...@@ -816,15 +811,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -816,15 +811,12 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode( def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle() input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle: if not input_is_idle:
# Prepare metadata # Prepare metadata
...@@ -836,14 +828,18 @@ class EAGLEWorker(TpModelWorker): ...@@ -836,14 +828,18 @@ class EAGLEWorker(TpModelWorker):
else: else:
batch = batch.copy() batch = batch.copy()
batch.prepare_for_idle() batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input( batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device, device=self.device,
hidden_size=self.model_config.hidden_size, hidden_size=hidden_size,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
topk=self.topk, topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
) )
batch.return_hidden_states = False batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
...@@ -858,8 +854,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -858,8 +854,7 @@ class EAGLEWorker(TpModelWorker):
# Run # Run
can_cuda_graph = ( can_cuda_graph = (
can_run_draft_extend_cuda_graph self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
) )
if can_cuda_graph: if can_cuda_graph:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment