Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7b68d271
Unverified
Commit
7b68d271
authored
Jul 21, 2025
by
Xiaoze Fan
Committed by
GitHub
Jul 21, 2025
Browse files
[Feature] Add a test for Layer-wise Prefill (#8231)
Signed-off-by:
jason-fxz
<
jason341132@qq.com
>
parent
74f59ae5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
299 additions
and
0 deletions
+299
-0
test/srt/test_forward_split_prefill.py
test/srt/test_forward_split_prefill.py
+299
-0
No files found.
test/srt/test_forward_split_prefill.py
0 → 100644
View file @
7b68d271
"""
Test forward_split_prefill functionality.
Usage:
python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill
or
python3 test_forward_split_prefill.py
"""
import
time
import
unittest
import
numpy
as
np
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
CustomTestCase
class
TestForwardSplitPrefill
(
CustomTestCase
):
"""Test cases for forward_split_prefill functionality."""
@
classmethod
def
setUpClass
(
cls
):
"""Set up the test environment once for all tests."""
cls
.
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
tp_size
=
1
cls
.
device
=
"cuda"
# Initialize server args
cls
.
server_args
=
ServerArgs
(
model_path
=
cls
.
model_path
,
tokenizer_path
=
cls
.
model_path
,
host
=
"127.0.0.1"
,
disable_cuda_graph
=
True
,
# Disable CUDA graph for testing split prefill
disable_hybrid_swa_memory
=
True
,
port
=
30000
,
tp_size
=
cls
.
tp_size
,
mem_fraction_static
=
0.8
,
trust_remote_code
=
True
,
)
cls
.
port_args
=
PortArgs
.
init_new
(
cls
.
server_args
)
# Load model and tokenizer
cls
.
model_config
=
ModelConfig
.
from_server_args
(
cls
.
server_args
)
cls
.
model_runner
=
ModelRunner
(
model_config
=
cls
.
model_config
,
mem_fraction_static
=
cls
.
server_args
.
mem_fraction_static
,
gpu_id
=
0
,
tp_rank
=
0
,
tp_size
=
cls
.
tp_size
,
pp_rank
=
0
,
pp_size
=
1
,
nccl_port
=
cls
.
port_args
.
nccl_port
,
server_args
=
cls
.
server_args
,
)
cls
.
tokenizer
=
get_tokenizer
(
cls
.
server_args
.
tokenizer_path
,
tokenizer_mode
=
cls
.
server_args
.
tokenizer_mode
,
trust_remote_code
=
cls
.
server_args
.
trust_remote_code
,
)
print
(
f
"Test with model:
{
cls
.
model_path
}
, num_hidden_layers:
{
cls
.
model_config
.
num_hidden_layers
}
"
)
def
prepare_test_batch
(
self
,
batch_size
=
2
,
input_len
=
128
,
is_split_prefill
=
True
):
"""Prepare a test batch for split prefill testing."""
# Create synthetic input
input_ids
=
np
.
random
.
randint
(
10
,
1000
,
(
batch_size
,
input_len
),
dtype
=
np
.
int32
)
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_new_tokens
=
8
,
)
reqs
=
[]
for
i
in
range
(
batch_size
):
req
=
Req
(
rid
=
i
,
origin_input_text
=
""
,
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
)
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
reqs
.
append
(
req
)
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
model_runner
.
token_to_kv_pool_allocator
,
tree_cache
=
None
,
model_config
=
self
.
model_config
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
enable_custom_logit_processor
=
False
,
)
if
is_split_prefill
:
batch
.
prepare_for_split_prefill
()
else
:
batch
.
prepare_for_extend
()
# Create forward batch
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
return
forward_batch
def
test_split_prefill_functionality
(
self
):
"""Test that split prefill can complete successfully."""
print
(
"
\n
=== Testing split prefill functionality ==="
)
forward_batch
=
self
.
prepare_test_batch
(
batch_size
=
2
,
input_len
=
64
)
# Reset split index
forward_batch
.
split_index
=
0
# Test split prefill in chunks
num_layers
=
self
.
model_config
.
num_hidden_layers
chunk_size
=
max
(
1
,
num_layers
//
4
)
# Split into 4 chunks
results
=
[]
split_count
=
0
while
forward_batch
.
split_index
<
num_layers
:
print
(
f
"Processing split
{
split_count
}
, split_index:
{
forward_batch
.
split_index
}
"
)
result
=
self
.
model_runner
.
forward_split_prefill
(
forward_batch
=
forward_batch
,
reinit_attn_backend
=
(
split_count
==
0
),
forward_count
=
chunk_size
,
)
results
.
append
(
result
)
split_count
+=
1
# Verify split_index is updated correctly
expected_next_index
=
min
(
split_count
*
chunk_size
,
num_layers
)
self
.
assertEqual
(
forward_batch
.
split_index
,
expected_next_index
)
# The last result should contain logits
self
.
assertIsNotNone
(
results
[
-
1
],
"Final split should return logits"
)
print
(
f
"Split prefill completed in
{
split_count
}
splits"
)
def
test_split_prefill_vs_normal_prefill
(
self
):
"""Test that split prefill produces the same results as normal prefill."""
print
(
"
\n
=== Testing split prefill vs normal prefill consistency ==="
)
forward_batch_normal
=
self
.
prepare_test_batch
(
batch_size
=
2
,
input_len
=
128
,
is_split_prefill
=
False
)
forward_batch_split
=
self
.
prepare_test_batch
(
batch_size
=
2
,
input_len
=
128
,
is_split_prefill
=
True
)
# Ensure same input
forward_batch_split
.
input_ids
=
forward_batch_normal
.
input_ids
.
clone
()
forward_batch_split
.
positions
=
forward_batch_normal
.
positions
.
clone
()
# Method 1: Normal extend (prefill)
print
(
"Running normal extend (prefill)..."
)
normal_result
=
self
.
model_runner
.
forward_extend
(
forward_batch_normal
)
# Method 2: Split prefill
print
(
"Running split prefill..."
)
num_layers
=
self
.
model_config
.
num_hidden_layers
chunk_size
=
max
(
1
,
num_layers
//
3
)
# Split into 3 chunks
split_result
=
None
while
forward_batch_split
.
split_index
<
num_layers
:
result
=
self
.
model_runner
.
forward_split_prefill
(
forward_batch
=
forward_batch_split
,
forward_count
=
chunk_size
,
)
if
result
is
not
None
:
split_result
=
result
# Compare results
self
.
assertIsNotNone
(
normal_result
,
"Normal prefill should return result"
)
self
.
assertIsNotNone
(
split_result
,
"Split prefill should return result"
)
# Compare logits shapes
self
.
assertEqual
(
normal_result
.
next_token_logits
.
shape
,
split_result
.
next_token_logits
.
shape
,
"Logits shapes should match"
,
)
# Compare logits values (should be very close due to same computation)
# Use a larger tolerance for numerical differences in split computation
torch
.
testing
.
assert_close
(
normal_result
.
next_token_logits
,
split_result
.
next_token_logits
,
rtol
=
1e-3
,
atol
=
1e-3
,
msg
=
"Split prefill and normal prefill should produce similar logits"
,
)
print
(
"✓ Split prefill and normal prefill produce consistent results"
)
def
test_split_prefill_different_chunk_sizes
(
self
):
"""Test split prefill with different chunk sizes."""
print
(
"
\n
=== Testing split prefill with different chunk sizes ==="
)
num_layers
=
self
.
model_config
.
num_hidden_layers
chunk_sizes
=
[
1
,
2
,
max
(
1
,
num_layers
//
2
),
num_layers
]
# Prepare identical batches for each test
base_batch
=
self
.
prepare_test_batch
(
batch_size
=
1
,
input_len
=
16
)
base_input_ids
=
base_batch
.
input_ids
.
clone
()
base_positions
=
base_batch
.
positions
.
clone
()
results
=
[]
for
chunk_size
in
chunk_sizes
:
if
chunk_size
>
num_layers
:
continue
print
(
f
"Testing chunk size:
{
chunk_size
}
"
)
# Prepare fresh batch
forward_batch
=
self
.
prepare_test_batch
(
batch_size
=
1
,
input_len
=
16
)
forward_batch
.
input_ids
=
base_input_ids
.
clone
()
forward_batch
.
positions
=
base_positions
.
clone
()
forward_batch
.
split_index
=
0
# Run split prefill
split_result
=
None
while
forward_batch
.
split_index
<
num_layers
:
result
=
self
.
model_runner
.
forward_split_prefill
(
forward_batch
=
forward_batch
,
forward_count
=
chunk_size
,
)
if
result
is
not
None
:
split_result
=
result
self
.
assertIsNotNone
(
split_result
,
f
"Split prefill should succeed with chunk_size=
{
chunk_size
}
"
,
)
results
.
append
(
split_result
)
# Compare all results should be identical (same input, same computation)
if
len
(
results
)
>
1
:
for
i
,
result
in
enumerate
(
results
[
1
:],
1
):
torch
.
testing
.
assert_close
(
results
[
0
].
next_token_logits
,
result
.
next_token_logits
,
rtol
=
1e-3
,
atol
=
1e-3
,
msg
=
f
"Results with different chunk sizes should be identical (chunk_size
{
chunk_sizes
[
i
]
}
)"
,
)
print
(
"✓ All chunk sizes produce consistent results"
)
def
test_split_prefill_edge_cases
(
self
):
"""Test edge cases for split prefill."""
print
(
"
\n
=== Testing split prefill edge cases ==="
)
# Test with single layer chunks
forward_batch
=
self
.
prepare_test_batch
(
batch_size
=
1
,
input_len
=
8
)
# Process one layer at a time
num_layers
=
self
.
model_config
.
num_hidden_layers
for
layer_idx
in
range
(
num_layers
):
result
=
self
.
model_runner
.
forward_split_prefill
(
forward_batch
=
forward_batch
,
reinit_attn_backend
=
(
layer_idx
==
0
),
forward_count
=
1
,
# One layer at a time
)
if
layer_idx
==
num_layers
-
1
:
# Last layer should return result
self
.
assertIsNotNone
(
result
,
"Last layer should return logits"
)
else
:
# Intermediate layers should return None
self
.
assertIsNone
(
result
,
f
"Layer
{
layer_idx
}
should return None"
)
print
(
"✓ Single layer processing works correctly"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment