Unverified Commit d3d4d767 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] Refactor eagle speculative decoding (#3986)


Co-authored-by: default avatarKe Bao <ISPObaoke@163.com>
parent 5be8f1ed
......@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
......@@ -326,7 +326,7 @@ def latency_test_run_once(
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
model_runner.token_to_kv_pool_allocator.clear()
measurement_results = {
"run_name": run_name,
......
......@@ -20,14 +20,15 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available():
from flashinfer import (
......@@ -36,6 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
class WrapperDispatch(Enum):
......@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
......@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
]
if not self.skip_prefill:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
......@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
......@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
......@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
......@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if spec_info is None:
bs = len(req_pool_indices)
......@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.begin_forward(
kv_indptr,
kv_indices,
......@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
......@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if use_ragged:
paged_kernel_lens = prefix_lens
......@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
for wrapper_id in range(2):
if wrapper_id == 0:
......@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
bs = len(req_pool_indices)
bs = len(seq_lens)
if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
......@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
......@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
)
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global global_override_indptr_cpu
class FlashInferMultiStepDraftBackend:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
......@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
......@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
kv_last_page_len_buf=self.kv_last_page_len,
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
......@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
triton.next_power_of_2(bs),
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
global global_override_indptr_cpu
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
global_override_indptr_cpu = indptr_cpu_whole[i]
call_fn(i, forward_batch)
global_override_indptr_cpu = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
......@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
......@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
......@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
def should_use_tensor_core(
kv_cache_dtype: torch.dtype,
num_attention_heads: int,
......@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
if env_override is not None:
return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads
......@@ -1118,12 +1128,18 @@ def fast_decode_plan(
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
**kwargs,
non_blocking: bool = True,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size = len(last_page_len)
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
......@@ -1136,13 +1152,19 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
q_data_type = data_type
if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty(
0,
......@@ -1159,6 +1181,7 @@ def fast_decode_plan(
),
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
stream = torch.cuda.current_stream()
......
......@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
self.req_to_token,
)
)
......
......@@ -22,7 +22,7 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
logger = logging.getLogger(__name__)
......@@ -128,7 +128,7 @@ class HiCacheController:
def __init__(
self,
mem_pool_device: MHATokenToKVPool,
mem_pool_host: MLATokenToKVPoolHost,
mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective",
):
......
......@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
......@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
# Batch configs
......@@ -596,7 +594,7 @@ class ScheduleBatch:
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
......@@ -606,7 +604,7 @@ class ScheduleBatch:
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=model_config,
enable_overlap=enable_overlap,
......@@ -637,19 +635,19 @@ class ScheduleBatch:
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
logger.error(
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
)
if self.tree_cache is not None:
self.tree_cache.pretty_print()
......@@ -917,12 +915,12 @@ class ScheduleBatch:
def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
return False
......@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse=True,
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
......@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool.available_size() > 0
self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request"
break
......@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
......@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
......@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
self.tree_cache.evict(
residual_size, self.token_to_kv_pool_allocator.free
)
req.reset_for_retract()
self.filter_batch(keep_indices=sorted_indices)
......@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
......@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sum of all sequence lengths
......
......@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
......@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
......@@ -251,7 +255,7 @@ class PrefillAdder:
def __init__(
self,
tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
......@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
......@@ -291,7 +295,7 @@ class PrefillAdder:
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
......@@ -299,7 +303,7 @@ class PrefillAdder:
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
......@@ -332,7 +336,6 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
......@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens
self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
):
# Non-chunked prefill
self.can_run_list.append(req)
......@@ -411,10 +414,11 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
......@@ -457,10 +461,11 @@ class PrefillAdder:
),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
......
......@@ -164,7 +164,7 @@ class Scheduler:
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_steps
* self.server_args.speculative_num_draft_tokens
)
)
if not self.spec_algorithm.is_none()
......@@ -309,7 +309,9 @@ class Scheduler:
)
# Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
......@@ -317,18 +319,18 @@ class Scheduler:
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
......@@ -458,7 +460,6 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
]
)
......@@ -809,7 +810,8 @@ class Scheduler:
running_bs: int,
):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
......@@ -844,7 +846,8 @@ class Scheduler:
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
......@@ -894,7 +897,8 @@ class Scheduler:
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
......@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
......@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
......@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info(
"Decode out of memory happened. "
......@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
model_worker_batch,
bid,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
......@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=model_worker_batch.bid,
bid=bid,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
......@@ -1230,6 +1233,7 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
......@@ -1302,7 +1306,7 @@ class Scheduler:
if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_chunked <= 0:
......@@ -1420,23 +1424,27 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
assert batch.spec_algorithm.is_none()
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin()
self.token_to_kv_pool_allocator.free_group_begin()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
continue
if batch.spec_algorithm.is_none():
......@@ -1479,7 +1487,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
......@@ -1718,9 +1726,6 @@ class Scheduler:
and not self.model_config.is_multimodal_gen
)
):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
......@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
......@@ -1916,11 +1921,11 @@ class Scheduler:
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0
......
......@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
......@@ -257,9 +255,6 @@ class TokenizerManager:
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
......@@ -309,10 +304,6 @@ class TokenizerManager:
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
......@@ -774,14 +765,6 @@ class TokenizerManager:
)
return res[0].internal_state
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
res: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return res[0]
def get_log_request_metadata(self):
max_length = None
skip_names = None
......
......@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
......@@ -49,6 +50,8 @@ class TpModelWorker:
dp_rank: Optional[int],
nccl_port: int,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_rank = tp_rank
......@@ -77,6 +80,8 @@ class TpModelWorker:
nccl_port=nccl_port,
server_args=server_args,
is_draft_worker=is_draft_worker,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
)
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
......@@ -154,7 +159,7 @@ class TpModelWorker:
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def forward_batch_generation(
......
......@@ -100,7 +100,7 @@ class TpModelWorkerClient:
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool,
self.worker.model_runner.token_to_kv_pool_allocator,
)
def forward_thread_func(self):
......
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -21,11 +20,13 @@ class ChunkCacheEntry:
class ChunkCache(BasePrefixCache):
def __init__(
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
......@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
req.req_pool_idx, :token_id_len
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries:
del self.entries[req.rid]
......@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
def protected_size(self):
return 0
def pretty_print(self):
return ""
......@@ -7,8 +7,8 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
BaseTokenToKVPool,
MLATokenToKVPoolHost,
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
......@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool: MHATokenToKVPool,
):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
)
......
......@@ -20,9 +20,12 @@ Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
"""
import abc
import logging
import threading
from enum import IntEnum
......@@ -89,7 +92,7 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class BaseTokenToKVPool:
class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data."""
def __init__(
......@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device
self.free_slots = None
......@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
self.is_in_free_group = False
self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
......@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
raise NotImplementedError()
class MHATokenToKVPool(BaseTokenToKVPool):
class MHATokenToKVPool(KVCache):
def __init__(
self,
......@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
......@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2[loc] = src_2.to(dtype).view(store_dtype)
class MLATokenToKVPool(BaseTokenToKVPool):
class MLATokenToKVPool(KVCache):
def __init__(
self,
size: int,
......@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.kv_lora_rank = kv_lora_rank
memory_saver_adapter = TorchMemorySaverAdapter.create(
......@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self.kv_buffer[layer_id][loc] = cache_k
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,
size: int,
......@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
heavy_channel_num: int,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
......@@ -437,12 +460,12 @@ def synchronized(func):
return wrapper
class MLATokenToKVPoolHost:
class MHATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 4.0,
host_to_device_ratio: float = 2.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
......
......@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.disable = disable
self.reset()
......@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len
]
self.token_to_kv_pool.free(kv_indices)
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
......@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
......@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
......
......@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -98,6 +99,8 @@ class ModelRunner:
nccl_port: int,
server_args: ServerArgs,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.model_config = model_config
......@@ -115,6 +118,8 @@ class ModelRunner:
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
# Model-specific adjustment
if (
......@@ -257,8 +262,8 @@ class ModelRunner:
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"
elif self.device == "xpu":
......@@ -660,12 +665,25 @@ class ModelRunner:
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
max_num_reqs = self.server_args.max_num_reqs
else:
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
self.server_args.draft_runner_cache_size = (
self.max_total_num_tokens
+ max_num_reqs * self.server_args.speculative_num_steps
# draft
+ max_num_reqs
* self.server_args.speculative_num_steps
* self.server_args.speculative_eagle_topk
# verify
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
# buffer
+ 100
)
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
self.server_args.max_num_reqs = max_num_reqs
if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens:
......@@ -681,12 +699,25 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
)
else:
assert self.is_draft_worker
if (
self.model_config.attention_arch == AttentionArch.MLA
......
......@@ -280,11 +280,16 @@ class ServerArgs:
self.disable_overlap_schedule = True
self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True
self.disable_radix_cache = True
self.chunked_prefill_size = -1
if self.max_running_requests is None:
self.max_running_requests = 32
logger.info(
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
"Overlap scheduler are disabled because of using "
"eagle speculative decoding."
"Max running request set to 32 because of using eagle speculative decoding."
)
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
assert self.speculative_num_steps < self.speculative_num_draft_tokens
# GGUF
if (
......
......@@ -3,14 +3,8 @@
from typing import List
import torch
from sglang.srt.utils import is_cuda_available
if is_cuda_available():
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
def build_tree_kernel_efficient_preprocess(
......
......@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_worker import EAGLEWorker
......
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING, List
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import (
build_tree_kernel,
......@@ -25,7 +26,7 @@ if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@dataclasses.dataclass
@dataclass
class EagleDraftInput:
# The inputs for decode
# shape: (b, topk)
......@@ -46,57 +47,46 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
# indices of unfinished requests during extend-after-decode
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices: List[int] = None
def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
batch.out_cache_loc = out_cache_loc
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
pt = 0
for i, req in enumerate(batch.reqs):
req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.concat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += req.extend_input_len
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
assert len(batch.req_pool_indices) == len(batch.reqs)
pt = 0
i = 0
for req in batch.reqs:
self.keep_indices = []
for idx, req in enumerate(batch.reqs):
if req.finished():
continue
self.keep_indices.append(idx)
# assert seq_len - pre_len == req.extend_input_len
input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len
i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
self.accept_length.add_(1)
create_extend_spec_info[(self.accept_length.numel(),)](
......@@ -117,14 +107,22 @@ class EagleDraftInput:
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
req_pool_indices = req_pool_indices[keep_indices]
assert req_pool_indices.shape[0] == bs
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
......@@ -162,7 +160,21 @@ class EagleDraftInput:
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclasses.dataclass
@dataclass
class EagleVerifyOutput:
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepeted token ids including the bonus token
verified_id: torch.Tensor
# Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepeted indices from logits_output.next_token_logits
accepeted_indices_cpu: List[int]
@dataclass
class EagleVerifyInput:
draft_token: torch.Tensor
custom_mask: torch.Tensor
......@@ -267,6 +279,7 @@ class EagleVerifyInput:
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
batch_size = len(req_pool_indices)
......@@ -285,7 +298,11 @@ class EagleVerifyInput:
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(
paged_kernel_lens_sum + self.draft_token_num * batch_size,
dtype=torch.int32,
device="cuda",
)
create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token,
......@@ -298,7 +315,21 @@ class EagleVerifyInput:
)
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
def verify(
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
) -> torch.Tensor:
"""WARNING: This API in-place modifies the states of logits_output
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits.
"""
draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
dim=-1,
......@@ -367,7 +398,6 @@ class EagleVerifyInput:
new_accept_index = []
unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
has_finished = False
......@@ -382,7 +412,6 @@ class EagleVerifyInput:
id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id)
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
has_finished = True
......@@ -400,11 +429,10 @@ class EagleVerifyInput:
accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
batch.token_to_kv_pool.free(mem_need_free_idx)
token_to_kv_pool_allocator.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
......@@ -427,20 +455,16 @@ class EagleVerifyInput:
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index
]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (
draft_input,
logits_output,
verified_id,
finished_extend_len,
accept_length_cpu,
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
return EagleVerifyOutput(
draft_input=draft_input,
logits_output=logits_output,
verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu,
accepeted_indices_cpu=accept_index,
)
......@@ -456,6 +480,18 @@ def eagle_verify_retrive(
draft_token_num: tl.constexpr,
max_len_upper: tl.constexpr,
):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid = tl.program_id(axis=0)
retrive_end = tl.load(retrive_cum_len + pid + 1)
......@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
tl.store(kv_indptr + zid, base + zid * iters)
@torch.compile
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
......@@ -671,13 +707,11 @@ def select_top_k_tokens(
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
......
import logging
import os
import time
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
from huggingface_hub import snapshot_download
......@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_available_gpu_memory
logger = logging.getLogger(__name__)
......@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
nccl_port: int,
target_worker: TpModelWorker,
):
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Load hot token ids
# Lossy optimization by using hot tokens
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
......@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
else:
self.hot_token_id = None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init target worker
super().__init__(
gpu_id=gpu_id,
......@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
nccl_port=nccl_port,
dp_rank=dp_rank,
is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
self.target_worker = target_worker
self.finish_extend_len = []
# Parse arguments
self.topk = server_args.speculative_eagle_topk
......@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
server_args.speculative_algorithm
)
self.server_args = server_args
self.use_nan_detection = self.server_args.enable_nan_detection
self.device = self.model_runner.device
self.gpu_id = self.model_runner.gpu_id
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
......@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
self.draft_model_runner.model.set_embed_and_head(embed, head)
self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph
)
# Create multi-step attn backends and cuda graph runners
if server_args.attention_backend == "flashinfer":
......@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
)
self.model_runner.draft_attn_backend = self.draft_attn_backend
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self):
......@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
@property
def draft_model_runner(self):
return self.model_runner
def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch doesn't have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
Returns:
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert not batch.spec_algorithm.is_none()
if batch.forward_mode.is_decode():
# Draft
spec_info: EagleVerifyInput = self.draft(batch)
# Verify
(
next_draft_input,
logits_output,
verified_id,
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch, spec_info)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
spec_info, to_free_cache_loc = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# if it is None, means all requests are finished
if batch.spec_info.verified_id is not None:
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verified_id,
model_worker_batch,
sum(accept_length_cpu),
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
)
else:
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
# Forward with the draft model.
batch.spec_info = EagleDraftInput(
hidden_states=logits_output.hidden_states,
verified_id=next_token_ids,
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids
)
self.forward_draft_extend(batch)
return logits_output, next_token_ids, model_worker_batch, 0
return logits_output, next_token_ids, bid, 0
def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]:
"""Run the target extend.
Args:
batch: The batch to run. States could be modified.
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
return logits_output, next_token_ids, model_worker_batch.bid
def draft(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
......@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
......@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
# Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch
)
if can_cuda_graph:
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
forward_batch
......@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
......@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
batch.sampling_info.is_all_greedy,
)
# Free cache locations
batch.token_to_kv_pool.free(out_cache_loc)
self._set_mem_pool(batch, self.target_worker.model_runner)
return ret
return ret, out_cache_loc
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
......@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
logits_output = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
......@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res = spec_info.verify(batch, logits_output)
res: EagleVerifyOutput = spec_info.verify(
batch, logits_output, self.token_to_kv_pool_allocator
)
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices_cpu
]
logits_output.hidden_states = logits_output.hidden_states[
res.accepeted_indices_cpu
]
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,)
batch.spec_info = res.draft_input
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
return logits_output, res, model_worker_batch
def forward_draft_extend(
self,
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
):
"""Run draft model extend. This API modifies the states of the batch.
Args:
batch: The batch to run.
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.token_to_kv_pool = runner.token_to_kv_pool
batch.req_to_token_pool = runner.req_to_token_pool
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
req_pool_indices_backup = batch.req_pool_indices
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
):
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
spec_info = forward_batch.spec_info
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
spec_info.hidden_states = logits_output.hidden_states
# Don't support prefix share now.
def finish_request(self, reqs: Union[Req, List[Req]]):
if not isinstance(reqs, List):
reqs = [reqs]
for req in reqs:
if req.rid not in self.finish_extend_len:
continue
req_len = (
len(req.origin_input_ids)
+ len(req.output_ids)
- self.finish_extend_len[req.rid]
- 1
)
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
req.req_pool_idx
][:req_len]
self.model_runner.token_to_kv_pool.free(kv_indices)
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.use_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
raise ValueError("Detected errors during sampling! NaN in the logits.")
def load_token_map(token_map_path: str) -> List[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment