Unverified Commit 013021b6 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

refactor EAGLE 2 (#3269)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
Co-authored-by: default avatarmerrymercy <lianminzheng@gmail.com>
Co-authored-by: default avatarYing1123 <sqy1415@gmail.com>
parent 3c8ac78d
...@@ -21,6 +21,7 @@ def main(): ...@@ -21,6 +21,7 @@ def main():
speculative_num_steps=3, speculative_num_steps=3,
speculative_eagle_topk=4, speculative_eagle_topk=4,
speculative_num_draft_tokens=16, speculative_num_draft_tokens=16,
cuda_graph_max_bs=8,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an ...@@ -10,6 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
...@@ -34,6 +35,7 @@ if is_flashinfer_available(): ...@@ -34,6 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
...@@ -53,10 +55,19 @@ class PrefillMetadata: ...@@ -53,10 +55,19 @@ class PrefillMetadata:
extend_no_prefix: bool extend_no_prefix: bool
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None
class FlashInferAttnBackend(AttentionBackend): class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
def __init__(self, model_runner: ModelRunner): def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__() super().__init__()
# Parse constants # Parse constants
...@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -69,6 +80,7 @@ class FlashInferAttnBackend(AttentionBackend):
), ),
) )
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
assert not ( assert not (
model_runner.sliding_window_size is not None model_runner.sliding_window_size is not None
...@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -90,16 +102,26 @@ class FlashInferAttnBackend(AttentionBackend):
global_config.flashinfer_workspace_size = 512 * 1024 * 1024 global_config.flashinfer_workspace_size = 512 * 1024 * 1024
# Allocate buffers # Allocate buffers
self.workspace_buffer = torch.empty( global global_workspace_buffer
global_config.flashinfer_workspace_size, if global_workspace_buffer is None:
dtype=torch.uint8, global_workspace_buffer = torch.empty(
device=model_runner.device, global_config.flashinfer_workspace_size,
) dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = [ if kv_indptr_buf is None:
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) self.kv_indptr = [
for _ in range(self.num_wrappers) torch.zeros(
] (max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else:
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones( self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device (max_bs,), dtype=torch.int32, device=model_runner.device
) )
...@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -122,12 +144,16 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_verify = [] self.prefill_wrappers_verify = []
self.decode_wrappers = [] self.decode_wrappers = []
for _ in range(self.num_wrappers): for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append( if not skip_prefill:
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") self.prefill_wrappers_paged.append(
) BatchPrefillWithPagedKVCacheWrapper(
self.prefill_wrappers_verify.append( self.workspace_buffer,
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") "NHD",
) )
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append( self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -137,10 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
) )
# Create indices updater # Create indices updater
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
# Other metadata # Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
...@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -211,23 +238,30 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_paged, use_ragged, extend_no_prefix self.prefill_wrappers_paged, use_ragged, extend_no_prefix
) )
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(
cuda_graph_kv_indices = torch.zeros( self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
(max_bs * self.max_context_len,), ):
dtype=torch.int32, if kv_indices_buf is None:
device="cuda", cuda_graph_kv_indices = torch.zeros(
) (max_bs * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = kv_indices_buf
self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
] ]
self.cuda_graph_custom_mask = torch.zeros( if not self.skip_prefill:
(max_bs * self.max_context_len), self.cuda_graph_custom_mask = torch.zeros(
dtype=torch.uint8, (max_bs * self.max_context_len),
device="cuda", dtype=torch.uint8,
) device="cuda",
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] )
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
...@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -602,11 +636,8 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
else: else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
req_pool_indices, bs = kv_indptr.shape[0] - 1
paged_kernel_lens,
self.req_to_token,
)
wrapper.end_forward() wrapper.end_forward()
wrapper.begin_forward( wrapper.begin_forward(
...@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -854,6 +885,132 @@ class FlashInferIndicesUpdaterPrefill:
) )
class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashInferAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.kv_indptr_stride = self.kv_indptr.shape[1]
def common_template(self, forward_batch: ForwardBatch, call_fn: int):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
self.cuda_graph_kv_indices,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
self.kv_indptr_stride,
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
@triton.jit @triton.jit
def create_flashinfer_kv_indices_triton( def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len] req_to_token_ptr, # [max_batch, max_context_len]
...@@ -937,3 +1094,105 @@ def should_use_tensor_core( ...@@ -937,3 +1094,105 @@ def should_use_tensor_core(
return gqa_group_size > 4 return gqa_group_size > 4
else: else:
return False return False
def fast_decode_plan(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
data_type: Union[str, torch.dtype] = "float16",
q_data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type
if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
dtype=(
getattr(torch, q_data_type)
if isinstance(q_data_type, str)
else q_data_type
),
)
self.empty_kv_cache = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
if self.use_tensor_cores:
if not self.is_cuda_graph_enabled:
# when not using cudagraph, we need to create the indptr buffer, otherwise
# the buffer is already created during initialization
self._qo_indptr_buf = torch.arange(
batch_size + 1, dtype=torch.int32, device=indptr.device
)
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._qo_indptr_buf,
indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
empty_q_data,
)
else:
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
indptr,
self.last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
logits_soft_cap,
empty_q_data,
empty_kv_cache,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
...@@ -103,69 +103,75 @@ def set_torch_compile_config(): ...@@ -103,69 +103,75 @@ def set_torch_compile_config():
torch._dynamo.config.cache_size_limit = 1024 torch._dynamo.config.cache_size_limit = 1024
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
capture_bs = list(
sorted(
set(
capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
capture_bs = [
bs
for bs in capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= server_args.cuda_graph_max_bs
]
compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile
else []
)
return capture_bs, compile_bs
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None
def get_global_graph_memory_pool():
return global_graph_memory_pool
def set_global_graph_memory_pool(val):
global global_graph_memory_pool
global_graph_memory_pool = val
class CudaGraphRunner: class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: "ModelRunner"): def __init__(self, model_runner: ModelRunner):
# Parse args # Parse args
self.model_runner = model_runner self.model_runner = model_runner
self.graphs = {} self.graphs = {}
self.input_buffers = {}
self.output_buffers = {} self.output_buffers = {}
self.flashinfer_handlers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.graph_memory_pool = None
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = self.model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
# Batch sizes to capture # Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
if self.capture_bs is None:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very samll. We add more values here to make sure we capture the maximum bs.
self.capture_bs = list(
sorted(
set(
self.capture_bs
+ [model_runner.req_to_token_pool.size - 1]
+ [model_runner.req_to_token_pool.size]
)
)
)
self.capture_bs = [
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.cuda_graph_max_bs
]
self.compile_bs = (
[
bs
for bs in self.capture_bs
if bs <= self.model_runner.server_args.torch_compile_max_bs
]
if self.use_torch_compile
else []
)
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle(): if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = ( raise RuntimeError("This should not happen")
self.model_runner.server_args.speculative_eagle_topk
)
else: else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = ( self.num_tokens_per_bs = (
...@@ -182,10 +188,10 @@ class CudaGraphRunner: ...@@ -182,10 +188,10 @@ class CudaGraphRunner:
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
if self.use_torch_compile: if self.enable_torch_compile:
set_torch_compile_config() set_torch_compile_config()
# Common inputs # Graph inputs
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
...@@ -301,7 +307,7 @@ class CudaGraphRunner: ...@@ -301,7 +307,7 @@ class CudaGraphRunner:
stream = self.stream stream = self.stream
num_tokens = bs * self.num_tokens_per_bs num_tokens = bs * self.num_tokens_per_bs
# Common inputs # Graph inputs
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
...@@ -320,7 +326,7 @@ class CudaGraphRunner: ...@@ -320,7 +326,7 @@ class CudaGraphRunner:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
spec_info = self.get_spec_info(num_tokens, positions) spec_info = self.get_spec_info(num_tokens)
forward_batch = ForwardBatch( forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode, forward_mode=self.capture_forward_mode,
...@@ -335,7 +341,6 @@ class CudaGraphRunner: ...@@ -335,7 +341,6 @@ class CudaGraphRunner:
seq_lens_sum=seq_lens.sum(), seq_lens_sum=seq_lens.sum(),
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=positions, positions=positions,
global_num_tokens=global_num_tokens, global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
...@@ -375,13 +380,14 @@ class CudaGraphRunner: ...@@ -375,13 +380,14 @@ class CudaGraphRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once() out = run_once()
torch.cuda.synchronize() torch.cuda.synchronize()
self.model_runner.tp_group.barrier() self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool() global_graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
...@@ -439,35 +445,26 @@ class CudaGraphRunner: ...@@ -439,35 +445,26 @@ class CudaGraphRunner:
) )
return logits_output return logits_output
def get_spec_info(self, num_tokens: int, positions: torch.Tensor): def get_spec_info(self, num_tokens: int):
spec_info = None spec_info = None
if self.model_runner.spec_algorithm.is_eagle(): if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.eagle_utils import EagleVerifyInput
EAGLEDraftInput,
EagleVerifyInput,
)
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
spec_info = EAGLEDraftInput() raise RuntimeError("This should not happen.")
spec_info.load_server_args(self.model_runner.server_args)
spec_info.hidden_states = self.hidden_states[:num_tokens]
spec_info.positions = positions
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
else: else:
spec_info = EagleVerifyInput( spec_info = EagleVerifyInput(
None, draft_token=None,
None, custom_mask=torch.zeros(
None, (num_tokens * self.model_runner.model_config.context_len),
None, dtype=torch.bool,
None, device="cuda",
None, ),
self.model_runner.server_args.speculative_num_draft_tokens, positions=None,
) retrive_index=None,
spec_info.custom_mask = torch.zeros( retrive_cum_len=None,
(num_tokens * self.model_runner.model_config.context_len), draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
dtype=torch.bool, capture_hidden_mode=CaptureHiddenMode.FULL,
device="cuda",
) )
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
return spec_info return spec_info
...@@ -197,64 +197,6 @@ class ForwardBatch: ...@@ -197,64 +197,6 @@ class ForwardBatch:
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -337,7 +279,7 @@ class ForwardBatch: ...@@ -337,7 +279,7 @@ class ForwardBatch:
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
if model_runner.model_is_mrope: if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch) ret._compute_mrope_positions(model_runner, batch)
# Init lora information # Init lora information
if model_runner.server_args.lora_paths is not None: if model_runner.server_args.lora_paths is not None:
...@@ -345,6 +287,63 @@ class ForwardBatch: ...@@ -345,6 +287,63 @@ class ForwardBatch:
return ret return ret
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
# text only
mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
vision_start_token_id=hf_config.vision_start_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_position_triton( def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
......
...@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -714,8 +715,6 @@ class ModelRunner: ...@@ -714,8 +715,6 @@ class ModelRunner:
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs.""" """Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
self.cuda_graph_runner = None self.cuda_graph_runner = None
if not self.is_generation: if not self.is_generation:
......
...@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected ...@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
) )
def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token): def build_tree_kernel(
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
):
bs = seq_lens.numel() bs = seq_lens.numel()
device = parent_list.device device = parent_list.device
tree_mask = torch.full( tree_mask = torch.full(
(torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,), (seq_lens_sum * draft_token + draft_token * draft_token * bs,),
True, True,
device=device, device=device,
) )
......
from __future__ import annotations
import bisect
import time
from typing import TYPE_CHECKING, Callable
import torch
from sglang.srt.model_executor.cuda_graph_runner import (
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
set_global_graph_memory_pool,
set_torch_compile_config,
)
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class EAGLEDraftCudaGraphRunner:
def __init__(self, eagle_worker: EAGLEWorker):
# Parse args
self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
server_args = model_runner.server_args
assert self.disable_padding
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = server_args.speculative_eagle_topk
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
if self.enable_torch_compile:
set_torch_compile_config()
# Graph inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros(
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
# Capture
try:
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n"
"1. disable cuda graph by --disable-cuda-graph\n"
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
num_tokens = num_seqs * self.num_tokens_per_bs
# Graph inputs
req_pool_indices = self.req_pool_indices[:num_seqs]
seq_lens = self.seq_lens[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
topk_p = self.topk_p[:num_seqs]
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
)
# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=num_seqs,
input_ids=None,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum(),
return_logprob=False,
positions=positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
forward_batch
)
# Run and capture
def run_once():
# Backup two fileds, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
ret = self.eagle_worker.draft_forward(forward_batch)
forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool())
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
forward_batch.out_cache_loc
)
self.positions[:raw_num_token].copy_(forward_batch.positions)
self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch
)
# Replay
self.graphs[bs].replay()
return self.output_buffers[bs]
import logging
import time
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
...@@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -12,8 +14,18 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.utils import rank0_print EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
assign_draft_cache_locs,
fast_topk,
select_top_k_tokens,
)
logger = logging.getLogger(__name__)
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
...@@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker): ...@@ -40,41 +52,47 @@ class EAGLEWorker(TpModelWorker):
is_draft_worker=True, is_draft_worker=True,
) )
self.target_worker = target_worker self.target_worker = target_worker
self.server_args = server_args
self.finish_extend_len = [] self.finish_extend_len = []
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.server_args = server_args
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.model_runner.init_cuda_graphs()
def forward_draft_decode(self, batch: ScheduleBatch): # Create multi-step attn backends and cuda graph runners
batch.spec_info.prepare_for_decode(batch) from sglang.srt.layers.attention.flashinfer_backend import (
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST FlashInferMultiStepDraftBackend,
model_worker_batch = batch.get_model_worker_batch() )
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch): self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self._set_mem_pool(batch, self.model_runner) self.model_runner,
batch.spec_info.prepare_for_extend(batch) self.topk,
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST self.speculative_num_steps,
model_worker_batch = batch.get_model_worker_batch() )
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) self.model_runner.draft_attn_backend = self.draft_attn_backend
logits_output = self.model_runner.forward(forward_batch) self.init_cuda_graphs()
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner) def init_cuda_graphs(self):
"""Capture cuda graphs."""
self.cuda_graph_runner = None
if self.server_args.disable_cuda_graph:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
def forward_batch_speculative_generation(self, batch: ScheduleBatch): def forward_batch_speculative_generation(self, batch: ScheduleBatch):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
# Draft # Draft
self._set_mem_pool(batch, self.model_runner) spec_info: EagleVerifyInput = self.draft(batch)
for i in range(self.server_args.speculative_num_steps):
self.forward_draft_decode(batch)
batch.spec_info.clear_draft_cache(batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
# Verify # Verify
( (
...@@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -84,8 +102,7 @@ class EAGLEWorker(TpModelWorker):
self.finish_extend_len, self.finish_extend_len,
accept_length_cpu, accept_length_cpu,
model_worker_batch, model_worker_batch,
) = self.verify(batch) ) = self.verify(batch, spec_info)
next_draft_input.load_server_args(self.server_args)
batch.spec_info = next_draft_input batch.spec_info = next_draft_input
# if it is None, means all requsets are finished # if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None: if batch.spec_info.verified_id is not None:
...@@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker): ...@@ -107,29 +124,145 @@ class EAGLEWorker(TpModelWorker):
) )
# Forward with the draft model. # Forward with the draft model.
spec_info = EAGLEDraftInput() batch.spec_info = EagleDraftInput(
spec_info.load_server_args(self.server_args) hidden_states=logits_output.hidden_states,
spec_info.hidden_states = logits_output.hidden_states verified_id=next_token_ids,
spec_info.verified_id = next_token_ids )
batch.spec_info = spec_info
self.forward_draft_extend(batch) self.forward_draft_extend(batch)
return logits_output, next_token_ids, model_worker_batch, 0 return logits_output, next_token_ids, model_worker_batch, 0
def verify(self, batch: ScheduleBatch): def draft(self, batch: ScheduleBatch):
verify_input = batch.spec_info.prepare_for_verify(batch) self._set_mem_pool(batch, self.model_runner)
verify_input.prepare_for_verify(batch)
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
)
assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch
)
if can_cuda_graph:
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
forward_batch
)
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
ret = EagleVerifyInput.create(
spec_info.verified_id,
score_list,
token_list,
parents_list,
batch.seq_lens,
batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
# Free cache locations
batch.token_to_kv_pool.free(out_cache_loc)
self._set_mem_pool(batch, self.target_worker.model_runner)
return ret
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
spec_info.topk_index,
spec_info.hidden_states,
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
parents_list: List[torch.Tensor] = []
# Forward multiple steps
scores = None
for i in range(self.speculative_num_steps):
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
i, topk_p, topk_index, hidden_states, scores, self.topk
)
score_list.append(tree_info[0])
token_list.append(tree_info[1])
parents_list.append(tree_info[2])
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[
forward_batch.batch_size
* self.topk
* i : forward_batch.batch_size
* self.topk
* (i + 1)
]
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch)
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = verify_input batch.spec_info = spec_info
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation( logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, skip_sample=True
) )
verify_input.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res = verify_input.verify(batch, logits_output) res = spec_info.verify(batch, logits_output)
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,) return res + (model_worker_batch,)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool batch.req_to_token_pool = runner.req_to_token_pool
...@@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -139,7 +272,7 @@ class EAGLEWorker(TpModelWorker):
self._set_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch) batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
...@@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -155,13 +288,10 @@ class EAGLEWorker(TpModelWorker):
def capture_for_decode( def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
): ):
sample_output = torch.softmax( probs = torch.softmax(logits_output.next_token_logits, dim=-1)
logits_output.next_token_logits, dim=-1
) # TODO(kavioyu): Support more sampling methods
spec_info = forward_batch.spec_info spec_info = forward_batch.spec_info
spec_info.sample_output = sample_output spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
spec_info.prev_mode = forward_batch.forward_mode
# Don't support prefix share now. # Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]): def finish_request(self, reqs: Union[Req, List[Req]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment