Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e865b6a
Unverified
Commit
6e865b6a
authored
Dec 05, 2025
by
Chukwuma Nwaugha
Committed by
GitHub
Dec 05, 2025
Browse files
Refactor example prompts fixture (#29854)
Signed-off-by: nwaughac@gmail.com
parent
d698bb38
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
22 deletions
+25
-22
tests/conftest.py
tests/conftest.py
+25
-22
No files found.
tests/conftest.py
View file @
6e865b6a
...
@@ -27,7 +27,7 @@ import threading
...
@@ -27,7 +27,7 @@ import threading
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
from
typing
import
Any
,
Callable
,
TypedDict
,
TypeVar
,
cast
,
TYPE_CHECKING
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect
...
@@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers.generation.utils
import
GenerateOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
...
@@ -202,10 +207,7 @@ def dynamo_reset():
...
@@ -202,10 +207,7 @@ def dynamo_reset():
@
pytest
.
fixture
@
pytest
.
fixture
def
example_prompts
()
->
list
[
str
]:
def
example_prompts
()
->
list
[
str
]:
prompts
=
[]
return
[
prompt
for
filename
in
_TEST_PROMPTS
for
prompt
in
_read_prompts
(
filename
)]
for
filename
in
_TEST_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -224,10 +226,7 @@ class DecoderPromptType(Enum):
...
@@ -224,10 +226,7 @@ class DecoderPromptType(Enum):
@
pytest
.
fixture
@
pytest
.
fixture
def
example_long_prompts
()
->
list
[
str
]:
def
example_long_prompts
()
->
list
[
str
]:
prompts
=
[]
return
[
prompt
for
filename
in
_LONG_PROMPTS
for
prompt
in
_read_prompts
(
filename
)]
for
filename
in
_LONG_PROMPTS
:
prompts
+=
_read_prompts
(
filename
)
return
prompts
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
...
@@ -353,10 +352,13 @@ class HfRunner:
...
@@ -353,10 +352,13 @@ class HfRunner:
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
else
:
else
:
model
=
auto_cls
.
from_pretrained
(
model
=
cast
(
model_name
,
nn
.
Module
,
trust_remote_code
=
trust_remote_code
,
auto_cls
.
from_pretrained
(
**
model_kwargs
,
model_name
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
,
),
)
)
# in case some unquantized custom models are not in same dtype
# in case some unquantized custom models are not in same dtype
...
@@ -374,10 +376,12 @@ class HfRunner:
...
@@ -374,10 +376,12 @@ class HfRunner:
self
.
model
=
model
self
.
model
=
model
if
not
skip_tokenizer_init
:
if
not
skip_tokenizer_init
:
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
:
"PreTrainedTokenizer | PreTrainedTokenizerFast"
=
(
model_name
,
AutoTokenizer
.
from_pretrained
(
dtype
=
dtype
,
model_name
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
)
# don't put this import at the top level
# don't put this import at the top level
...
@@ -495,7 +499,7 @@ class HfRunner:
...
@@ -495,7 +499,7 @@ class HfRunner:
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output_ids
=
self
.
model
.
generate
(
output_ids
:
torch
.
Tensor
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
**
kwargs
,
**
kwargs
,
...
@@ -505,8 +509,7 @@ class HfRunner:
...
@@ -505,8 +509,7 @@ class HfRunner:
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
)
)
output_ids
=
output_ids
.
cpu
().
tolist
()
outputs
.
append
((
output_ids
.
cpu
().
tolist
(),
output_str
))
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
return
outputs
def
generate_greedy
(
def
generate_greedy
(
...
@@ -574,7 +577,7 @@ class HfRunner:
...
@@ -574,7 +577,7 @@ class HfRunner:
all_logprobs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
all_logprobs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output
=
self
.
model
.
generate
(
output
:
"GenerateOutput"
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
do_sample
=
False
,
do_sample
=
False
,
...
@@ -656,7 +659,7 @@ class HfRunner:
...
@@ -656,7 +659,7 @@ class HfRunner:
all_output_strs
:
list
[
str
]
=
[]
all_output_strs
:
list
[
str
]
=
[]
for
inputs
in
all_inputs
:
for
inputs
in
all_inputs
:
output
=
self
.
model
.
generate
(
output
:
"GenerateOutput"
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
use_cache
=
True
,
use_cache
=
True
,
do_sample
=
False
,
do_sample
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment