Unverified Commit 63195028 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support EAGLE draft extend CUDA graph (#6606)


Co-authored-by: default avatarSehoon Kim <sehoonkim@berkeley.edu>
parent a3d7f4b6
......@@ -1268,6 +1268,29 @@ class FlashAttentionBackend(AttentionBackend):
),
}
self.draft_extend_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.zeros(
max_bs + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
......@@ -1508,6 +1531,32 @@ class FlashAttentionBackend(AttentionBackend):
self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.draft_extend_metadata["page_table"][
req_pool_indices, :
]
self.draft_extend_metadata[bs] = metadata
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
......@@ -1732,6 +1781,29 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
metadata.max_seq_len_q = accept_length.max().item()
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
if encoder_lens is not None:
# Only support encoder size 1 for now
......
......@@ -262,10 +262,14 @@ class ServerArgs:
self.mem_fraction_static = 0.88
if gpu_mem is not None and gpu_mem > 96 * 1024:
mem_fraction = self.mem_fraction_static
# 15 GB + additional 3GB for cuda graph
reserve_mem = 1024 * 18
# need reserve more memory for spec cuda graph
if self.speculative_algorithm is not None:
reserve_mem = 1024 * 20
self.mem_fraction_static = min(
mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem,
(gpu_mem - 1024 * 18)
/ gpu_mem, # 15 GB + additional 3GB for cuda graph
(gpu_mem - reserve_mem) / gpu_mem,
)
# Set chunked prefill size, which depends on the gpu memory capacity
......
from __future__ import annotations
import bisect
from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner,
LogitsProcessorOutput,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class EAGLEDraftExtendCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1
# Attention backend
self.num_tokens_per_bs = self.speculative_num_steps + 1
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token
)
self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
)
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
if self.enable_torch_compile:
set_torch_compile_config()
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
if self.eagle_worker.speculative_algorithm.is_eagle3():
self.hidden_states = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size * 3,
),
dtype=self.model_runner.dtype,
)
else:
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32)
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture CUDA graph failed: {e}\n"
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, forward_batch: ForwardBatch):
batch_size = forward_batch.seq_lens.numel()
is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
extend_seq_lens = self.extend_seq_lens[:bs]
accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
spec_info = EagleDraftInput(
hidden_states=hidden_states,
accept_length=accept_length,
)
spec_info.positions = None
# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DRAFT_EXTEND,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
attn_backend=self.eagle_worker.draft_extend_attn_backend,
extend_seq_lens=extend_seq_lens,
padded_static_len=self.padded_static_len,
)
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_capture_cuda_graph(
bs=bs,
num_tokens=num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=spec_info,
)
# Run and capture
def run_once():
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_model_runner.model.forward(
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
)
forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
set_global_graph_memory_pool(graph.pool())
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
assert raw_bs * self.num_tokens_per_bs == num_tokens
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.accept_length.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.spec_info.positions = None
if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:bs]
self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
)
# Replay
self.graphs[bs].replay()
out = self.output_buffers[bs]
if bs != raw_bs:
forward_batch.spec_info.accept_length = self.accept_length[:raw_bs]
out = LogitsProcessorOutput(
next_token_logits=out.next_token_logits[:raw_bs],
hidden_states=out.hidden_states[:raw_bs],
)
return out
......@@ -84,6 +84,7 @@ class EagleDraftInput:
self,
batch: ScheduleBatch,
speculative_num_steps: int,
pad_input: bool = False,
):
assert len(self.verified_id) == len(batch.out_cache_loc)
accept_length_cpu = batch.spec_info.accept_length_cpu
......@@ -111,6 +112,50 @@ class EagleDraftInput:
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
if pad_input:
batch_size = sum(not req.finished() for req in batch.reqs)
# Total constant input length after padding
static_len = speculative_num_steps + 1
# Total size after padding
padded_input_size = batch_size * static_len
padded_len = padded_input_size - batch.input_ids.shape[0]
if padded_len > 0:
new_input_ids = torch.nn.functional.pad(
batch.input_ids, (0, padded_len), value=0
)
position_padding = torch.arange(
padded_len, device=self.positions.device
)
new_positions = torch.cat([self.positions, position_padding])
# need dummy hidden states for the padded positions
hidden_states_dim = self.hidden_states.shape[-1]
new_hidden_states = torch.cat(
[
self.hidden_states,
torch.zeros(
(padded_len, hidden_states_dim),
dtype=self.hidden_states.dtype,
device=self.hidden_states.device,
),
],
dim=0,
)
# allocate KV cache location for the padded tokens
padded_cache_loc = torch.zeros(
padded_len,
dtype=batch.out_cache_loc.dtype,
device=batch.out_cache_loc.device,
)
new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc])
batch.input_ids = new_input_ids
self.hidden_states = new_hidden_states
self.positions = new_positions
batch.out_cache_loc = new_out_cache_loc
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
......
......@@ -26,6 +26,9 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
......@@ -189,6 +192,7 @@ class EAGLEWorker(TpModelWorker):
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
FlashAttentionMultiStepBackend,
)
......@@ -197,7 +201,10 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.draft_extend_attn_backend = FlashAttentionBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
......@@ -242,7 +249,18 @@ class EAGLEWorker(TpModelWorker):
# Capture extend
if self.draft_extend_attn_backend:
raise NotImplementedError()
tic = time.perf_counter()
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
self
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
@property
def draft_model_runner(self):
......@@ -656,6 +674,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_logprob = False
......@@ -665,7 +684,19 @@ class EAGLEWorker(TpModelWorker):
)
# Run
logits_output, _ = self.draft_model_runner.forward(forward_batch)
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
)
if can_cuda_graph:
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
forward_batch
)
else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment