"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "11d22e0e809d1219a067ded8a18f7b0129fc58c7"
Unverified Commit e7bc6003 authored by Zhihao Zhang's avatar Zhihao Zhang Committed by GitHub
Browse files

[Feature] Speculative decoding support lookahead (#9873)


Co-authored-by: default avatara4zhangfei <a4zhangfei@qq.com>
Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
parent 2a2ff9a8
...@@ -102,6 +102,8 @@ dev = ["sglang[test]"] ...@@ -102,6 +102,8 @@ dev = ["sglang[test]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json", "srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json", "srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
"srt/speculative/cpp_lookahead/*.cpp",
"srt/speculative/cpp_lookahead/*.h",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
......
...@@ -1110,7 +1110,8 @@ def sample_sharegpt_requests( ...@@ -1110,7 +1110,8 @@ def sample_sharegpt_requests(
add_generation_prompt=True, add_generation_prompt=True,
tokenize=False, tokenize=False,
) )
prompt = prompt.replace(tokenizer.bos_token, "") if tokenizer.bos_token:
prompt = prompt.replace(tokenizer.bos_token, "")
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1] completion = dataset[i][1]
......
...@@ -29,6 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType ...@@ -29,6 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.utils import ( from sglang.srt.utils import (
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported, is_sm100_supported,
...@@ -317,7 +318,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -317,7 +318,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -422,7 +425,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -422,7 +425,9 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -638,7 +643,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -638,7 +643,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -651,7 +658,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -651,7 +658,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -673,7 +682,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -673,7 +682,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
assert self.sliding_window_size is not None assert self.sliding_window_size is not None
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -721,7 +732,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -721,7 +732,9 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -753,7 +766,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -753,7 +766,9 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
): ):
...@@ -858,7 +873,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -858,7 +873,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -873,7 +890,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -873,7 +890,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
if use_ragged: if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
...@@ -909,7 +928,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -909,7 +928,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -955,7 +976,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -955,7 +976,9 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -997,7 +1020,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -997,7 +1020,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
): ):
bs = len(seq_lens) bs = len(seq_lens)
...@@ -1024,8 +1049,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1024,8 +1049,8 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
else: else:
assert isinstance(spec_info, EagleDraftInput) or isinstance( assert isinstance(
spec_info, EagleVerifyInput spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
) )
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
......
...@@ -74,6 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton ...@@ -74,6 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -950,7 +951,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -950,7 +951,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
] = None
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
...@@ -1600,7 +1603,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1600,7 +1603,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs) bs = len(self.reqs)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_lookahead()
):
# if spec decoding is used, the decode batch is prepared inside # if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models. # `forward_batch_speculative_generation` after running draft models.
return return
...@@ -1975,7 +1982,9 @@ class ModelWorkerBatch: ...@@ -1975,7 +1982,9 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
......
...@@ -385,6 +385,18 @@ class Scheduler( ...@@ -385,6 +385,18 @@ class Scheduler(
target_worker=self.tp_worker, target_worker=self.tp_worker,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
elif self.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else: else:
self.draft_worker = None self.draft_worker = None
...@@ -740,8 +752,8 @@ class Scheduler( ...@@ -740,8 +752,8 @@ class Scheduler(
else ( else (
server_args.speculative_num_draft_tokens server_args.speculative_num_draft_tokens
+ ( + (
server_args.speculative_eagle_topk (server_args.speculative_eagle_topk or 1)
* server_args.speculative_num_steps * (server_args.speculative_num_steps or 1)
) )
) )
) )
...@@ -784,7 +796,7 @@ class Scheduler( ...@@ -784,7 +796,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
...@@ -821,7 +833,7 @@ class Scheduler( ...@@ -821,7 +833,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
...@@ -2358,9 +2370,8 @@ class Scheduler( ...@@ -2358,9 +2370,8 @@ class Scheduler(
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool_allocator.clear() self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none(): if self.draft_worker:
self.draft_worker.model_runner.req_to_token_pool.clear() self.draft_worker.clear_cache_pool()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
......
...@@ -84,6 +84,7 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat ...@@ -84,6 +84,7 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
trace_get_proc_propagate_context, trace_get_proc_propagate_context,
trace_req_finish, trace_req_finish,
...@@ -174,6 +175,15 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -174,6 +175,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.image_token_id = self.model_config.image_token_id self.image_token_id = self.model_config.image_token_id
self.max_req_input_len = None # Will be set later in engine.py self.max_req_input_len = None # Will be set later in engine.py
speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.reserve_input_token_num = (
0
if speculative_algorithm.is_none()
else server_args.speculative_num_draft_tokens
)
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors() import_processors()
try: try:
...@@ -618,6 +628,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -618,6 +628,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
_max_req_len = self.context_len _max_req_len = self.context_len
input_token_num = len(input_ids) if input_ids is not None else 0 input_token_num = len(input_ids) if input_ids is not None else 0
input_token_num += self.reserve_input_token_num
if input_token_num >= self.context_len: if input_token_num >= self.context_len:
if self.server_args.allow_auto_truncate: if self.server_args.allow_auto_truncate:
logger.warning( logger.warning(
......
...@@ -275,6 +275,7 @@ class CudaGraphRunner: ...@@ -275,6 +275,7 @@ class CudaGraphRunner:
if ( if (
model_runner.spec_algorithm.is_eagle() model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone() or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_lookahead()
): ):
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen") raise RuntimeError("This should not happen")
...@@ -441,11 +442,21 @@ class CudaGraphRunner: ...@@ -441,11 +442,21 @@ class CudaGraphRunner:
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
) )
is_lookahead_supported = (
(
forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel()
)
if self.model_runner.spec_algorithm.is_lookahead()
else True
)
return ( return (
is_bs_supported is_bs_supported
and is_encoder_lens_supported and is_encoder_lens_supported
and is_tbo_supported and is_tbo_supported
and capture_hidden_mode_matches and capture_hidden_mode_matches
and is_lookahead_supported
) )
def capture(self) -> None: def capture(self) -> None:
...@@ -856,6 +867,20 @@ class CudaGraphRunner: ...@@ -856,6 +867,20 @@ class CudaGraphRunner:
seq_lens_cpu=None, seq_lens_cpu=None,
) )
elif self.model_runner.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
spec_info = LookaheadVerifyInput(
draft_token=None,
tree_mask=self.custom_mask,
positions=None,
retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
draft_token_num=self.num_tokens_per_bs,
)
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
return spec_info return spec_info
......
...@@ -1402,7 +1402,7 @@ class ModelRunner: ...@@ -1402,7 +1402,7 @@ class ModelRunner:
if self.is_hybrid_gdn: if self.is_hybrid_gdn:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if not self.spec_algorithm.is_none(): if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if self.is_draft_worker: if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size self.max_total_num_tokens = self.server_args.draft_runner_cache_size
max_num_reqs = self.server_args.max_num_reqs max_num_reqs = self.server_args.max_num_reqs
......
...@@ -286,6 +286,14 @@ class ServerArgs: ...@@ -286,6 +286,14 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill" speculative_attention_mode: str = "prefill"
# For lookahead only
speculative_lookahead_min_match_window_size: int = 1
speculative_lookahead_max_match_window_size: int = 12
speculative_lookahead_min_bfs_breadth: int = 1
speculative_lookahead_max_bfs_breadth: int = 10
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_lookahead_branch_length: int = 18
speculative_lookahead_capacity: int = 10 * 1000 * 1000
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
...@@ -529,7 +537,7 @@ class ServerArgs: ...@@ -529,7 +537,7 @@ class ServerArgs:
# Standalone speculative decoding needs more memory than other speculative # Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger. # decoding algorithms since the draft model is typically larger.
reserved_mem += 6 * 1024 reserved_mem += 6 * 1024
else: elif self.speculative_algorithm != "LOOKAHEAD":
reserved_mem += 2 * 1024 reserved_mem += 2 * 1024
if self.enable_dp_attention: if self.enable_dp_attention:
reserved_mem += 4 * 1024 reserved_mem += 4 * 1024
...@@ -780,11 +788,11 @@ class ServerArgs: ...@@ -780,11 +788,11 @@ class ServerArgs:
self.speculative_algorithm = "EAGLE" self.speculative_algorithm = "EAGLE"
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
if self.speculative_algorithm == "STANDALONE": if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention:
# TODO: support dp attention for standalone speculative decoding # TODO: support dp attention for standalone speculative decoding
assert ( raise ValueError(
self.enable_dp_attention is False "Currently standalone speculative decoding does not support dp attention."
), "Currently standalone speculative decoding does not support dp attention." )
if self.max_running_requests is None: if self.max_running_requests is None:
self.max_running_requests = 48 self.max_running_requests = 48
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
...@@ -858,6 +866,39 @@ class ServerArgs: ...@@ -858,6 +866,39 @@ class ServerArgs:
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# assert self.speculative_num_steps < self.speculative_num_draft_tokens # assert self.speculative_num_steps < self.speculative_num_draft_tokens
if self.speculative_algorithm == "LOOKAHEAD":
if not self.device.startswith("cuda"):
raise ValueError(
"Lookahead speculative decoding only supports CUDA device."
)
if self.max_running_requests is None:
self.max_running_requests = 48
self.disable_overlap_schedule = True
self.enable_mixed_chunk = False
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
if self.speculative_num_draft_tokens is None:
# TODO: Do better auto choose in the future
self.speculative_num_draft_tokens = (
self.speculative_lookahead_max_match_window_size
)
logger.warning(
"The overlap scheduler and mixed chunked prefill are disabled because of "
"using lookahead speculative decoding."
)
if (
self.speculative_eagle_topk > 1
and self.page_size > 1
and self.attention_backend != "flashinfer"
):
raise ValueError(
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
)
if self.enable_dp_attention:
# TODO: support dp attention for lookahead speculative decoding
raise ValueError(
"Currently lookahead speculative decoding does not support dp attention."
)
# GGUF # GGUF
if ( if (
self.load_format == "auto" or self.load_format == "gguf" self.load_format == "auto" or self.load_format == "gguf"
...@@ -1690,7 +1731,7 @@ class ServerArgs: ...@@ -1690,7 +1731,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--speculative-algorithm", "--speculative-algorithm",
type=str, type=str,
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"], choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
help="Speculative algorithm.", help="Speculative algorithm.",
) )
parser.add_argument( parser.add_argument(
...@@ -1750,6 +1791,50 @@ class ServerArgs: ...@@ -1750,6 +1791,50 @@ class ServerArgs:
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode, default=ServerArgs.speculative_attention_mode,
) )
# Lookahead speculative decoding
parser.add_argument(
"--speculative-lookahead-min-match-window-size",
type=int,
default=ServerArgs.speculative_lookahead_min_match_window_size,
help="The minimum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-lookahead-max-match-window-size",
type=int,
default=ServerArgs.speculative_lookahead_max_match_window_size,
help="The maximum window size for pattern matching in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-lookahead-min-bfs-breadth",
type=int,
default=ServerArgs.speculative_lookahead_min_bfs_breadth,
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-lookahead-max-bfs-breadth",
type=int,
default=ServerArgs.speculative_lookahead_max_bfs_breadth,
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-lookahead-match-type",
type=str,
choices=["BFS", "PROB"],
default=ServerArgs.speculative_lookahead_match_type,
help="The match type for cache tree.",
)
parser.add_argument(
"--speculative-lookahead-branch-length",
type=int,
default=ServerArgs.speculative_lookahead_branch_length,
help="The branch length for lookahead speculative decoding.",
)
parser.add_argument(
"--speculative-lookahead-capacity",
type=int,
default=ServerArgs.speculative_lookahead_capacity,
help="The cache capacity for lookahead speculative decoding.",
)
# Expert parallelism # Expert parallelism
parser.add_argument( parser.add_argument(
......
../../../../../sgl-kernel/.clang-format
\ No newline at end of file
#include "lookahead.h"
#include <limits>
#include <vector>
namespace lookahead {
struct Node {
std::unordered_map<int32_t, int32_t> next;
};
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
Lookahead::Result info;
std::vector<int32_t> prevs;
info.token.reserve(draft_token_num);
prevs.reserve(draft_token_num);
std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
info.token.emplace_back(last_token);
prevs.emplace_back(-1);
for (auto [token, next] : tree[root].next) {
queue.emplace(token, next, 0);
}
while (queue.size()) {
auto [token, next, prev] = queue.front();
queue.pop();
info.token.emplace_back(token);
prevs.emplace_back(prev);
for (auto [t, n] : tree[next].next) {
queue.emplace(t, n, info.token.size() - 1);
}
}
// zero padding to length
while (info.token.size() < draft_token_num) {
info.token.emplace_back(0);
prevs.emplace_back(0);
}
int n = info.token.size();
info.mask.resize(n * n, 0);
info.mask[0] = 1;
for (int i = 0; i < n; ++i) {
if (prevs[i] != -1) {
memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
}
info.mask[i * n + i] = 1;
}
return info;
}
Lookahead::Lookahead(size_t capacity, const Param& param) {
param_ = param;
nodes_.resize(capacity);
for (auto& node : nodes_) {
node_pool_.emplace_back(&node);
}
free_node_count_ = node_pool_.size();
root_ = getNode();
if (!(param_.branch_length > 1)) {
throw std::runtime_error(
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
}
if (!(param_.min_match_window_size > 0)) {
throw std::runtime_error(
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
}
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
throw std::runtime_error(
"min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
std::to_string(param_.min_match_window_size) +
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
}
if (!(param_.max_match_window_size < param_.branch_length)) {
throw std::runtime_error(
"max_match_window_size must be less than branch_length, current max_match_window_size: " +
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
}
if (!(param_.min_bfs_breadth > 0)) {
throw std::runtime_error(
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
}
if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
throw std::runtime_error(
"min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
}
if (!(param_.draft_token_num > 0)) {
throw std::runtime_error(
"draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
}
for (auto config : param_.batch_draft_token_num) {
if (config != std::numeric_limits<decltype(config)>::max()) {
if (!(config <= param_.draft_token_num)) {
throw std::runtime_error(
"batch_draft_token_num config value " + std::to_string(config) +
" must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
}
}
}
for (auto config : param_.batch_min_match_window_size) {
if (config != std::numeric_limits<decltype(config)>::max()) {
if (!(config >= param_.min_match_window_size)) {
throw std::runtime_error(
"batch_min_match_window_size config value " + std::to_string(config) +
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
}
if (!(config <= param_.max_match_window_size)) {
throw std::runtime_error(
"batch_min_match_window_size config value " + std::to_string(config) +
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
}
}
}
quit_flag_ = false;
insert_worker_ = std::thread(&Lookahead::insert, this);
}
Lookahead::~Lookahead() {
quit_flag_ = true;
insert_queue_.close();
insert_worker_.join();
}
std::vector<std::pair<TrieNode*, int32_t>>
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
auto draft_token_num = param_.get_draft_token_num(batch_size);
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
auto max_match_window_size = param_.max_match_window_size;
std::vector<std::pair<TrieNode*, int32_t>> result;
result.reserve(param_.max_match_window_size - param_.min_match_window_size);
for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
match_window_size >= param_.min_match_window_size;
--match_window_size) {
auto start = tokens.data() + tokens.size() - match_window_size;
auto end = start + match_window_size;
auto cursor = root_;
while (start != end) {
auto iter = cursor->child.find(*start);
if (iter == cursor->child.end()) {
cursor = nullptr;
break;
}
++start;
cursor = iter->second;
}
if (cursor) {
result.emplace_back(std::make_pair(cursor, match_window_size));
}
}
return result;
}
void Lookahead::squeeze(size_t count) {
if (!(node_pool_.size() >= free_node_count_ + count)) {
throw std::runtime_error(
"Insufficient node size to release required nodes. "
"available to release: " +
std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
}
while (count--) {
auto last = global_lru_.back();
global_lru_.pop_back();
if (!last->child.empty()) {
throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
}
last->parent->lru.erase(last->parent_lru_pos);
last->parent->sorted_children.erase(last);
last->parent->child.erase(last->token);
node_pool_[free_node_count_++] = last;
}
}
void Lookahead::synchronize() const {
while (!insert_queue_.empty()) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
}
void Lookahead::insert() {
while (!quit_flag_) {
std::vector<int32_t> data;
if (!insert_queue_.dequeue(data)) {
continue;
}
const auto* token = data.data();
size_t size = data.size();
std::unique_lock<std::mutex> lock(mutex_);
for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
auto start = token + i;
auto end = start + std::min(size - i, param_.branch_length);
if (end - start > free_node_count_) {
squeeze(end - start - free_node_count_);
}
TrieNode* cursor = root_;
path_.clear();
while (start != end) {
auto token = *start;
auto iter = cursor->child.find(token);
if (iter == cursor->child.end()) {
iter = cursor->child.insert({token, getNode()}).first;
auto node = iter->second;
cursor->lru.emplace_front(node);
global_lru_.emplace_back(node);
node->token = token;
node->parent = cursor;
node->parent_lru_pos = cursor->lru.begin();
node->global_lru_pos = --global_lru_.end();
node->freq = 1;
cursor->sorted_children.insert(node);
} else {
auto node = iter->second;
cursor->sorted_children.erase(node);
node->freq++;
cursor->sorted_children.insert(node);
cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
}
cursor = iter->second;
path_.emplace_back(cursor);
++start;
}
for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
TrieNode* node = *it;
global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
}
}
}
}
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
for (auto&& token : tokens) {
insert_queue_.enqueue(std::move(token));
}
}
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
(param_.max_match_window_size - param_.min_match_window_size + 1);
auto draft_token_num = param_.get_draft_token_num(batch_size);
std::vector<Node> tree(draft_token_num + 1);
int root = 0;
int cursor = 1;
for (auto [node, depth] : nodes) {
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
while (queue.size() && cursor <= draft_token_num) {
auto front = queue.front();
queue.pop();
auto parent = std::get<0>(front);
auto cur_breadth = std::get<1>(front);
auto iter = std::get<2>(front)->lru.begin();
auto breadth = std::max(1, int32_t(cur_breadth));
for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
auto token = (*iter)->token;
auto pos = -1;
if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
pos = tit->second;
} else {
pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
}
queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
}
}
}
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
auto draft_token_num = param_.get_draft_token_num(batch_size);
struct CompareByLastDouble {
bool operator()(
const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
const std::tuple<double, const TrieNode*, double>& b) const {
return std::get<2>(a) < std::get<2>(b);
}
};
std::priority_queue<
std::tuple<double, const TrieNode*, double>,
std::vector<std::tuple<double, const TrieNode*, double>>,
CompareByLastDouble>
heap;
std::vector<Node> tree(draft_token_num + 1);
int root = 0;
int cursor = 1;
int top_k = param_.max_bfs_breadth;
auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
double sum_freq = 0.0;
int count = 0;
std::list<std::pair<TrieNode*, int32_t>> topk_children;
for (auto* child : trie_node->sorted_children) {
sum_freq += static_cast<double>(child->freq);
topk_children.emplace_back(child, child->freq);
if (++count >= top_k) break;
}
if (sum_freq <= 0) sum_freq = 1.0;
for (const auto& [child, freq] : topk_children) {
double norm_freq = static_cast<double>(freq) / sum_freq * prob;
heap.emplace(parent, child, norm_freq);
}
};
for (auto [node, _] : nodes) {
addToHeap(root, node, 1.0);
while (!heap.empty() && cursor <= draft_token_num) {
auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
heap.pop();
auto token = trie_node->token;
int pos = -1;
auto tit = tree[parent].next.find(token);
if (tit != tree[parent].next.end()) {
pos = tit->second;
} else {
pos = cursor++;
tree[parent].next[token] = pos;
}
addToHeap(pos, trie_node, prob);
}
}
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
std::unique_lock<std::mutex> lock(mutex_);
Result merged_result;
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
for (const auto& tks : tokens) {
Result res = (this->*match_func)(tks, tokens.size());
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
}
return merged_result;
}
void Lookahead::Result::truncate(size_t n) {
if (n < token.size()) {
int full_n = token.size();
for (int i = 1; i < n; ++i) {
memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
}
token.resize(n);
mask.resize(n * n);
}
}
} // namespace lookahead
#pragma once
#include <cstddef>
#include <cstdint>
#include <functional>
#include <list>
#include <mutex>
#include <set>
#include <sstream>
#include <thread>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "param.h"
#include "queue.h"
namespace lookahead {
struct TrieNode {
std::unordered_map<int32_t, TrieNode*> child;
std::list<TrieNode*>::const_iterator global_lru_pos;
std::list<TrieNode*>::const_iterator parent_lru_pos;
int32_t token;
TrieNode* parent;
std::list<TrieNode*> lru;
int32_t freq = 0;
struct CompareByFreq {
bool operator()(TrieNode* a, TrieNode* b) const {
return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b);
}
};
std::multiset<TrieNode*, CompareByFreq> sorted_children;
};
class Lookahead {
std::vector<TrieNode> nodes_;
std::vector<TrieNode*> node_pool_;
size_t free_node_count_;
std::list<TrieNode*> global_lru_;
TrieNode* root_;
std::vector<TrieNode*> path_;
Param param_;
std::vector<std::pair<TrieNode*, int32_t>> match(const std::vector<int32_t>& tokens, size_t batch_size) const;
void squeeze(size_t count);
TrieNode* getNode() {
auto node = node_pool_[--free_node_count_];
node->~TrieNode();
new (node) TrieNode();
return node;
}
mutable std::mutex mutex_;
bool quit_flag_;
utils::Queue<std::vector<int32_t>> insert_queue_;
std::thread insert_worker_;
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
public:
Lookahead(size_t capacity, const Param& param);
Lookahead() = default;
~Lookahead();
static Lookahead& instance() {
static Lookahead instance;
return instance;
}
void synchronize() const;
void asyncInsert(std::vector<std::vector<int32_t>>&& tokens);
struct Result {
std::vector<int32_t> token;
std::vector<uint8_t> mask;
void truncate(size_t n);
};
Result batchMatch(const std::vector<std::vector<int32_t>>& tokens) const;
void reset() {
std::unique_lock<std::mutex> lock(mutex_);
global_lru_.clear();
path_.clear();
node_pool_.clear();
for (auto& node : nodes_) {
node_pool_.emplace_back(&node);
}
free_node_count_ = node_pool_.size();
root_ = getNode();
}
const Param& param() const {
return param_;
}
private:
Result matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const;
Result matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const;
void insert();
};
} // namespace lookahead
# -*- coding: utf-8 -*-
# from sglang.op.lookahead import Lookahead, Param
import logging
import os
from typing import List, Tuple
import numpy as np
from torch.utils.cpp_extension import load
logger = logging.getLogger(__name__)
_abs_path = os.path.dirname(os.path.abspath(__file__))
lookahead_cache_cpp = load(
name="lookahead_cache_cpp",
sources=[
f"{_abs_path}/lookahead_cache_binding.cpp",
f"{_abs_path}/lookahead.cpp",
],
extra_cflags=["-O3", "-std=c++20"],
)
class LookaheadCache:
def __init__(
self,
branch_length=18,
min_match_window_size=1,
max_match_window_size=10,
min_bfs_breadth=1,
max_bfs_breadth=8,
draft_token_num=8,
match_type="BFS",
capacity=1000000,
):
param = lookahead_cache_cpp.Param()
param.branch_length = branch_length
param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size
param.min_bfs_breadth = min_bfs_breadth
param.max_bfs_breadth = max_bfs_breadth
param.draft_token_num = draft_token_num
param.match_type = match_type
self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
self.default_mask = np.ones((1, 1), dtype=np.int64)
self.draft_token_num = draft_token_num
def batch_put(self, batch_tokens: List[List[int]]):
self.cache.asyncInsert(batch_tokens)
def synchronize(self):
self.cache.synchronize()
def reset(self):
self.cache.reset()
def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
result = self.cache.batchMatch(batch_tokens)
return np.array(result.token), np.array(result.mask)
def leaf_paths_from_mask(
self, tokens: List[int], tree_mask: List[List[int]]
) -> List[List[int]]:
"""
Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
Args:
mask : List[List[int]] # nxn binary matrix
tokens : List[int] # token list corresponding to columns
Returns:
List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
"""
row_sets = [
(i, {idx for idx, v in enumerate(row) if v == 1})
for i, row in enumerate(tree_mask)
]
leaf_sets = []
leaf_rows = []
for i, cur_set in reversed(row_sets):
if any(cur_set <= kept for kept in leaf_sets):
continue
leaf_sets.append(cur_set)
leaf_rows.append(i)
leaf_rows.reverse()
result = []
for r in leaf_rows:
path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
result.append(path)
return result
def debug_result(
self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
):
decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
decoding_masks = decoding_masks.reshape(
-1, self.draft_token_num, self.draft_token_num
)
logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
for i in range(decoding_ids.shape[0]):
leaf_paths = self.leaf_paths_from_mask(
decoding_ids[i].tolist(), decoding_masks[i].tolist()
)
if tokenizer is None:
logger.info(f"draft path {i}: {leaf_paths}")
else:
logger.info(f"result {i}:")
for leaf_path in leaf_paths:
logger.info(
f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
)
# main function
if __name__ == "__main__":
format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
level=logging.DEBUG,
format=format,
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
token_ids = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
]
cache = LookaheadCache(branch_length=12, draft_token_num=8)
cache.batch_put(token_ids)
cache.synchronize()
decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
cache.debug_result(decoding_ids, decoding_masks)
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "lookahead.h"
PYBIND11_MODULE(lookahead_cache_cpp, m) {
using namespace lookahead;
namespace py = pybind11;
m.doc() = "";
py::class_<Lookahead>(m, "Lookahead")
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
.def("asyncInsert", &Lookahead::asyncInsert, "")
.def("batchMatch", &Lookahead::batchMatch, "")
.def("reset", &Lookahead::reset, "")
.def("synchronize", &Lookahead::synchronize, "");
py::class_<Param>(m, "Param")
.def(py::init<>())
.def_readwrite("enable", &Param::enable)
.def_readwrite("enable_router_mode", &Param::enable_router_mode)
.def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
.def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
.def_readwrite("min_match_window_size", &Param::min_match_window_size)
.def_readwrite("max_match_window_size", &Param::max_match_window_size)
.def_readwrite("branch_length", &Param::branch_length)
.def_readwrite("draft_token_num", &Param::draft_token_num)
.def_readwrite("match_type", &Param::match_type)
.def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
.def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
.def("get_draft_token_num", &Param::get_draft_token_num, "")
.def("get_min_match_window_size", &Param::get_min_match_window_size, "")
.def("parse", &Param::parse, "")
.def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
.def("detail", &Param::detail, "");
py::class_<Lookahead::Result>(m, "Result")
.def(py::init<>())
.def_readwrite("token", &Lookahead::Result::token)
.def_readwrite("mask", &Lookahead::Result::mask)
.def("truncate", &Lookahead::Result::truncate);
}
#pragma once
#include <cstddef>
#include <iostream>
#include <limits>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace lookahead {
struct Param {
bool enable;
bool enable_router_mode;
size_t min_bfs_breadth;
size_t max_bfs_breadth;
size_t min_match_window_size;
size_t max_match_window_size;
size_t branch_length;
size_t draft_token_num;
std::string match_type;
std::vector<size_t> batch_min_match_window_size;
std::vector<size_t> batch_draft_token_num;
size_t get_draft_token_num(size_t batch_size) const {
if (batch_size < batch_draft_token_num.size()) {
if (batch_draft_token_num[batch_size] !=
std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
return batch_draft_token_num[batch_size];
}
}
return draft_token_num - 1;
}
size_t get_min_match_window_size(size_t batch_size) const {
if (batch_size < batch_min_match_window_size.size()) {
if (batch_min_match_window_size[batch_size] !=
std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
return batch_min_match_window_size[batch_size];
}
}
return min_match_window_size;
}
std::vector<size_t> parse(const std::string& value) {
// 0-1|10,2-3|20,
std::vector<size_t> result;
if (value.empty()) {
return result;
}
std::vector<size_t> mark;
std::regex comma_re(",");
std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
for (auto p : std::vector<std::string>(first, last)) {
std::cerr << "seg " << p << std::endl;
}
for (const auto& seg : std::vector<std::string>(first, last)) {
std::regex pipe_re("\\|");
std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
std::vector<std::string> part(seg_first, seg_last);
for (auto p : part) {
std::cerr << "part " << p << std::endl;
}
if (part.size() != 2) {
throw std::runtime_error(
"failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
}
std::regex endash_re("-");
std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
std::vector<std::string> range(range_first, range_last);
if (range.size() != 2) {
throw std::runtime_error("failed to get range, invalid config: " + value);
}
size_t L = std::atoi(range[0].c_str());
size_t R = std::atoi(range[1].c_str());
if (L > R || R > 128) {
throw std::runtime_error("invalid range, config: " + value);
}
if (R >= result.size()) {
result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
mark.resize(result.size(), false);
}
size_t config = std::atoi(part[1].c_str());
do {
if (mark[L]) {
throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
}
mark[L] = true;
result[L] = config;
} while (++L <= R);
}
return result;
}
void resetBatchMinMatchWindowSize(const std::string& value) {
batch_min_match_window_size = parse(value);
}
void resetBatchReturnTokenNum(const std::string& value) {
batch_draft_token_num = parse(value);
}
std::string detail() {
std::stringstream ss;
ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
<< ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
<< ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
<< ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
<< ", match_type = " << match_type;
ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
ss << i << "|" << batch_min_match_window_size[i] << ",";
}
ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
for (int i = 0; i < batch_draft_token_num.size(); ++i) {
ss << i << "|" << batch_draft_token_num[i] << ",";
}
return ss.str();
}
};
} // namespace lookahead
#pragma once
#include <condition_variable>
#include <queue>
namespace utils {
template <typename T>
class Queue {
public:
bool enqueue(T&& rhs) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (closed_) {
return false;
}
queue_.emplace(std::move(rhs));
}
cv_.notify_one();
return true;
}
bool enqueue(const T& rhs) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (closed_) {
return false;
}
queue_.emplace(rhs);
}
cv_.notify_one();
return true;
}
bool dequeue(T& rhs) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return queue_.size() || closed_; });
if (closed_) {
return false;
}
rhs = std::move(queue_.front());
queue_.pop();
return true;
}
size_t size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
bool empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.empty();
}
void close() {
{
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
}
cv_.notify_all();
}
private:
std::queue<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
bool closed_{false};
};
} // namespace utils
...@@ -771,6 +771,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -771,6 +771,10 @@ class EAGLEWorker(TpModelWorker):
return score_list, token_list, parents_list return score_list, token_list, parents_list
def clear_cache_pool(self):
self.model_runner.req_to_token_pool.clear()
self.model_runner.token_to_kv_pool_allocator.clear()
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size) spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False batch.return_hidden_states = False
......
from __future__ import annotations
import copy
import logging
from typing import Optional
import torch
import triton
logger = logging.getLogger(__name__)
from dataclasses import dataclass
import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import (
TREE_SPEC_KERNEL_AVAILABLE,
assign_req_to_token_pool,
create_flashinfer_kv_indices_triton,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
elif is_hip():
from sgl_kernel import verify_tree_greedy
@dataclass
class LookaheadVerifyInput:
def __init__(
self,
draft_token: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
draft_token_num: int,
):
self.draft_token = draft_token
self.custom_mask = tree_mask
self.positions = positions
self.retrive_index = retrive_index
self.retrive_next_token = retrive_next_token
self.retrive_next_sibling = retrive_next_sibling
self.draft_token_num = draft_token_num
self.device = self.custom_mask.device
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.draft_token
if page_size == 1:
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num
else:
prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids)
)
self.last_loc = last_loc
bs = batch.batch_size()
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
end_offset,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
bs = len(req_pool_indices)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.qo_indptr = (
torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
* self.draft_token_num
)
kv_indices = torch.empty(
cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
def _fill_requests(
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
):
accept_index_cpu = self.accept_index.tolist()
predict_cpu = self.predict.tolist()
has_finished = False
# Iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
for j, idx in enumerate(accept_index_row):
if idx == -1:
break
id = predict_cpu[idx]
req.output_ids.append(id)
req.check_finished()
if req.finished():
has_finished = True
# set all tokens after finished token to -1 and break
self.accept_index[i, j + 1 :] = -1
break
else:
if req.grammar is not None:
try:
req.grammar.accept_token(id)
except ValueError as e:
logger.info(
f"{i=}, {req=}\n"
f"{self.accept_index=}\n"
f"{self.predict=}\n"
)
raise e
req.spec_verify_ct += 1
if has_finished:
self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
self.accept_index = self.accept_index[self.accept_index != -1]
logits_output.next_token_logits = logits_output.next_token_logits[
self.accept_index
]
if logits_output.hidden_states:
logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
self.verified_id = self.predict[self.accept_index]
def _free_cache(self, batch: ScheduleBatch, page_size: int):
bs = batch.batch_size()
# Free the KV cache for unaccepted tokens
if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it.
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[self.accept_index] = False
batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
else:
# Shift the accepted tokens to the beginning.
# Only evict the last part
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
batch.seq_lens,
batch.out_cache_loc,
self.accept_index,
self.accept_length,
self.draft_token_num,
page_size,
)
to_free_slots = torch.empty(
(to_free_num_slots.sum().item(),),
dtype=torch.int64,
device=to_free_num_slots.device,
)
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
# to_free_slots: [ 2, 5, 7 8]
# to_free_slots also needs to be page-aligned without the first partial page
#
# split each row of out_cache_loc into two parts.
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
# 2. the second part goes to to_free_slots.
get_target_cache_loc[(bs,)](
tgt_cache_loc,
to_free_slots,
self.accept_length,
to_free_num_slots,
batch.out_cache_loc,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
next_power_of_2(bs),
)
# Free the kv cache
batch.token_to_kv_pool_allocator.free(to_free_slots)
# Copy the kv cache
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
tgt_cache_loc, src_cache_loc
)
batch.out_cache_loc = tgt_cache_loc
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.accept_length + 1,
batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs),
)
def _greedy_verify(
self,
batch: ScheduleBatch,
logits_output: LogitsProcessorOutput,
):
bs = batch.batch_size()
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num)
candidates = self.draft_token.reshape(bs, self.draft_token_num)
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
self.accept_index = torch.full(
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
)
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
verify_tree_greedy(
predicts=self.predict, # mutable
accept_index=self.accept_index, # mutable
accept_token_num=self.accept_length, # mutable
candidates=candidates,
retrive_index=self.retrive_index,
retrive_next_token=self.retrive_next_token,
retrive_next_sibling=self.retrive_next_sibling,
target_predict=target_predict,
)
def _sampling_verify(
self,
batch: ScheduleBatch,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
):
bs = batch.batch_size()
candidates = self.draft_token.reshape(bs, self.draft_token_num)
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
self.accept_index = torch.full(
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
)
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
# apply temperature and get target probs
expanded_temperature = torch.repeat_interleave(
sampling_info.temperatures, self.draft_token_num, dim=0
) # (bs * draft_token_num, 1)
target_probs = F.softmax(
logits_output.next_token_logits / expanded_temperature, dim=-1
) # (bs * draft_token_num, vocab_size)
# NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
# contributing to the poor performance of _sampling_verify.
target_probs = top_k_renorm_prob(
target_probs,
torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
) # (bs * draft_token_num, vocab_size)
if sampling_info.need_top_p_sampling:
# logger.info("Using top-p sampling in speculative decoding verification.")
target_probs = top_p_renorm_prob(
target_probs,
torch.repeat_interleave(
sampling_info.top_ps, self.draft_token_num, dim=0
),
)
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
draft_probs = torch.zeros(
target_probs.shape, dtype=torch.float32, device=self.device
)
# coins for rejection sampling
coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
# coins for final sampling
coins_for_final_sampling = torch.rand(
(bs,), dtype=torch.float32, device=self.device
)
tree_speculative_sampling_target_only(
predicts=self.predict, # mutable
accept_index=self.accept_index, # mutable
accept_token_num=self.accept_length, # mutable
candidates=candidates.to(torch.int64),
retrive_index=self.retrive_index.to(torch.int64),
retrive_next_token=self.retrive_next_token.to(torch.int64),
retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
uniform_samples=coins,
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
deterministic=True,
)
def verify(
self,
batch: ScheduleBatch,
logits_output: LogitsProcessorOutput,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar
) -> torch.Tensor:
bs = self.retrive_index.shape[0]
sampling_info = batch.sampling_info
if bs != len(sampling_info):
sampling_info = copy.deepcopy(sampling_info)
# NOTE: retrive_index are the indices of the requests that are kept.
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor(
logits_output.next_token_logits,
sampling_info,
num_tokens_in_batch=self.draft_token_num,
)
# Apply penalty
if sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding.
linear_penalty = torch.zeros(
(bs, logits_output.next_token_logits.shape[1]),
dtype=torch.float32,
device=self.device,
)
sampling_info.apply_logits_bias(linear_penalty)
logits_output.next_token_logits.add_(
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
)
# Apply grammar mask
if vocab_mask is not None:
assert self.grammar is not None
self.grammar.apply_vocab_mask(
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
)
# Sample tokens. Force greedy sampling on AMD
is_all_greedy = sampling_info.is_all_greedy
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
logger.warning(
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
"Falling back to greedy verification."
)
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
self._greedy_verify(batch, logits_output)
else:
# NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
self._greedy_verify(batch, logits_output)
# self._sampling_verify(batch, logits_output, sampling_info)
self._fill_requests(batch, logits_output)
self._free_cache(batch, page_size)
batch.seq_lens.add_(self.accept_length + 1)
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
return logits_output, self.verified_id, self.accept_length.sum().item()
def filter_batch(self, new_indices: torch.Tensor):
pass
def merge_batch(self, spec_info: LookaheadVerifyInput):
pass
import logging
import os
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
logger = logging.getLogger(__name__)
USE_FULL_MASK = True
class LOOKAHEADWorker:
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
self.target_worker = target_worker
self.model_runner = target_worker.model_runner
self.tp_rank = tp_rank
self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens
self.branch_length: int = server_args.speculative_lookahead_branch_length
self.max_match_window_size: int = (
server_args.speculative_lookahead_max_match_window_size
)
self.max_batch_size = target_worker.max_running_requests
self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
self._init_preallocated_tensors()
self.lookahead_cache = LookaheadCache(
min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
capacity=server_args.speculative_lookahead_capacity,
branch_length=server_args.speculative_lookahead_branch_length,
draft_token_num=server_args.speculative_num_draft_tokens,
)
def clear_cache_pool(self):
self.lookahead_cache.reset()
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
seq2_len = len(seq2)
if seq2_len >= n:
return seq2[-n:]
need_from_seq1 = n - seq2_len
return seq1[-need_from_seq1:] + seq2
def _init_preallocated_tensors(self):
max_total_drafts = self.max_batch_size * self.draft_token_num
max_total_mask_size = (
self.max_batch_size * self.draft_token_num * self.draft_token_num
)
self.draft_tokens = torch.empty(
(max_total_drafts,), dtype=torch.int64, device=self.device
)
self.retrieve_indexes = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.retrive_next_token = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.retrive_next_sibling = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.positions = torch.empty(
(max_total_drafts,), dtype=torch.int64, device=self.device
)
self.tree_mask = torch.empty(
(max_total_mask_size,), dtype=torch.bool, device=self.device
)
self.draft_tokens_batch = []
self.tree_mask_batch = []
self.retrieve_indexes_batch = []
self.retrive_next_token_batch = []
self.retrive_next_sibling_batch = []
self.positions_batch = []
for bs in range(0, self.max_batch_size + 1):
self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
self.positions_batch.append(self.positions[: bs * self.draft_token_num])
self.draft_tokens_batch.append(
self.draft_tokens[: bs * self.draft_token_num]
)
self.tree_mask_batch.append(
self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
)
def _prepare_draft_tokens(
self, batch: ScheduleBatch
) -> tuple[np.ndarray, np.ndarray]:
bs = batch.batch_size()
self.lookahead_cache.synchronize()
batch_tokens = []
for req in batch.reqs:
check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_match_window_size
)
batch_tokens.append(check_token)
req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
total_draft_token_num = len(req_drafts)
# Check if speculative decoding is needed; here we always enforce it
assert (
total_draft_token_num == bs * self.draft_token_num
), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
return req_drafts, mask
def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
if batch.forward_mode.is_extend():
return
bs = batch.batch_size()
retrive_index = self.retrieve_indexes_batch[bs]
retrive_next_token = self.retrive_next_token_batch[bs]
retrive_next_sibling = self.retrive_next_sibling_batch[bs]
positions = self.positions_batch[bs]
tree_mask = self.tree_mask_batch[bs]
draft_tokens = self.draft_tokens_batch[bs]
req_drafts, mask = self._prepare_draft_tokens(batch)
tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
reconstruct_indices_from_tree_mask(
tree_mask,
batch.seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs,
self.draft_token_num,
)
# NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
# Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
if USE_FULL_MASK:
tree_mask = []
mask = mask.reshape(
batch.batch_size(), self.draft_token_num, self.draft_token_num
)
for i, req in enumerate(batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
req_mask = torch.cat(
(req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
).to(torch.bool)
tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0)
batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = LookaheadVerifyInput(
draft_tokens,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
self.draft_token_num,
)
batch.spec_info.prepare_for_verify(batch, self.page_size)
def _update_lookahead_cache(self, batch: ScheduleBatch):
batch_tokens = []
for req in batch.reqs:
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
# there is not much difference, so we will not insert it for now.
# if batch.forward_mode.is_extend():
# put_ids = req.origin_input_ids + req.output_ids
# else:
put_ids = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.branch_length
)
batch_tokens.append(put_ids)
self.lookahead_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch()
bid = model_worker_batch.bid
num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify():
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
)
verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size
)
self._update_lookahead_cache(batch)
batch.forward_mode = ForwardMode.DECODE
else:
logits_output, next_token_ids, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(model_worker_batch)
)
return (
logits_output,
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
)
...@@ -6,6 +6,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -6,6 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
EAGLE = auto() EAGLE = auto()
EAGLE3 = auto() EAGLE3 = auto()
STANDALONE = auto() STANDALONE = auto()
LOOKAHEAD = auto()
def is_none(self): def is_none(self):
return self == SpeculativeAlgorithm.NONE return self == SpeculativeAlgorithm.NONE
...@@ -19,12 +20,16 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -19,12 +20,16 @@ class SpeculativeAlgorithm(IntEnum):
def is_standalone(self): def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE return self == SpeculativeAlgorithm.STANDALONE
def is_lookahead(self):
return self == SpeculativeAlgorithm.LOOKAHEAD
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
name_map = { name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3, "EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE, "STANDALONE": SpeculativeAlgorithm.STANDALONE,
"LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
None: SpeculativeAlgorithm.NONE, None: SpeculativeAlgorithm.NONE,
} }
if name is not None: if name is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment