Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2c3c1bd0
Unverified
Commit
2c3c1bd0
authored
Sep 17, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 17, 2025
Browse files
[V0 Deprecation] Remove V0 Engine tests (#25114)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
5963b98b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1 addition
and
622 deletions
+1
-622
tests/engine/conftest.py
tests/engine/conftest.py
+0
-12
tests/engine/test_computed_prefix_blocks.py
tests/engine/test_computed_prefix_blocks.py
+0
-37
tests/engine/test_executor.py
tests/engine/test_executor.py
+0
-111
tests/engine/test_multiproc_workers.py
tests/engine/test_multiproc_workers.py
+0
-179
tests/engine/test_options.py
tests/engine/test_options.py
+0
-58
tests/engine/test_short_mm_context.py
tests/engine/test_short_mm_context.py
+1
-0
tests/engine/test_stop_checker.py
tests/engine/test_stop_checker.py
+0
-225
No files found.
tests/engine/conftest.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
tests/engine/test_computed_prefix_blocks.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.sampling_params
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_computed_prefix_blocks
(
model
:
str
,
block_size
:
int
):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt
=
(
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?"
)
prompt2
=
(
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration."
)
engine_args
=
EngineArgs
(
model
=
model
,
block_size
=
block_size
,
enable_prefix_caching
=
True
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
()
engine
.
add_request
(
"0"
,
prompt
+
prompt2
,
sampling_params
)
engine
.
step
()
engine
.
add_request
(
"1"
,
prompt
,
sampling_params
)
engine
.
step
()
tests/engine/test_executor.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
pytest
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.executor.uniproc_executor
import
UniProcExecutor
from
vllm.sampling_params
import
SamplingParams
class
Mock
:
...
class
CustomUniExecutor
(
UniProcExecutor
):
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
args
:
tuple
=
(),
kwargs
:
Optional
[
dict
]
=
None
)
->
list
[
Any
]:
# Drop marker to show that this was run
with
open
(
".marker"
,
"w"
):
...
return
super
().
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
CustomUniExecutorAsync
=
CustomUniExecutor
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
def
test_custom_executor_type_checking
(
model
):
with
pytest
.
raises
(
ValueError
):
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
Mock
)
LLMEngine
.
from_engine_args
(
engine_args
)
with
pytest
.
raises
(
ValueError
):
engine_args
=
AsyncEngineArgs
(
model
=
model
,
distributed_executor_backend
=
Mock
)
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
def
test_custom_executor
(
model
,
tmp_path
):
cwd
=
os
.
path
.
abspath
(
"."
)
os
.
chdir
(
tmp_path
)
try
:
assert
not
os
.
path
.
exists
(
".marker"
)
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomUniExecutor
,
enforce_eager
=
True
,
# reduce test time
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
engine
.
add_request
(
"0"
,
"foo"
,
sampling_params
)
engine
.
step
()
assert
os
.
path
.
exists
(
".marker"
)
finally
:
os
.
chdir
(
cwd
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
def
test_custom_executor_async
(
model
,
tmp_path
):
cwd
=
os
.
path
.
abspath
(
"."
)
os
.
chdir
(
tmp_path
)
try
:
assert
not
os
.
path
.
exists
(
".marker"
)
engine_args
=
AsyncEngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomUniExecutorAsync
,
enforce_eager
=
True
,
# reduce test time
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
async
def
t
():
stream
=
await
engine
.
add_request
(
"0"
,
"foo"
,
sampling_params
)
async
for
x
in
stream
:
...
asyncio
.
run
(
t
())
assert
os
.
path
.
exists
(
".marker"
)
finally
:
os
.
chdir
(
cwd
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
def
test_respect_ray
(
model
):
# even for TP=1 and PP=1,
# if users specify ray, we should use ray.
# users might do this if they want to manage the
# resources using ray.
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
"ray"
,
enforce_eager
=
True
,
# reduce test time
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
assert
engine
.
model_executor
.
uses_ray
tests/engine/test_multiproc_workers.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
time
import
sleep
from
typing
import
Any
import
pytest
from
vllm.config
import
VllmConfig
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
class
DummyWorkerWrapper
(
WorkerWrapperBase
):
"""Dummy version of vllm.worker.worker.Worker"""
def
worker_method
(
self
,
worker_input
:
Any
)
->
tuple
[
int
,
Any
]:
sleep
(
0.05
)
if
isinstance
(
worker_input
,
Exception
):
# simulate error case
raise
worker_input
return
self
.
rpc_rank
,
input
def
_start_workers
()
->
tuple
[
list
[
ProcessWorkerWrapper
],
WorkerMonitor
]:
result_handler
=
ResultHandler
()
vllm_config
=
VllmConfig
()
workers
=
[
ProcessWorkerWrapper
(
result_handler
,
DummyWorkerWrapper
,
vllm_config
,
rank
)
for
rank
in
range
(
8
)
]
worker_monitor
=
WorkerMonitor
(
workers
,
result_handler
)
assert
not
worker_monitor
.
is_alive
()
result_handler
.
start
()
worker_monitor
.
start
()
assert
worker_monitor
.
is_alive
()
return
workers
,
worker_monitor
def
test_local_workers
()
->
None
:
"""Test workers with sync task submission"""
workers
,
worker_monitor
=
_start_workers
()
def
execute_workers
(
worker_input
:
str
)
->
None
:
worker_outputs
=
[
worker
.
execute_method
(
"worker_method"
,
worker_input
)
for
worker
in
workers
]
for
rank
,
output
in
enumerate
(
worker_outputs
):
assert
output
.
get
()
==
(
rank
,
input
)
executor
=
ThreadPoolExecutor
(
max_workers
=
4
)
# Test concurrent submission from different threads
futures
=
[
executor
.
submit
(
partial
(
execute_workers
,
f
"thread
{
thread_num
}
"
))
for
thread_num
in
range
(
4
)
]
for
future
in
futures
:
future
.
result
()
# Test error case
exception
=
ValueError
(
"fake error"
)
result
=
workers
[
0
].
execute_method
(
"worker_method"
,
exception
)
try
:
result
.
get
()
pytest
.
fail
(
"task should have failed"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ValueError
)
assert
str
(
e
)
==
"fake error"
# Test cleanup when a worker fails
assert
worker_monitor
.
is_alive
()
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
20
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
workers
[
0
].
execute_method
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
def
test_local_workers_clean_shutdown
()
->
None
:
"""Test clean shutdown"""
workers
,
worker_monitor
=
_start_workers
()
assert
worker_monitor
.
is_alive
()
assert
all
(
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Clean shutdown
worker_monitor
.
close
()
worker_monitor
.
join
(
20
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
workers
[
0
].
execute_method
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
@
pytest
.
mark
.
asyncio
async
def
test_local_workers_async
()
->
None
:
"""Test local workers with async task submission"""
workers
,
worker_monitor
=
_start_workers
()
async
def
execute_workers
(
worker_input
:
str
)
->
None
:
worker_coros
=
[
worker
.
execute_method_async
(
"worker_method"
,
worker_input
)
for
worker
in
workers
]
results
=
await
asyncio
.
gather
(
*
worker_coros
)
for
rank
,
result
in
enumerate
(
results
):
assert
result
==
(
rank
,
input
)
tasks
=
[
asyncio
.
create_task
(
execute_workers
(
f
"task
{
task_num
}
"
))
for
task_num
in
range
(
4
)
]
for
task
in
tasks
:
await
task
# Test error case
exception
=
ValueError
(
"fake error"
)
try
:
_result
=
await
workers
[
0
].
execute_method_async
(
"worker_method"
,
exception
)
pytest
.
fail
(
"task should have failed"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ValueError
)
assert
str
(
e
)
==
"fake error"
# Test cleanup when a worker fails
assert
worker_monitor
.
is_alive
()
workers
[
3
].
process
.
kill
()
# Other workers should get shut down here
worker_monitor
.
join
(
20
)
# Ensure everything is stopped
assert
not
worker_monitor
.
is_alive
()
assert
all
(
not
worker
.
process
.
is_alive
()
for
worker
in
workers
)
# Further attempts to submit tasks should fail
try
:
_result
=
await
workers
[
0
].
execute_method_async
(
"worker_method"
,
"test"
)
pytest
.
fail
(
"task should fail once workers have been shut down"
)
except
Exception
as
e
:
assert
isinstance
(
e
,
ChildProcessError
)
tests/engine/test_options.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
nullcontext
import
pytest
from
vllm.entrypoints.llm
import
LLM
from
vllm.sampling_params
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
def
test_skip_tokenizer_initialization
(
model
:
str
):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm
=
LLM
(
model
=
model
,
skip_tokenizer_init
=
True
,
enforce_eager
=
True
,
)
sampling_params
=
SamplingParams
(
prompt_logprobs
=
True
,
detokenize
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
"cannot pass text prompts when"
):
llm
.
generate
(
"abc"
,
sampling_params
)
outputs
=
llm
.
generate
({
"prompt_token_ids"
:
[
1
,
2
,
3
]},
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
>
0
completions
=
outputs
[
0
].
outputs
assert
len
(
completions
)
>
0
assert
completions
[
0
].
text
==
""
assert
completions
[
0
].
token_ids
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
def
test_enable_prompt_embeds
(
hf_runner
,
model
:
str
,
enable_prompt_embeds
:
bool
):
prompt
=
"abc"
with
hf_runner
(
model
)
as
hf_model
:
token_ids
=
hf_model
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
token_ids
=
token_ids
.
to
(
hf_model
.
model
.
device
)
embed_layer
=
hf_model
.
model
.
get_input_embeddings
()
prompt_embeds
=
embed_layer
(
token_ids
).
squeeze
(
0
)
ctx
=
(
nullcontext
()
if
enable_prompt_embeds
else
pytest
.
raises
(
ValueError
,
match
=
"set `--enable-prompt-embeds`"
))
llm
=
LLM
(
model
=
model
,
enable_prompt_embeds
=
enable_prompt_embeds
,
enforce_eager
=
True
,
)
with
ctx
:
llm
.
generate
({
"prompt_embeds"
:
prompt_embeds
})
tests/engine/test_short_mm_context.py
View file @
2c3c1bd0
...
...
@@ -25,6 +25,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
model
,
max_model_len
=
128
,
# LLaVA has a feature size of 576
enforce_eager
=
True
,
load_format
=
"dummy"
,
)
with
vllm_model
:
...
...
tests/engine/test_stop_checker.py
deleted
100644 → 0
View file @
5963b98b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
transformers
import
AutoTokenizer
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.reasoning
import
ReasoningParser
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
REASONING_MODEL_NAME
=
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
class
MockReasoningParser
(
ReasoningParser
):
"""Mock reasoning parser for testing purposes."""
def
__init__
(
self
,
tokenizer
:
AutoTokenizer
,
reasoning_active
:
bool
=
False
):
super
().
__init__
(
tokenizer
)
self
.
reasoning_active
=
reasoning_active
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
not
self
.
reasoning_active
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
return
input_ids
class
MockSequence
(
Sequence
):
"""Mock sequence for testing purposes."""
def
__init__
(
self
,
token_ids
,
output_text
=
"test_output"
,
eos_token_id
=
0
):
self
.
token_ids
=
token_ids
self
.
output_text
=
output_text
self
.
eos_token_id
=
eos_token_id
self
.
status
=
SequenceStatus
.
RUNNING
self
.
stop_reason
=
None
def
get_token_ids
(
self
):
return
self
.
token_ids
def
get_last_token_id
(
self
):
return
self
.
token_ids
[
-
1
]
if
self
.
token_ids
else
None
def
get_len
(
self
):
return
len
(
self
.
token_ids
)
def
get_output_len
(
self
):
return
len
(
self
.
token_ids
)
-
1
# Simulating prompt + outputs
@
pytest
.
fixture
def
deepseek_r1_qwen_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
REASONING_MODEL_NAME
)
@
pytest
.
fixture
def
stop_checker
():
return
StopChecker
(
max_model_len
=
10
)
@
pytest
.
fixture
def
stop_checker_with_reasoner
():
reasoner
=
MockReasoningParser
(
deepseek_r1_qwen_tokenizer
)
return
StopChecker
(
max_model_len
=
10
,
reasoner
=
reasoner
)
def
test_eos_token_stopping
(
stop_checker
):
"""Test sequence stopping when EOS token is encountered."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
0
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
()
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
def
test_ignore_eos
(
stop_checker
):
"""Test sequence continuing when EOS token is ignored."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
0
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
ignore_eos
=
True
)
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
RUNNING
def
test_min_tokens
(
stop_checker
):
"""Test min_tokens prevents early stopping."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
0
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
min_tokens
=
3
)
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
RUNNING
def
test_stop_token_ids
(
stop_checker
):
"""Test sequence stopping with custom stop token IDs."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
stop_token_ids
=
[
3
])
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
assert
seq
.
stop_reason
==
3
def
test_stop_strings
(
stop_checker
):
"""Test sequence stopping with stop strings."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
output_text
=
"test output with STOP"
,
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
stop
=
[
"STOP"
])
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
assert
seq
.
stop_reason
==
"STOP"
assert
"STOP"
not
in
seq
.
output_text
# Default behavior removes stop string
def
test_include_stop_str_in_output
(
stop_checker
):
"""Test keeping stop strings in output."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
output_text
=
"test output with STOP"
,
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
stop
=
[
"STOP"
],
include_stop_str_in_output
=
True
)
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
5
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
assert
"STOP"
in
seq
.
output_text
def
test_max_tokens
(
stop_checker
):
"""Test sequence stopping at max_tokens."""
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
max_tokens
=
2
)
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
def
test_max_model_len
(
stop_checker
):
"""Test sequence stopping at max_model_len."""
seq
=
MockSequence
(
token_ids
=
list
(
range
(
11
)),
eos_token_id
=
0
)
# 11 tokens, max is 10
sampling_params
=
SamplingParams
()
stop_checker
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_LENGTH_CAPPED
def
test_reasoning_skip_stops
(
stop_checker_with_reasoner
):
"""Test that stop tokens and strings are ignored during reasoning."""
# Set reasoning_active to True to simulate being in reasoning mode
stop_checker_with_reasoner
.
reasoner
.
reasoning_active
=
True
# Test with stop token
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
stop_token_ids
=
[
3
])
stop_checker_with_reasoner
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
RUNNING
# Test with stop string
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
output_text
=
"test STOP"
)
sampling_params
=
SamplingParams
(
stop
=
[
"STOP"
])
stop_checker_with_reasoner
.
maybe_stop_sequence
(
seq
,
new_char_count
=
4
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
RUNNING
# But EOS token still stops the sequence
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
0
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
()
stop_checker_with_reasoner
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
def
test_reasoning_end_enables_stops
(
stop_checker_with_reasoner
):
"""Test that stop tokens work after reasoning ends."""
# Set reasoning_active to False to simulate being out of reasoning mode
stop_checker_with_reasoner
.
reasoner
.
reasoning_active
=
False
# Test with stop token
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
eos_token_id
=
0
)
sampling_params
=
SamplingParams
(
stop_token_ids
=
[
3
])
stop_checker_with_reasoner
.
maybe_stop_sequence
(
seq
,
new_char_count
=
1
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
# Test with stop string
seq
=
MockSequence
(
token_ids
=
[
1
,
2
,
3
],
output_text
=
"test STOP"
)
sampling_params
=
SamplingParams
(
stop
=
[
"STOP"
])
stop_checker_with_reasoner
.
maybe_stop_sequence
(
seq
,
new_char_count
=
4
,
sampling_params
=
sampling_params
)
assert
seq
.
status
==
SequenceStatus
.
FINISHED_STOPPED
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment