Unverified Commit b26bc86b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support page size > 1 + eagle (#4908)

parent 5ec5eaf7
...@@ -33,6 +33,7 @@ runtime_common = [ ...@@ -33,6 +33,7 @@ runtime_common = [
"prometheus-client>=0.20.0", "prometheus-client>=0.20.0",
"psutil", "psutil",
"pydantic", "pydantic",
"pynvml",
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"soundfile==0.13.1", "soundfile==0.13.1",
......
...@@ -14,7 +14,6 @@ from functools import partial ...@@ -14,7 +14,6 @@ from functools import partial
from typing import TYPE_CHECKING, Callable, List, Optional, Union from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch import torch
import triton
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito ...@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import get_bool_env_var, is_flashinfer_available from sglang.srt.utils import is_flashinfer_available, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
self.page_size = model_runner.page_size
max_bs = model_runner.req_to_token_pool.size * self.topk max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
...@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend: ...@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
self.pool_len, self.pool_len,
kv_indices_buffer.shape[1], kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs), next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps), next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
assert forward_batch.spec_info is not None assert forward_batch.spec_info is not None
...@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend: ...@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
) )
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone() forward_batch.spec_info.kv_indptr.clone()
) )
......
...@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
return req_pool_indices return req_pool_indices
def alloc_token_slots(self, num_tokens: int): def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens: if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None: if self.tree_cache is not None:
self.tree_cache.evict(num_tokens) self.tree_cache.evict(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens) out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
...@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.pretty_print() self.tree_cache.pretty_print()
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
return out_cache_loc if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_extend( def alloc_paged_token_slots_extend(
self, self,
...@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
backup_state: bool = False,
): ):
if ( if (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
...@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size, + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
) )
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens prefix_lens, seq_lens, last_loc, extend_num_tokens
) )
...@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_decode( def alloc_paged_token_slots_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
backup_state: bool = False,
): ):
if ( if (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
...@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.evict( self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size, len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
) )
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if out_cache_loc is None: if out_cache_loc is None:
error_msg = ( error_msg = (
f"Decode out of memory. Try to lower your batch size.\n" f"Decode out of memory. Try to lower your batch size.\n"
...@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
......
...@@ -1110,7 +1110,7 @@ class Scheduler( ...@@ -1110,7 +1110,7 @@ class Scheduler(
) )
if memory_leak: if memory_leak:
msg = ( msg = (
"KV cache pool leak detected! " "token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n" f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n" f"{self.tree_cache.evictable_size()=}\n"
...@@ -1121,7 +1121,7 @@ class Scheduler( ...@@ -1121,7 +1121,7 @@ class Scheduler(
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = ( msg = (
"Memory pool leak detected!" "req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, " f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n" f"total_size={self.req_to_token_pool.size}\n"
) )
......
...@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator: ...@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
if self.free_group: if self.free_group:
self.free(torch.cat(self.free_group)) self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self): def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange( self.free_slots = torch.arange(
......
...@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator: ...@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
next_power_of_2(extend_num_tokens), next_power_of_2(extend_num_tokens),
) )
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item() merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32 num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
...@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator: ...@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
self.page_size, self.page_size,
) )
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item() num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
...@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator: ...@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self): def free_group_begin(self):
self.is_not_in_free_group = False self.is_not_in_free_group = False
self.free_group = [] self.free_group = []
...@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator: ...@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
if self.free_group: if self.free_group:
self.free(torch.cat(self.free_group)) self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self): def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange( self.free_pages = torch.arange(
......
...@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if capture_bs is None: if capture_bs is None:
if server_args.speculative_algorithm is None: if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding: if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160] capture_bs = list(range(1, 33)) + range(40, 161, 16)
else: else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
else: else:
# Since speculative decoding requires more cuda graph memory, we # Since speculative decoding requires more cuda graph memory, we
# capture less. # capture less.
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160] capture_bs = (
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
)
if _is_hip: if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += list(range(160, 257, 8))
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import random import random
...@@ -132,9 +133,9 @@ class ServerArgs: ...@@ -132,9 +133,9 @@ class ServerArgs:
# Speculative decoding # Speculative decoding
speculative_algorithm: Optional[str] = None speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None speculative_draft_model_path: Optional[str] = None
speculative_num_steps: int = 5 speculative_num_steps: Optional[int] = None
speculative_eagle_topk: int = 4 speculative_eagle_topk: Optional[int] = None
speculative_num_draft_tokens: int = 8 speculative_num_draft_tokens: Optional[int] = None
speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0 speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None speculative_token_map: Optional[str] = None
...@@ -313,12 +314,29 @@ class ServerArgs: ...@@ -313,12 +314,29 @@ class ServerArgs:
or self.speculative_algorithm == "EAGLE3" or self.speculative_algorithm == "EAGLE3"
): ):
if self.max_running_requests is None: if self.max_running_requests is None:
self.max_running_requests = 32 self.max_running_requests = 48
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
logger.info( logger.info(
"Overlap scheduler is disabled because of using " "Overlap scheduler is disabled because of using "
"eagle speculative decoding." "eagle speculative decoding."
) )
# Auto choose parameters
if self.speculative_num_steps is None:
assert (
self.speculative_eagle_topk is None
and self.speculative_num_draft_tokens is None
)
(
self.speculative_num_steps,
self.speculative_eagle_topk,
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
# The token generated from the verify step is counted. # The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# assert self.speculative_num_steps < self.speculative_num_draft_tokens # assert self.speculative_num_steps < self.speculative_num_draft_tokens
...@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action): ...@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help) raise ValueError(self.help)
def auto_choose_speculative_params(self: ServerArgs):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
if self.decrypted_config_file:
config_path = self.decrypted_config_file
else:
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"{config_path} is not found.")
config = json.load(open(config_path))
arch = config.get("architectures", ["Unknown"])[0]
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
# The default value for deepseek
return (5, 4, 8)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:
# The default value for all other models
return (5, 4, 8)
from __future__ import annotations from __future__ import annotations
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
...@@ -10,11 +11,15 @@ import triton.language as tl ...@@ -10,11 +11,15 @@ import triton.language as tl
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available, is_hip from sglang.srt.utils import is_cuda_available, is_hip, next_power_of_2
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import ( from sgl_kernel import (
...@@ -34,6 +39,9 @@ import logging ...@@ -34,6 +39,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
@dataclass @dataclass
class EagleDraftInput: class EagleDraftInput:
# The inputs for decode # The inputs for decode
...@@ -93,7 +101,7 @@ class EagleDraftInput: ...@@ -93,7 +101,7 @@ class EagleDraftInput:
torch.cumsum(self.accept_length, axis=0, dtype=torch.int), torch.cumsum(self.accept_length, axis=0, dtype=torch.int),
self.positions, self.positions,
new_verified_id, new_verified_id,
triton.next_power_of_2(speculative_num_steps + 1), next_power_of_2(speculative_num_steps + 1),
) )
batch.seq_lens_sum = sum(seq_lens_cpu) batch.seq_lens_sum = sum(seq_lens_cpu)
...@@ -225,18 +233,34 @@ class EagleVerifyInput: ...@@ -225,18 +233,34 @@ class EagleVerifyInput:
CaptureHiddenMode.FULL, CaptureHiddenMode.FULL,
) )
def prepare_for_verify(self, batch: ScheduleBatch): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
if page_size == 1:
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num
else:
prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids)
)
self.last_loc = last_loc
bs = batch.batch_size() bs = batch.batch_size()
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.seq_lens, batch.seq_lens,
batch.seq_lens + self.draft_token_num, end_offset,
batch.out_cache_loc, batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
def generate_attn_arg_prefill( def generate_attn_arg_prefill(
...@@ -282,6 +306,7 @@ class EagleVerifyInput: ...@@ -282,6 +306,7 @@ class EagleVerifyInput:
batch: ScheduleBatch, batch: ScheduleBatch,
logits_output: torch.Tensor, logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Verify and find accepted tokens based on logits output and batch Verify and find accepted tokens based on logits output and batch
...@@ -305,6 +330,7 @@ class EagleVerifyInput: ...@@ -305,6 +330,7 @@ class EagleVerifyInput:
) )
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
# Apply penalty
if sampling_info.penalizer_orchestrator.is_required: if sampling_info.penalizer_orchestrator.is_required:
# This is a relaxed version of penalties for speculative decoding. # This is a relaxed version of penalties for speculative decoding.
linear_penalty = torch.zeros( linear_penalty = torch.zeros(
...@@ -317,6 +343,7 @@ class EagleVerifyInput: ...@@ -317,6 +343,7 @@ class EagleVerifyInput:
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
) )
# Sample tokens
if batch.sampling_info.is_all_greedy: if batch.sampling_info.is_all_greedy:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
target_predict = target_predict.reshape(bs, self.draft_token_num) target_predict = target_predict.reshape(bs, self.draft_token_num)
...@@ -378,13 +405,24 @@ class EagleVerifyInput: ...@@ -378,13 +405,24 @@ class EagleVerifyInput:
deterministic=True, deterministic=True,
) )
if SIMULATE_ACC_LEN:
# Do simulation
accept_index = _generate_simulated_accept_index(
accept_index=accept_index,
predict=predict, # mutable
accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs,
spec_steps=self.spec_steps,
)
new_accept_index = [] new_accept_index = []
unfinished_index = [] unfinished_index = []
accept_index_cpu = accept_index.tolist() accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist() predict_cpu = predict.tolist()
has_finished = False has_finished = False
# iterate every accepted token and check if req has finished after append the token # Iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots # should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
new_accept_index_ = [] new_accept_index_ = []
...@@ -407,13 +445,28 @@ class EagleVerifyInput: ...@@ -407,13 +445,28 @@ class EagleVerifyInput:
unfinished_index.append(i) unfinished_index.append(i)
req.spec_verify_ct += 1 req.spec_verify_ct += 1
if has_finished:
accept_length = (accept_index != -1).sum(dim=1) - 1
# Free the KV cache for unaccepted tokens
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
if page_size != 1:
align_evict_mask_to_page_size[len(batch.seq_lens),](
batch.seq_lens,
evict_mask,
page_size,
self.draft_token_num,
next_power_of_2(self.draft_token_num),
)
token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
# Construct EagleVerifyOutput
if not has_finished: if not has_finished:
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
token_to_kv_pool_allocator.free(mem_need_free_idx)
batch.out_cache_loc = batch.out_cache_loc[accept_index] batch.out_cache_loc = batch.out_cache_loc[accept_index]
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
...@@ -422,7 +475,7 @@ class EagleVerifyInput: ...@@ -422,7 +475,7 @@ class EagleVerifyInput:
batch.seq_lens + accept_length + 1, batch.seq_lens + accept_length + 1,
batch.out_cache_loc, batch.out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
...@@ -443,13 +496,6 @@ class EagleVerifyInput: ...@@ -443,13 +496,6 @@ class EagleVerifyInput:
accepeted_indices=accept_index, accepeted_indices=accept_index,
) )
else: else:
accept_length = (accept_index != -1).sum(dim=1) - 1
accept_index = accept_index[accept_index != -1]
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask]
token_to_kv_pool_allocator.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
...@@ -457,7 +503,7 @@ class EagleVerifyInput: ...@@ -457,7 +503,7 @@ class EagleVerifyInput:
batch.seq_lens + accept_length + 1, batch.seq_lens + accept_length + 1,
batch.out_cache_loc[accept_index], batch.out_cache_loc[accept_index],
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
triton.next_power_of_2(bs), next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
...@@ -465,20 +511,21 @@ class EagleVerifyInput: ...@@ -465,20 +511,21 @@ class EagleVerifyInput:
draft_input = EagleDraftInput() draft_input = EagleDraftInput()
if len(new_accept_index) > 0: if len(new_accept_index) > 0:
new_accept_index = torch.tensor(new_accept_index, device="cuda") new_accept_index = torch.tensor(new_accept_index, device="cuda")
unfinished_index_device = torch.tensor(unfinished_index, device="cuda")
draft_input.hidden_states = batch.spec_info.hidden_states[ draft_input.hidden_states = batch.spec_info.hidden_states[
new_accept_index new_accept_index
] ]
draft_input.verified_id = predict[new_accept_index] draft_input.verified_id = predict[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
draft_input.accept_length_cpu = [ draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index accept_length_cpu[i] for i in unfinished_index
] ]
draft_input.accept_length = accept_length[unfinished_index_device]
if has_finished: if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[ draft_input.seq_lens_for_draft_extend = batch.seq_lens[
unfinished_index unfinished_index_device
] ]
draft_input.req_pool_indices_for_draft_extend = ( draft_input.req_pool_indices_for_draft_extend = (
batch.req_pool_indices[unfinished_index] batch.req_pool_indices[unfinished_index_device]
) )
else: else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.seq_lens_for_draft_extend = batch.seq_lens
...@@ -564,13 +611,24 @@ def assign_draft_cache_locs( ...@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
pool_len: tl.constexpr, pool_len: tl.constexpr,
topk: tl.constexpr, topk: tl.constexpr,
speculative_num_steps: tl.constexpr, speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 32 BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
kv_start = tl.load(seq_lens + pid) kv_start = tl.load(seq_lens + pid)
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
if page_size == 1 or topk == 1:
kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
num_new_page = (
last_page_len + speculative_num_steps + page_size - 1
) // page_size
kv_end = prefix_len // page_size * page_size + num_new_page * (page_size * topk)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE)
for i in range(num_loop): for i in range(num_loop):
...@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices( ...@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
tl.store(kv_indptr + zid, base + zid * iters) tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask)
num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@torch.compile(dynamic=True) @torch.compile(dynamic=True)
def select_top_k_tokens( def select_top_k_tokens(
i: int, i: int,
...@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim): ...@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
else: else:
# Use topk for efficiency with larger k values # Use topk for efficiency with larger k values
return torch.topk(values, topk, dim=dim) return torch.topk(values, topk, dim=dim)
def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
simulate_acc_len,
bs,
spec_steps,
):
simulate_acc_len_float = float(simulate_acc_len)
simulated_values = torch.normal(
mean=simulate_acc_len_float,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps)
simulate_acc_len = int(simulated_values.round().item())
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
...@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group ...@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
...@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = server_args.device self.device = server_args.device
self.target_worker = target_worker self.target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string( self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
...@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker): ...@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch) spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify( logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info batch, spec_info
) )
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# If it is None, it means all requests are finished # If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None: if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
...@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker): ...@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
) )
# Allocate cache locations # Allocate cache locations
out_cache_loc = batch.alloc_token_slots( if self.page_size == 1:
num_seqs * self.topk * self.speculative_num_steps out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
) num_seqs * self.topk * self.speculative_num_steps, backup_state=True
)
else:
if self.topk == 1:
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
# KV cache layout in batch.req_to_token_pool.req_to_token:
#
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
# prefix top-k = 0 tok-k = 1 top-k = 2
#
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
seq_lens,
last_loc,
extend_num_tokens,
backup_state=True,
)
)
assign_draft_cache_locs[(num_seqs,)]( assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
...@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
batch.req_to_token_pool.req_to_token.shape[1], batch.req_to_token_pool.req_to_token.shape[1],
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.page_size,
) )
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
...@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
# Run forward steps # Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch) score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
ret = EagleVerifyInput.create( ret = EagleVerifyInput.create(
spec_info.verified_id, spec_info.verified_id,
score_list, score_list,
...@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.server_args.speculative_num_draft_tokens,
) )
return ret, out_cache_loc return ret
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
# Parse args # Parse args
...@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
return score_list, token_list, parents_list return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch) spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify( res: EagleVerifyOutput = spec_info.verify(
batch, logits_output, self.token_to_kv_pool_allocator batch,
logits_output,
self.token_to_kv_pool_allocator,
self.page_size,
) )
# Post process based on verified outputs. # Post process based on verified outputs.
......
...@@ -76,11 +76,14 @@ def is_in_ci(): ...@@ -76,11 +76,14 @@ def is_in_ci():
if is_in_ci(): if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157 DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157" 5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
)
else: else:
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157 DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157" 7000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
)
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
...@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): ...@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class CustomTestCase(unittest.TestCase): class CustomTestCase(unittest.TestCase):
pass
"""
def _callTestMethod(self, method): def _callTestMethod(self, method):
max_retry = int( max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0") os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
...@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase): ...@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method), lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry, max_retry=max_retry,
) )
"""
...@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei ...@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip install sgl-kernel==0.0.5.post4 --force-reinstall pip install sgl-kernel==0.0.5.post4 --force-reinstall
pip install torch_memory_saver pip install torch_memory_saver
pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm pip install transformers==4.50.0 sentence_transformers accelerate==1.4.0 peft pandas datasets timm torchaudio
# For compling xgrammar kernels # For compling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install cuda-python nvidia-cuda-nvrtc-cu12
......
...@@ -26,7 +26,7 @@ suites = { ...@@ -26,7 +26,7 @@ suites = {
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22), TestFile("test_block_int8.py", 22),
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 447), TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2), TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
......
...@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase): ...@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
print(f"{metrics=}") print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20) self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info") server_info = requests.get(self.base_url + "/get_server_info").json()
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] avg_spec_accept_length = server_info["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}") print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 3.5)
speculative_eagle_topk = server_info["speculative_eagle_topk"]
if speculative_eagle_topk == 1:
self.assertGreater(avg_spec_accept_length, 2.5)
else:
self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens. # Wait a little bit so that the memory check happens.
time.sleep(4) time.sleep(4)
...@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer): ...@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
) )
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase): ...@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60) self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info") server_info = requests.get(self.base_url + "/get_server_info")
print(f"{server_info=}")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}") print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5) self.assertGreater(avg_spec_accept_length, 2.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment