Unverified Commit 013021b6 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

refactor EAGLE 2 (#3269)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
Co-authored-by: default avatarmerrymercy <lianminzheng@gmail.com>
Co-authored-by: default avatarYing1123 <sqy1415@gmail.com>
parent 3c8ac78d
......@@ -21,6 +21,7 @@ def main():
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
cuda_graph_max_bs=8,
)
outputs = llm.generate(prompts, sampling_params)
......
......@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
import torch
......@@ -34,6 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
class WrapperDispatch(Enum):
......@@ -53,10 +55,19 @@ class PrefillMetadata:
extend_no_prefix: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner):
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Parse constants
......@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
),
)
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
assert not (
model_runner.sliding_window_size is not None
......@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
# Allocate buffers
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
if kv_indptr_buf is None:
self.kv_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else:
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
......@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_verify = []
self.decode_wrappers = []
for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
if not skip_prefill:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
......@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
)
# Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
......@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
)
def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph(
self,
......@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.end_forward()
wrapper.begin_forward(
......@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill:
)
class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.kv_indptr_stride = self.kv_indptr.shape[1]
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
self.cuda_graph_kv_indices,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
self.kv_indptr_stride,
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
......@@ -937,3 +1094,105 @@ def should_use_tensor_core(
return gqa_group_size > 4
else:
return False
def fast_decode_plan(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type
if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)
self.empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
if self.use_tensor_cores:
if not self.is_cuda_graph_enabled:
# when not using cudagraph, we need to create the indptr buffer, otherwise
# the buffer is already created during initialization
self._qo_indptr_buf = torch.arange(
batch_size + 1, dtype=torch.int32, device=indptr.device
)
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._qo_indptr_buf,
indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
empty_q_data,
)
else:
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
indptr,
self.last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
logits_soft_cap,
empty_q_data,
empty_kv_cache,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
......@@ -103,69 +103,75 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
capture_bs = list(
sorted(
set(
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs = [
bs
for bs in capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= server_args.cuda_graph_max_bs
]
compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile
else []
)
return capture_bs, compile_bs
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None
def get_global_graph_memory_pool():
return global_graph_memory_pool
def set_global_graph_memory_pool(val):
global global_graph_memory_pool
global_graph_memory_pool = val
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: "ModelRunner"):
def __init__(self, model_runner: ModelRunner):
# Parse args
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
self.dp_size = self.model_runner.server_args.dp_size
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
# Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
if self.capture_bs is None:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
self.capture_bs = list(
sorted(
set(
self.capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
self.capture_bs = [
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.cuda_graph_max_bs
]
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_eagle_topk
)
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
......@@ -182,10 +188,10 @@ class CudaGraphRunner:
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
if self.use_torch_compile:
if self.enable_torch_compile:
set_torch_compile_config()
# Common inputs
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
......@@ -301,7 +307,7 @@ class CudaGraphRunner:
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
# Common inputs
# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
......@@ -320,7 +326,7 @@ class CudaGraphRunner:
global_num_tokens = None
gathered_buffer = None
spec_info = self.get_spec_info(num_tokens, positions)
spec_info = self.get_spec_info(num_tokens)
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
......@@ -335,7 +341,6 @@ class CudaGraphRunner:
seq_lens_sum=seq_lens.sum(),
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=positions,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
......@@ -375,13 +380,14 @@ class CudaGraphRunner:
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool()
global_graph_memory_pool = graph.pool()
return graph, out
def replay(self, forward_batch: ForwardBatch):
......@@ -439,35 +445,26 @@ class CudaGraphRunner:
)
return logits_output
def get_spec_info(self, num_tokens: int, positions: torch.Tensor):
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import (
EAGLEDraftInput,
EagleVerifyInput,
)
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.model_runner.server_args)
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
raise RuntimeError("This should not happen.")
else:
spec_info = EagleVerifyInput(
None,
None,
None,
None,
None,
None,
self.model_runner.server_args.speculative_num_draft_tokens,
)
spec_info.custom_mask = torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
draft_token=None,
custom_mask=torch.zeros(
(num_tokens * self.model_runner.model_config.context_len),
dtype=torch.bool,
device="cuda",
),
positions=None,
retrive_index=None,
retrive_cum_len=None,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
)
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
return spec_info
......@@ -197,64 +197,6 @@ class ForwardBatch:
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
@classmethod
def init_new(
cls,
......@@ -337,7 +279,7 @@ class ForwardBatch:
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch)
ret._compute_mrope_positions(model_runner, batch)
# Init lora information
if model_runner.server_args.lora_paths is not None:
......@@ -345,6 +287,63 @@ class ForwardBatch:
return ret
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
......
......@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs
......@@ -714,8 +715,6 @@ class ModelRunner:
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None
if not self.is_generation:
......
......@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
)
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token):
def build_tree_kernel(
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
):
bs = seq_lens.numel()
device = parent_list.device
tree_mask = torch.full(
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,),
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
True,
device=device,
)
......
from __future__ import annotations
import bisect
import time
from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
if self.enable_torch_compile:
set_torch_compile_config()
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = num_seqs * self.num_tokens_per_bs
# Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
topk_p = self.topk_p[:num_seqs]
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
)
# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=num_seqs,
input_ids=None,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
forward_batch
)
# Run and capture
def run_once():
# Backup two fileds, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_forward(forward_batch)
forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool())
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch
)
# Replay
self.graphs[bs].replay()
return self.output_buffers[bs]
import logging
import time
from typing import List, Optional, Union
import torch
......@@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
from sglang.srt.utils import rank0_print
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
assign_draft_cache_locs,
fast_topk,
select_top_k_tokens,
)
logger = logging.getLogger(__name__)
class EAGLEWorker(TpModelWorker):
......@@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker):
is_draft_worker=True,
)
self.target_worker = target_worker
self.server_args = server_args
self.finish_extend_len = []
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.server_args = server_args
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.model_runner.init_cuda_graphs()
def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
# Create multi-step attn backends and cuda graph runners
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
self.model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
if self.server_args.disable_cuda_graph:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode():
# Draft
self._set_mem_pool(batch, self.model_runner)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
spec_info: EagleVerifyInput = self.draft(batch)
# Verify
(
......@@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker):
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch)
next_draft_input.load_server_args(self.server_args)
) = self.verify(batch, spec_info)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None:
......@@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker):
)
# Forward with the draft model.
spec_info = EAGLEDraftInput()
spec_info.load_server_args(self.server_args)
spec_info.hidden_states = logits_output.hidden_states
spec_info.verified_id = next_token_ids
batch.spec_info = spec_info
batch.spec_info = EagleDraftInput(
hidden_states=logits_output.hidden_states,
verified_id=next_token_ids,
)
self.forward_draft_extend(batch)
return logits_output, next_token_ids, model_worker_batch, 0
def verify(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch)
verify_input.prepare_for_verify(batch)
def draft(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
)
assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch
)
if can_cuda_graph:
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
forward_batch
)
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
ret = EagleVerifyInput.create(
spec_info.verified_id,
score_list,
token_list,
parents_list,
batch.seq_lens,
batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
# Free cache locations
batch.token_to_kv_pool.free(out_cache_loc)
self._set_mem_pool(batch, self.target_worker.model_runner)
return ret
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
spec_info.topk_index,
spec_info.hidden_states,
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
parents_list: List[torch.Tensor] = []
# Forward multiple steps
scores = None
for i in range(self.speculative_num_steps):
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[
forward_batch.batch_size
* self.topk
* i : forward_batch.batch_size
* self.topk
* (i + 1)
]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
verify_input.hidden_states = logits_output.hidden_states
res = verify_input.verify(batch, logits_output)
spec_info.hidden_states = logits_output.hidden_states
res = spec_info.verify(batch, logits_output)
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
......@@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker):
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch)
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......@@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker):
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):
sample_output = torch.softmax(
logits_output.next_token_logits, dim=-1
) # TODO(kavioyu): Support more sampling methods
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
spec_info = forward_batch.spec_info
spec_info.sample_output = sample_output
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment