Unverified Commit 24f7cb1e authored by Zhihao Zhang's avatar Zhihao Zhang Committed by GitHub
Browse files

[speculative decoding] rename lookahead to ngram (#11010)


Co-authored-by: default avatara4zhangfei <a4zhangfei@qq.com>
parent e05555fa
...@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"] ...@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
"srt/layers/moe/fused_moe_triton/configs/*/*.json", "srt/layers/moe/fused_moe_triton/configs/*/*.json",
"srt/layers/quantization/configs/*.json", "srt/layers/quantization/configs/*.json",
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
"srt/speculative/cpp_lookahead/*.cpp", "srt/speculative/cpp_ngram/*.cpp",
"srt/speculative/cpp_lookahead/*.h", "srt/speculative/cpp_ngram/*.h",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
......
...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType ...@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_flashinfer_available, is_flashinfer_available,
...@@ -344,9 +344,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -344,9 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -453,9 +451,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -453,9 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
...@@ -673,9 +669,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -673,9 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -690,9 +684,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -690,9 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -718,9 +710,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -718,9 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -770,9 +760,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -770,9 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None, disable_split_kv: Optional[bool] = None,
): ):
...@@ -806,9 +794,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -806,9 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
...@@ -919,9 +905,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -919,9 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
...@@ -937,9 +921,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -937,9 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
if use_ragged: if use_ragged:
...@@ -977,9 +959,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -977,9 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1026,9 +1006,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1026,9 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
...@@ -1071,9 +1049,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1071,9 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None, fixed_split_size: Optional[int] = None,
): ):
...@@ -1102,7 +1078,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1102,7 +1078,7 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask = None custom_mask = None
else: else:
assert isinstance( assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput) spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
) )
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
......
...@@ -74,7 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton ...@@ -74,7 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -953,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -953,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] None
] = None )
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
...@@ -1608,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1608,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if ( if (
self.spec_algorithm.is_eagle() self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone() or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_lookahead() or self.spec_algorithm.is_ngram()
): ):
# if spec decoding is used, the decode batch is prepared inside # if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models. # `forward_batch_speculative_generation` after running draft models.
...@@ -1984,9 +1984,9 @@ class ModelWorkerBatch: ...@@ -1984,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput] None
] = None )
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
......
...@@ -388,10 +388,10 @@ class Scheduler( ...@@ -388,10 +388,10 @@ class Scheduler(
target_worker=self.tp_worker, target_worker=self.tp_worker,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
elif self.spec_algorithm.is_lookahead(): elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = LOOKAHEADWorker( self.draft_worker = NGRAMWorker(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank, moe_ep_rank=moe_ep_rank,
...@@ -826,7 +826,7 @@ class Scheduler( ...@@ -826,7 +826,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None or self.spec_algorithm.is_lookahead() if self.draft_worker is None or self.spec_algorithm.is_ngram()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
...@@ -863,7 +863,7 @@ class Scheduler( ...@@ -863,7 +863,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None or self.spec_algorithm.is_lookahead() if self.draft_worker is None or self.spec_algorithm.is_ngram()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
......
...@@ -246,7 +246,7 @@ class CudaGraphRunner: ...@@ -246,7 +246,7 @@ class CudaGraphRunner:
if ( if (
model_runner.spec_algorithm.is_eagle() model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone() or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_lookahead() or model_runner.spec_algorithm.is_ngram()
): ):
if self.model_runner.is_draft_worker: if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen") raise RuntimeError("This should not happen")
...@@ -413,12 +413,12 @@ class CudaGraphRunner: ...@@ -413,12 +413,12 @@ class CudaGraphRunner:
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
) )
is_lookahead_supported = ( is_ngram_supported = (
( (
forward_batch.batch_size * self.num_tokens_per_bs forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel() == forward_batch.input_ids.numel()
) )
if self.model_runner.spec_algorithm.is_lookahead() if self.model_runner.spec_algorithm.is_ngram()
else True else True
) )
...@@ -427,7 +427,7 @@ class CudaGraphRunner: ...@@ -427,7 +427,7 @@ class CudaGraphRunner:
and is_encoder_lens_supported and is_encoder_lens_supported
and is_tbo_supported and is_tbo_supported
and capture_hidden_mode_matches and capture_hidden_mode_matches
and is_lookahead_supported and is_ngram_supported
) )
def capture(self) -> None: def capture(self) -> None:
...@@ -838,10 +838,10 @@ class CudaGraphRunner: ...@@ -838,10 +838,10 @@ class CudaGraphRunner:
seq_lens_cpu=None, seq_lens_cpu=None,
) )
elif self.model_runner.spec_algorithm.is_lookahead(): elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
spec_info = LookaheadVerifyInput( spec_info = NgramVerifyInput(
draft_token=None, draft_token=None,
tree_mask=self.custom_mask, tree_mask=self.custom_mask,
positions=None, positions=None,
......
...@@ -286,14 +286,14 @@ class ServerArgs: ...@@ -286,14 +286,14 @@ class ServerArgs:
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill" speculative_attention_mode: str = "prefill"
# For lookahead only # For ngram only
speculative_lookahead_min_match_window_size: int = 1 speculative_ngram_min_match_window_size: int = 1
speculative_lookahead_max_match_window_size: int = 12 speculative_ngram_max_match_window_size: int = 12
speculative_lookahead_min_bfs_breadth: int = 1 speculative_ngram_min_bfs_breadth: int = 1
speculative_lookahead_max_bfs_breadth: int = 10 speculative_ngram_max_bfs_breadth: int = 10
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS" speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
speculative_lookahead_branch_length: int = 18 speculative_ngram_branch_length: int = 18
speculative_lookahead_capacity: int = 10 * 1000 * 1000 speculative_ngram_capacity: int = 10 * 1000 * 1000
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
...@@ -566,7 +566,7 @@ class ServerArgs: ...@@ -566,7 +566,7 @@ class ServerArgs:
# Standalone speculative decoding needs more memory than other speculative # Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger. # decoding algorithms since the draft model is typically larger.
reserved_mem += 6 * 1024 reserved_mem += 6 * 1024
elif self.speculative_algorithm != "LOOKAHEAD": elif self.speculative_algorithm != "NGRAM":
reserved_mem += 2 * 1024 reserved_mem += 2 * 1024
if self.enable_dp_attention: if self.enable_dp_attention:
reserved_mem += 4 * 1024 reserved_mem += 4 * 1024
...@@ -1024,23 +1024,23 @@ class ServerArgs: ...@@ -1024,23 +1024,23 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
) )
if self.speculative_algorithm == "LOOKAHEAD": if self.speculative_algorithm == "NGRAM":
if not self.device.startswith("cuda"): if not self.device.startswith("cuda"):
raise ValueError( raise ValueError(
"Lookahead speculative decoding only supports CUDA device." "Ngram speculative decoding only supports CUDA device."
) )
if self.max_running_requests is None: if self.max_running_requests is None:
self.max_running_requests = 48 self.max_running_requests = 48
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
self.enable_mixed_chunk = False self.enable_mixed_chunk = False
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth
if self.speculative_num_draft_tokens is None: if self.speculative_num_draft_tokens is None:
self.speculative_num_draft_tokens = ( self.speculative_num_draft_tokens = (
self.speculative_lookahead_max_match_window_size self.speculative_ngram_max_match_window_size
) )
logger.warning( logger.warning(
"The overlap scheduler and mixed chunked prefill are disabled because of " "The overlap scheduler and mixed chunked prefill are disabled because of "
"using lookahead speculative decoding." "using ngram speculative decoding."
) )
if ( if (
...@@ -1052,9 +1052,9 @@ class ServerArgs: ...@@ -1052,9 +1052,9 @@ class ServerArgs:
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
) )
if self.enable_dp_attention: if self.enable_dp_attention:
# TODO: support dp attention for lookahead speculative decoding # TODO: support dp attention for ngram speculative decoding
raise ValueError( raise ValueError(
"Currently lookahead speculative decoding does not support dp attention." "Currently ngram speculative decoding does not support dp attention."
) )
def _handle_load_format(self): def _handle_load_format(self):
...@@ -1921,7 +1921,7 @@ class ServerArgs: ...@@ -1921,7 +1921,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--speculative-algorithm", "--speculative-algorithm",
type=str, type=str,
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"], choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"],
help="Speculative algorithm.", help="Speculative algorithm.",
) )
parser.add_argument( parser.add_argument(
...@@ -1981,49 +1981,49 @@ class ServerArgs: ...@@ -1981,49 +1981,49 @@ class ServerArgs:
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
default=ServerArgs.speculative_attention_mode, default=ServerArgs.speculative_attention_mode,
) )
# Lookahead speculative decoding # Ngram speculative decoding
parser.add_argument( parser.add_argument(
"--speculative-lookahead-min-match-window-size", "--speculative-ngram-min-match-window-size",
type=int, type=int,
default=ServerArgs.speculative_lookahead_min_match_window_size, default=ServerArgs.speculative_ngram_min_match_window_size,
help="The minimum window size for pattern matching in lookahead speculative decoding.", help="The minimum window size for pattern matching in ngram speculative decoding.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-max-match-window-size", "--speculative-ngram-max-match-window-size",
type=int, type=int,
default=ServerArgs.speculative_lookahead_max_match_window_size, default=ServerArgs.speculative_ngram_max_match_window_size,
help="The maximum window size for pattern matching in lookahead speculative decoding.", help="The maximum window size for pattern matching in ngram speculative decoding.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-min-bfs-breadth", "--speculative-ngram-min-bfs-breadth",
type=int, type=int,
default=ServerArgs.speculative_lookahead_min_bfs_breadth, default=ServerArgs.speculative_ngram_min_bfs_breadth,
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", help="The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-max-bfs-breadth", "--speculative-ngram-max-bfs-breadth",
type=int, type=int,
default=ServerArgs.speculative_lookahead_max_bfs_breadth, default=ServerArgs.speculative_ngram_max_bfs_breadth,
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", help="The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-match-type", "--speculative-ngram-match-type",
type=str, type=str,
choices=["BFS", "PROB"], choices=["BFS", "PROB"],
default=ServerArgs.speculative_lookahead_match_type, default=ServerArgs.speculative_ngram_match_type,
help="The match type for cache tree.", help="The match type for cache tree.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-branch-length", "--speculative-ngram-branch-length",
type=int, type=int,
default=ServerArgs.speculative_lookahead_branch_length, default=ServerArgs.speculative_ngram_branch_length,
help="The branch length for lookahead speculative decoding.", help="The branch length for ngram speculative decoding.",
) )
parser.add_argument( parser.add_argument(
"--speculative-lookahead-capacity", "--speculative-ngram-capacity",
type=int, type=int,
default=ServerArgs.speculative_lookahead_capacity, default=ServerArgs.speculative_ngram_capacity,
help="The cache capacity for lookahead speculative decoding.", help="The cache capacity for ngram speculative decoding.",
) )
# Expert parallelism # Expert parallelism
......
#include "lookahead.h" #include "ngram.h"
#include <limits> #include <limits>
#include <vector> #include <vector>
namespace lookahead { namespace ngram {
struct Node { struct Node {
std::unordered_map<int32_t, int32_t> next; std::unordered_map<int32_t, int32_t> next;
}; };
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) { Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
Lookahead::Result info; Ngram::Result info;
std::vector<int32_t> prevs; std::vector<int32_t> prevs;
info.token.reserve(draft_token_num); info.token.reserve(draft_token_num);
prevs.reserve(draft_token_num); prevs.reserve(draft_token_num);
...@@ -50,7 +50,7 @@ Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<No ...@@ -50,7 +50,7 @@ Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<No
return info; return info;
} }
Lookahead::Lookahead(size_t capacity, const Param& param) { Ngram::Ngram(size_t capacity, const Param& param) {
param_ = param; param_ = param;
nodes_.resize(capacity); nodes_.resize(capacity);
for (auto& node : nodes_) { for (auto& node : nodes_) {
...@@ -116,17 +116,16 @@ Lookahead::Lookahead(size_t capacity, const Param& param) { ...@@ -116,17 +116,16 @@ Lookahead::Lookahead(size_t capacity, const Param& param) {
} }
quit_flag_ = false; quit_flag_ = false;
insert_worker_ = std::thread(&Lookahead::insert, this); insert_worker_ = std::thread(&Ngram::insert, this);
} }
Lookahead::~Lookahead() { Ngram::~Ngram() {
quit_flag_ = true; quit_flag_ = true;
insert_queue_.close(); insert_queue_.close();
insert_worker_.join(); insert_worker_.join();
} }
std::vector<std::pair<TrieNode*, int32_t>> std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
auto draft_token_num = param_.get_draft_token_num(batch_size); auto draft_token_num = param_.get_draft_token_num(batch_size);
auto min_match_window_size = param_.get_min_match_window_size(batch_size); auto min_match_window_size = param_.get_min_match_window_size(batch_size);
auto max_match_window_size = param_.max_match_window_size; auto max_match_window_size = param_.max_match_window_size;
...@@ -154,7 +153,7 @@ Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const { ...@@ -154,7 +153,7 @@ Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
return result; return result;
} }
void Lookahead::squeeze(size_t count) { void Ngram::squeeze(size_t count) {
if (!(node_pool_.size() >= free_node_count_ + count)) { if (!(node_pool_.size() >= free_node_count_ + count)) {
throw std::runtime_error( throw std::runtime_error(
"Insufficient node size to release required nodes. " "Insufficient node size to release required nodes. "
...@@ -177,13 +176,13 @@ void Lookahead::squeeze(size_t count) { ...@@ -177,13 +176,13 @@ void Lookahead::squeeze(size_t count) {
} }
} }
void Lookahead::synchronize() const { void Ngram::synchronize() const {
while (!insert_queue_.empty()) { while (!insert_queue_.empty()) {
std::this_thread::sleep_for(std::chrono::microseconds(10)); std::this_thread::sleep_for(std::chrono::microseconds(10));
} }
} }
void Lookahead::insert() { void Ngram::insert() {
while (!quit_flag_) { while (!quit_flag_) {
std::vector<int32_t> data; std::vector<int32_t> data;
if (!insert_queue_.dequeue(data)) { if (!insert_queue_.dequeue(data)) {
...@@ -239,13 +238,13 @@ void Lookahead::insert() { ...@@ -239,13 +238,13 @@ void Lookahead::insert() {
} }
} }
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) { void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
for (auto&& token : tokens) { for (auto&& token : tokens) {
insert_queue_.enqueue(std::move(token)); insert_queue_.enqueue(std::move(token));
} }
} }
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const { Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size); std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) / double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
...@@ -284,7 +283,7 @@ Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t ...@@ -284,7 +283,7 @@ Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t
return fillResult(tokens.back(), draft_token_num + 1, tree, root); return fillResult(tokens.back(), draft_token_num + 1, tree, root);
} }
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const { Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size); std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
auto draft_token_num = param_.get_draft_token_num(batch_size); auto draft_token_num = param_.get_draft_token_num(batch_size);
...@@ -346,10 +345,10 @@ Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_ ...@@ -346,10 +345,10 @@ Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_
return fillResult(tokens.back(), draft_token_num + 1, tree, root); return fillResult(tokens.back(), draft_token_num + 1, tree, root);
} }
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const { Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
Result merged_result; Result merged_result;
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb; auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
for (const auto& tks : tokens) { for (const auto& tks : tokens) {
Result res = (this->*match_func)(tks, tokens.size()); Result res = (this->*match_func)(tks, tokens.size());
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end()); merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
...@@ -358,7 +357,7 @@ Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& ...@@ -358,7 +357,7 @@ Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>&
return merged_result; return merged_result;
} }
void Lookahead::Result::truncate(size_t n) { void Ngram::Result::truncate(size_t n) {
if (n < token.size()) { if (n < token.size()) {
int full_n = token.size(); int full_n = token.size();
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
...@@ -369,4 +368,4 @@ void Lookahead::Result::truncate(size_t n) { ...@@ -369,4 +368,4 @@ void Lookahead::Result::truncate(size_t n) {
} }
} }
} // namespace lookahead } // namespace ngram
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "param.h" #include "param.h"
#include "queue.h" #include "queue.h"
namespace lookahead { namespace ngram {
struct TrieNode { struct TrieNode {
std::unordered_map<int32_t, TrieNode*> child; std::unordered_map<int32_t, TrieNode*> child;
...@@ -34,7 +34,7 @@ struct TrieNode { ...@@ -34,7 +34,7 @@ struct TrieNode {
std::multiset<TrieNode*, CompareByFreq> sorted_children; std::multiset<TrieNode*, CompareByFreq> sorted_children;
}; };
class Lookahead { class Ngram {
std::vector<TrieNode> nodes_; std::vector<TrieNode> nodes_;
std::vector<TrieNode*> node_pool_; std::vector<TrieNode*> node_pool_;
size_t free_node_count_; size_t free_node_count_;
...@@ -61,12 +61,12 @@ class Lookahead { ...@@ -61,12 +61,12 @@ class Lookahead {
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_; std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
public: public:
Lookahead(size_t capacity, const Param& param); Ngram(size_t capacity, const Param& param);
Lookahead() = default; Ngram() = default;
~Lookahead(); ~Ngram();
static Lookahead& instance() { static Ngram& instance() {
static Lookahead instance; static Ngram instance;
return instance; return instance;
} }
...@@ -107,4 +107,4 @@ class Lookahead { ...@@ -107,4 +107,4 @@ class Lookahead {
void insert(); void insert();
}; };
} // namespace lookahead } // namespace ngram
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# from sglang.op.lookahead import Lookahead, Param
import logging import logging
import os import os
from typing import List, Tuple from typing import List, Tuple
...@@ -12,17 +10,17 @@ from torch.utils.cpp_extension import load ...@@ -12,17 +10,17 @@ from torch.utils.cpp_extension import load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_abs_path = os.path.dirname(os.path.abspath(__file__)) _abs_path = os.path.dirname(os.path.abspath(__file__))
lookahead_cache_cpp = load( ngram_cache_cpp = load(
name="lookahead_cache_cpp", name="ngram_cache_cpp",
sources=[ sources=[
f"{_abs_path}/lookahead_cache_binding.cpp", f"{_abs_path}/ngram_cache_binding.cpp",
f"{_abs_path}/lookahead.cpp", f"{_abs_path}/ngram.cpp",
], ],
extra_cflags=["-O3", "-std=c++20"], extra_cflags=["-O3", "-std=c++20"],
) )
class LookaheadCache: class NgramCache:
def __init__( def __init__(
self, self,
branch_length=18, branch_length=18,
...@@ -34,7 +32,7 @@ class LookaheadCache: ...@@ -34,7 +32,7 @@ class LookaheadCache:
match_type="BFS", match_type="BFS",
capacity=1000000, capacity=1000000,
): ):
param = lookahead_cache_cpp.Param() param = ngram_cache_cpp.Param()
param.branch_length = branch_length param.branch_length = branch_length
param.min_match_window_size = min_match_window_size param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size param.max_match_window_size = max_match_window_size
...@@ -42,7 +40,7 @@ class LookaheadCache: ...@@ -42,7 +40,7 @@ class LookaheadCache:
param.max_bfs_breadth = max_bfs_breadth param.max_bfs_breadth = max_bfs_breadth
param.draft_token_num = draft_token_num param.draft_token_num = draft_token_num
param.match_type = match_type param.match_type = match_type
self.cache = lookahead_cache_cpp.Lookahead(capacity, param) self.cache = ngram_cache_cpp.Ngram(capacity, param)
self.default_mask = np.ones((1, 1), dtype=np.int64) self.default_mask = np.ones((1, 1), dtype=np.int64)
self.draft_token_num = draft_token_num self.draft_token_num = draft_token_num
...@@ -131,7 +129,7 @@ if __name__ == "__main__": ...@@ -131,7 +129,7 @@ if __name__ == "__main__":
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100], [1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
] ]
cache = LookaheadCache(branch_length=12, draft_token_num=8) cache = NgramCache(branch_length=12, draft_token_num=8)
cache.batch_put(token_ids) cache.batch_put(token_ids)
cache.synchronize() cache.synchronize()
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "lookahead.h" #include "ngram.h"
PYBIND11_MODULE(lookahead_cache_cpp, m) { PYBIND11_MODULE(ngram_cache_cpp, m) {
using namespace lookahead; using namespace ngram;
namespace py = pybind11; namespace py = pybind11;
m.doc() = ""; m.doc() = "";
py::class_<Lookahead>(m, "Lookahead") py::class_<Ngram>(m, "Ngram")
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param")) .def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
.def("asyncInsert", &Lookahead::asyncInsert, "") .def("asyncInsert", &Ngram::asyncInsert, "")
.def("batchMatch", &Lookahead::batchMatch, "") .def("batchMatch", &Ngram::batchMatch, "")
.def("reset", &Lookahead::reset, "") .def("reset", &Ngram::reset, "")
.def("synchronize", &Lookahead::synchronize, ""); .def("synchronize", &Ngram::synchronize, "");
py::class_<Param>(m, "Param") py::class_<Param>(m, "Param")
.def(py::init<>()) .def(py::init<>())
...@@ -35,9 +35,9 @@ PYBIND11_MODULE(lookahead_cache_cpp, m) { ...@@ -35,9 +35,9 @@ PYBIND11_MODULE(lookahead_cache_cpp, m) {
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
.def("detail", &Param::detail, ""); .def("detail", &Param::detail, "");
py::class_<Lookahead::Result>(m, "Result") py::class_<Ngram::Result>(m, "Result")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("token", &Lookahead::Result::token) .def_readwrite("token", &Ngram::Result::token)
.def_readwrite("mask", &Lookahead::Result::mask) .def_readwrite("mask", &Ngram::Result::mask)
.def("truncate", &Lookahead::Result::truncate); .def("truncate", &Ngram::Result::truncate);
} }
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace lookahead { namespace ngram {
struct Param { struct Param {
bool enable; bool enable;
...@@ -122,4 +122,4 @@ struct Param { ...@@ -122,4 +122,4 @@ struct Param {
} }
}; };
} // namespace lookahead } // namespace ngram
...@@ -42,7 +42,7 @@ elif is_hip(): ...@@ -42,7 +42,7 @@ elif is_hip():
@dataclass @dataclass
class LookaheadVerifyInput: class NgramVerifyInput:
def __init__( def __init__(
self, self,
draft_token: torch.Tensor, draft_token: torch.Tensor,
...@@ -408,5 +408,5 @@ class LookaheadVerifyInput: ...@@ -408,5 +408,5 @@ class LookaheadVerifyInput:
def filter_batch(self, new_indices: torch.Tensor): def filter_batch(self, new_indices: torch.Tensor):
pass pass
def merge_batch(self, spec_info: LookaheadVerifyInput): def merge_batch(self, spec_info: NgramVerifyInput):
pass pass
...@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch ...@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj from sglang.srt.utils import broadcast_pyobj
...@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) ...@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
USE_FULL_MASK = True USE_FULL_MASK = True
class LOOKAHEADWorker: class NGRAMWorker:
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
...@@ -38,9 +38,9 @@ class LOOKAHEADWorker: ...@@ -38,9 +38,9 @@ class LOOKAHEADWorker:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens self.draft_token_num: int = server_args.speculative_num_draft_tokens
self.branch_length: int = server_args.speculative_lookahead_branch_length self.branch_length: int = server_args.speculative_ngram_branch_length
self.max_match_window_size: int = ( self.max_match_window_size: int = (
server_args.speculative_lookahead_max_match_window_size server_args.speculative_ngram_max_match_window_size
) )
self.max_batch_size = target_worker.max_running_requests self.max_batch_size = target_worker.max_running_requests
...@@ -48,18 +48,18 @@ class LOOKAHEADWorker: ...@@ -48,18 +48,18 @@ class LOOKAHEADWorker:
self._init_preallocated_tensors() self._init_preallocated_tensors()
self.lookahead_cache = LookaheadCache( self.ngram_cache = NgramCache(
min_match_window_size=server_args.speculative_lookahead_min_match_window_size, min_match_window_size=server_args.speculative_ngram_min_match_window_size,
max_match_window_size=server_args.speculative_lookahead_max_match_window_size, max_match_window_size=server_args.speculative_ngram_max_match_window_size,
min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth, min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth, max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
capacity=server_args.speculative_lookahead_capacity, capacity=server_args.speculative_ngram_capacity,
branch_length=server_args.speculative_lookahead_branch_length, branch_length=server_args.speculative_ngram_branch_length,
draft_token_num=server_args.speculative_num_draft_tokens, draft_token_num=server_args.speculative_num_draft_tokens,
) )
def clear_cache_pool(self): def clear_cache_pool(self):
self.lookahead_cache.reset() self.ngram_cache.reset()
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int): def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
seq2_len = len(seq2) seq2_len = len(seq2)
...@@ -124,14 +124,14 @@ class LOOKAHEADWorker: ...@@ -124,14 +124,14 @@ class LOOKAHEADWorker:
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
bs = batch.batch_size() bs = batch.batch_size()
self.lookahead_cache.synchronize() self.ngram_cache.synchronize()
batch_tokens = [] batch_tokens = []
for req in batch.reqs: for req in batch.reqs:
check_token = self._efficient_concat_last_n( check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_match_window_size req.origin_input_ids, req.output_ids, self.max_match_window_size
) )
batch_tokens.append(check_token) batch_tokens.append(check_token)
req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens) req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
total_draft_token_num = len(req_drafts) total_draft_token_num = len(req_drafts)
# Check if speculative decoding is needed; here we always enforce it # Check if speculative decoding is needed; here we always enforce it
...@@ -184,9 +184,9 @@ class LOOKAHEADWorker: ...@@ -184,9 +184,9 @@ class LOOKAHEADWorker:
tree_mask.append(req_mask.flatten()) tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0) tree_mask = torch.cat(tree_mask, dim=0)
batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = LookaheadVerifyInput( batch.spec_info = NgramVerifyInput(
draft_tokens, draft_tokens,
tree_mask, tree_mask,
positions, positions,
...@@ -197,7 +197,7 @@ class LOOKAHEADWorker: ...@@ -197,7 +197,7 @@ class LOOKAHEADWorker:
) )
batch.spec_info.prepare_for_verify(batch, self.page_size) batch.spec_info.prepare_for_verify(batch, self.page_size)
def _update_lookahead_cache(self, batch: ScheduleBatch): def _update_ngram_cache(self, batch: ScheduleBatch):
batch_tokens = [] batch_tokens = []
for req in batch.reqs: for req in batch.reqs:
# FIXME: Whether to insert 'extend' into the cache or not, after testing, # FIXME: Whether to insert 'extend' into the cache or not, after testing,
...@@ -209,7 +209,7 @@ class LOOKAHEADWorker: ...@@ -209,7 +209,7 @@ class LOOKAHEADWorker:
req.origin_input_ids, req.output_ids, self.branch_length req.origin_input_ids, req.output_ids, self.branch_length
) )
batch_tokens.append(put_ids) batch_tokens.append(put_ids)
self.lookahead_cache.batch_put(batch_tokens) self.ngram_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch): def forward_batch_speculative_generation(self, batch: ScheduleBatch):
self._prepare_for_speculative_decoding(batch) self._prepare_for_speculative_decoding(batch)
...@@ -227,7 +227,7 @@ class LOOKAHEADWorker: ...@@ -227,7 +227,7 @@ class LOOKAHEADWorker:
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size batch, logits_output, self.page_size
) )
self._update_lookahead_cache(batch) self._update_ngram_cache(batch)
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
else: else:
......
...@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
EAGLE = auto() EAGLE = auto()
EAGLE3 = auto() EAGLE3 = auto()
STANDALONE = auto() STANDALONE = auto()
LOOKAHEAD = auto() NGRAM = auto()
def is_none(self): def is_none(self):
return self == SpeculativeAlgorithm.NONE return self == SpeculativeAlgorithm.NONE
...@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
def is_standalone(self): def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE return self == SpeculativeAlgorithm.STANDALONE
def is_lookahead(self): def is_ngram(self):
return self == SpeculativeAlgorithm.LOOKAHEAD return self == SpeculativeAlgorithm.NGRAM
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
...@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3, "EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE, "STANDALONE": SpeculativeAlgorithm.STANDALONE,
"LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD, "NGRAM": SpeculativeAlgorithm.NGRAM,
None: SpeculativeAlgorithm.NONE, None: SpeculativeAlgorithm.NONE,
} }
if name is not None: if name is not None:
......
...@@ -82,7 +82,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( ...@@ -82,7 +82,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct" "meta-llama/Llama-3.1-8B-Instruct"
) )
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct" DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases # Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
......
...@@ -314,7 +314,7 @@ set(SOURCES ...@@ -314,7 +314,7 @@ set(SOURCES
"csrc/kvcacheio/transfer.cu" "csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu" "csrc/speculative/eagle_utils.cu"
"csrc/speculative/lookahead_utils.cu" "csrc/speculative/ngram_utils.cu"
"csrc/speculative/packbit.cu" "csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment