Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
8cda5a62
"python/vscode:/vscode.git/clone" did not exist on "8b6966d0205abeaca143693c6f273dcacbfa779d"
Unverified
Commit
8cda5a62
authored
Sep 07, 2025
by
Qiaolin Yu
Committed by
GitHub
Sep 07, 2025
Browse files
Standalone speculative decoding (#10090)
parent
400d3b97
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
285 additions
and
9 deletions
+285
-9
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+12
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+17
-5
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+5
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-0
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+5
-0
python/sglang/srt/speculative/standalone_worker.py
python/sglang/srt/speculative/standalone_worker.py
+109
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+4
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_standalone_speculative_decoding.py
test/srt/test_standalone_speculative_decoding.py
+115
-0
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
8cda5a62
...
@@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1539,7 +1539,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
forward_mode
=
ForwardMode
.
DECODE
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
if
self
.
spec_algorithm
.
is_eagle
():
if
self
.
spec_algorithm
.
is_eagle
()
or
self
.
spec_algorithm
.
is_standalone
()
:
# if spec decoding is used, the decode batch is prepared inside
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
# `forward_batch_speculative_generation` after running draft models.
return
return
...
...
python/sglang/srt/managers/scheduler.py
View file @
8cda5a62
...
@@ -349,6 +349,18 @@ class Scheduler(
...
@@ -349,6 +349,18 @@ class Scheduler(
target_worker
=
self
.
tp_worker
,
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
)
)
elif
self
.
spec_algorithm
.
is_standalone
():
from
sglang.srt.speculative.standalone_worker
import
StandaloneWorker
self
.
draft_worker
=
StandaloneWorker
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
moe_ep_rank
=
moe_ep_rank
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_port
,
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
)
else
:
else
:
self
.
draft_worker
=
None
self
.
draft_worker
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
8cda5a62
...
@@ -271,7 +271,10 @@ class CudaGraphRunner:
...
@@ -271,7 +271,10 @@ class CudaGraphRunner:
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
(
model_runner
.
spec_algorithm
.
is_eagle
()
or
model_runner
.
spec_algorithm
.
is_standalone
()
):
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen"
)
raise
RuntimeError
(
"This should not happen"
)
else
:
else
:
...
@@ -827,7 +830,10 @@ class CudaGraphRunner:
...
@@ -827,7 +830,10 @@ class CudaGraphRunner:
def
get_spec_info
(
self
,
num_tokens
:
int
):
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
if
(
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
):
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
...
...
python/sglang/srt/server_args.py
View file @
8cda5a62
...
@@ -473,9 +473,14 @@ class ServerArgs:
...
@@ -473,9 +473,14 @@ class ServerArgs:
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
reserved_mem
=
32
*
1024
reserved_mem
=
32
*
1024
# draft model and larger cuda graph buffers
if
self
.
speculative_algorithm
is
not
None
:
if
self
.
speculative_algorithm
is
not
None
:
# draft model and larger cuda graph buffers
if
self
.
speculative_algorithm
==
"STANDALONE"
:
reserved_mem
+=
2
*
1024
# Standalone speculative decoding needs more memory than other speculative
# decoding algorithms since the draft model is typically larger.
reserved_mem
+=
6
*
1024
else
:
reserved_mem
+=
2
*
1024
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
reserved_mem
+=
4
*
1024
reserved_mem
+=
4
*
1024
...
@@ -704,7 +709,12 @@ class ServerArgs:
...
@@ -704,7 +709,12 @@ class ServerArgs:
# NEXTN shares the same implementation of EAGLE
# NEXTN shares the same implementation of EAGLE
self
.
speculative_algorithm
=
"EAGLE"
self
.
speculative_algorithm
=
"EAGLE"
if
self
.
speculative_algorithm
in
(
"EAGLE"
,
"EAGLE3"
):
if
self
.
speculative_algorithm
in
(
"EAGLE"
,
"EAGLE3"
,
"STANDALONE"
):
if
self
.
speculative_algorithm
==
"STANDALONE"
:
# TODO: support dp attention for standalone speculative decoding
assert
(
self
.
enable_dp_attention
is
False
),
"Currently standalone speculative decoding does not support dp attention."
if
self
.
max_running_requests
is
None
:
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
48
self
.
max_running_requests
=
48
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
...
@@ -1499,7 +1509,7 @@ class ServerArgs:
...
@@ -1499,7 +1509,7 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--speculative-algorithm"
,
"--speculative-algorithm"
,
type
=
str
,
type
=
str
,
choices
=
[
"EAGLE"
,
"EAGLE3"
,
"NEXTN"
],
choices
=
[
"EAGLE"
,
"EAGLE3"
,
"NEXTN"
,
"STANDALONE"
],
help
=
"Speculative algorithm."
,
help
=
"Speculative algorithm."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs):
...
@@ -2635,7 +2645,9 @@ def auto_choose_speculative_params(self: ServerArgs):
"""
"""
hf_config
=
self
.
get_hf_config
()
hf_config
=
self
.
get_hf_config
()
arch
=
hf_config
.
architectures
[
0
]
arch
=
hf_config
.
architectures
[
0
]
if
self
.
speculative_algorithm
==
"STANDALONE"
:
# The default value for standalone speculative decoding
return
(
3
,
1
,
4
)
if
arch
in
[
"LlamaForCausalLM"
]:
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
# The default value for llama
return
(
5
,
4
,
8
)
return
(
5
,
4
,
8
)
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
8cda5a62
...
@@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -341,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
extend_seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
extend_seq_lens
)
self
.
extend_seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
extend_seq_lens
)
self
.
out_cache_loc
[:
num_tokens
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
num_tokens
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
num_tokens
].
copy_
(
forward_batch
.
positions
)
self
.
positions
[:
num_tokens
].
copy_
(
forward_batch
.
positions
)
self
.
hidden_states
[:
num_tokens
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
if
(
forward_batch
.
spec_info
.
hidden_states
.
shape
[
1
]
==
self
.
hidden_states
.
shape
[
1
]
):
self
.
hidden_states
[:
num_tokens
].
copy_
(
forward_batch
.
spec_info
.
hidden_states
)
if
forward_batch
.
spec_info
.
accept_length
is
not
None
:
if
forward_batch
.
spec_info
.
accept_length
is
not
None
:
self
.
accept_length
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
accept_length
)
self
.
accept_length
[:
raw_bs
].
copy_
(
forward_batch
.
spec_info
.
accept_length
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
8cda5a62
...
@@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker):
...
@@ -730,6 +730,14 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
# Set inputs
forward_batch
.
input_ids
=
input_ids
forward_batch
.
input_ids
=
input_ids
# This is a temporary fix for the case that the user is using standalone
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
# rope kernel needs cache_loc to be contiguous.
if
(
self
.
server_args
.
speculative_algorithm
==
"STANDALONE"
and
self
.
model_config
.
hf_config
.
architectures
[
0
]
==
"GptOssForCausalLM"
):
out_cache_loc
=
out_cache_loc
.
contiguous
()
forward_batch
.
out_cache_loc
=
out_cache_loc
[
i
]
forward_batch
.
out_cache_loc
=
out_cache_loc
[
i
]
forward_batch
.
positions
.
add_
(
1
)
forward_batch
.
positions
.
add_
(
1
)
forward_batch
.
attn_backend
=
self
.
draft_attn_backend
.
attn_backends
[
i
]
forward_batch
.
attn_backend
=
self
.
draft_attn_backend
.
attn_backends
[
i
]
...
...
python/sglang/srt/speculative/spec_info.py
View file @
8cda5a62
...
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
NONE
=
auto
()
NONE
=
auto
()
EAGLE
=
auto
()
EAGLE
=
auto
()
EAGLE3
=
auto
()
EAGLE3
=
auto
()
STANDALONE
=
auto
()
def
is_none
(
self
):
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
return
self
==
SpeculativeAlgorithm
.
NONE
...
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
def
is_eagle3
(
self
):
def
is_eagle3
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE3
return
self
==
SpeculativeAlgorithm
.
EAGLE3
def
is_standalone
(
self
):
return
self
==
SpeculativeAlgorithm
.
STANDALONE
@
staticmethod
@
staticmethod
def
from_string
(
name
:
str
):
def
from_string
(
name
:
str
):
name_map
=
{
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE3"
:
SpeculativeAlgorithm
.
EAGLE3
,
"EAGLE3"
:
SpeculativeAlgorithm
.
EAGLE3
,
"STANDALONE"
:
SpeculativeAlgorithm
.
STANDALONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
}
if
name
is
not
None
:
if
name
is
not
None
:
...
...
python/sglang/srt/speculative/standalone_worker.py
0 → 100644
View file @
8cda5a62
import
logging
from
contextlib
import
contextmanager
from
typing
import
Optional
import
torch
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
,
load_token_map
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
empty_context
,
get_bool_env_var
,
is_cuda
if
is_cuda
():
from
sgl_kernel
import
segment_packbits
logger
=
logging
.
getLogger
(
__name__
)
RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
StandaloneWorker
(
EAGLEWorker
):
def
__init__
(
self
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
moe_ep_rank
:
int
,
nccl_port
:
int
,
target_worker
:
TpModelWorker
,
):
# Parse arguments
self
.
server_args
=
server_args
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
speculative_num_draft_tokens
=
server_args
.
speculative_num_draft_tokens
self
.
enable_nan_detection
=
server_args
.
enable_nan_detection
self
.
gpu_id
=
gpu_id
self
.
device
=
server_args
.
device
self
.
target_worker
=
target_worker
self
.
page_size
=
server_args
.
page_size
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
# Do not capture cuda graph in `super().__init__()`
# It will be captured later.
backup_disable_cuda_graph
=
server_args
.
disable_cuda_graph
server_args
.
disable_cuda_graph
=
True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
target_worker
.
get_memory_pool
()
)
# Load hot token ids
if
server_args
.
speculative_token_map
is
not
None
:
self
.
hot_token_id
=
load_token_map
(
server_args
.
speculative_token_map
)
server_args
.
json_model_override_args
=
(
f
'{{"hot_vocab_size":
{
len
(
self
.
hot_token_id
)
}
}}'
)
else
:
self
.
hot_token_id
=
None
# Init draft worker
with
empty_context
():
TpModelWorker
.
__init__
(
self
,
server_args
=
server_args
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
pp_rank
=
0
,
# FIXME
dp_rank
=
dp_rank
,
moe_ep_rank
=
moe_ep_rank
,
nccl_port
=
nccl_port
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
# Init attention backend and cuda graphs
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
self
.
draft_tp_context
=
(
draft_tp_context
if
server_args
.
enable_dp_attention
else
empty_context
)
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
self
.
init_attention_backend
()
self
.
init_cuda_graphs
()
# Some dummy tensors
self
.
num_new_pages_per_topk
=
torch
.
empty
(
(),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
extend_lens
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
python/sglang/test/test_utils.py
View file @
8cda5a62
...
@@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
...
@@ -72,6 +72,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
=
"meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
=
"meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
=
"lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
=
"lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
=
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3
=
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
=
(
"meta-llama/Llama-3.1-8B-Instruct"
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
# Other use cases
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
=
(
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
=
(
...
...
test/srt/run_suite.py
View file @
8cda5a62
...
@@ -76,6 +76,7 @@ suites = {
...
@@ -76,6 +76,7 @@ suites = {
TestFile
(
"test_harmony_parser.py"
,
20
),
TestFile
(
"test_harmony_parser.py"
,
20
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hybrid_attn_backend.py"
,
100
),
TestFile
(
"test_hybrid_attn_backend.py"
,
100
),
TestFile
(
"test_standalone_speculative_decoding.py"
,
250
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_io_struct.py"
,
8
),
TestFile
(
"test_io_struct.py"
,
8
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
...
...
test/srt/test_standalone_speculative_decoding.py
0 → 100644
View file @
8cda5a62
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
GSM_DATASET_PATH
=
None
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS
=
[
"--trust-remote-code"
,
"--cuda-graph-max-bs"
,
"8"
,
"--speculative-algorithm"
,
"STANDALONE"
,
"--speculative-draft-model-path"
,
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"4"
,
"--speculative-eagle-topk"
,
"2"
,
"--speculative-num-draft-tokens"
,
"7"
,
"--mem-fraction-static"
,
0.7
,
]
class
TestStandaloneSpeculativeDecodingBase
(
CustomTestCase
):
model
=
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST
draft_model
=
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
base_url
=
DEFAULT_URL_FOR_TEST
accuracy_threshold
=
0.7
# derived tests need to override this
spec_decode_threshold
=
3.6
# derived spec decoding tests need to override this
@
classmethod
def
get_server_args
(
cls
):
"""Return the arguments for the server launch. Override in subclasses."""
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"fa3"
]
@
classmethod
def
setUpClass
(
cls
):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os
.
environ
[
"SGL_JIT_DEEPGEMM_PRECOMPILE"
]
=
"false"
os
.
environ
[
"SGL_ENABLE_JIT_DEEPGEMM"
]
=
"false"
model
=
cls
.
model
cls
.
process
=
popen_launch_server
(
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
cls
.
get_server_args
(),
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
requests
.
get
(
self
.
base_url
+
"/flush_cache"
)
args
=
SimpleNamespace
(
num_shots
=
4
,
num_questions
=
100
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
data_path
=
GSM_DATASET_PATH
,
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
f
"
{
metrics
=
}
"
)
# Use the appropriate metric key based on the test class
metric_key
=
"accuracy"
self
.
assertGreater
(
metrics
[
metric_key
],
self
.
accuracy_threshold
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
avg_spec_accept_length
=
server_info
.
json
()[
"internal_states"
][
0
][
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
self
.
spec_decode_threshold
)
class
TestStandaloneSpeculativeDecodingTriton
(
TestStandaloneSpeculativeDecodingBase
):
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"triton"
]
class
TestStandaloneSpeculativeDecodingFlashinfer
(
TestStandaloneSpeculativeDecodingBase
):
@
classmethod
def
get_server_args
(
cls
):
return
DEFAULT_SERVER_ARGS
+
[
"--attention-backend"
,
"flashinfer"
]
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment