Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
32b6816e
Unverified
Commit
32b6816e
authored
Sep 01, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 01, 2023
Browse files
Add tests for models (#922)
parent
c128d698
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
178 additions
and
0 deletions
+178
-0
requirements-dev.txt
requirements-dev.txt
+1
-0
tests/conftest.py
tests/conftest.py
+132
-0
tests/models/test_models.py
tests/models/test_models.py
+45
-0
No files found.
requirements-dev.txt
View file @
32b6816e
...
@@ -10,3 +10,4 @@ types-setuptools
...
@@ -10,3 +10,4 @@ types-setuptools
# testing
# testing
pytest
pytest
pytest-forked
tests/conftest.py
0 → 100644
View file @
32b6816e
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
from
transformers
import
AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
_TEST_PROMPTS
=
[
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
,
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020."
,
"Compare and contrast artificial intelligence with human intelligence in terms of processing information."
,
"Describe the basic components of a neural network and how it can be trained."
,
"Write a short story about a robot that dreams for the first time."
,
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models."
,
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies."
,
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'"
,
]
@
pytest
.
fixture
def
example_prompts
()
->
List
[
str
]:
return
_TEST_PROMPTS
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
class
HfRunner
:
def
__init__
(
self
,
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
dtype
:
str
=
"half"
,
)
->
None
:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
def
generate
(
self
,
prompts
:
List
[
str
],
**
kwargs
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
[]
for
prompt
in
prompts
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
output_ids
=
self
.
model
.
generate
(
input_ids
.
cuda
(),
use_cache
=
True
,
**
kwargs
,
)
output_str
=
self
.
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
)[
0
]
output_ids
=
output_ids
[
0
].
cpu
().
tolist
()
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
return
self
.
generate
(
prompts
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
)
@
pytest
.
fixture
def
hf_runner
():
return
HfRunner
class
VllmRunner
:
def
__init__
(
self
,
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
dtype
:
str
=
"half"
,
)
->
None
:
self
.
model
=
LLM
(
model
=
model_name
,
tokenizer
=
tokenizer_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
swap_space
=
0
,
)
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
req_outputs
=
self
.
model
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
outputs
=
[]
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
output_str
=
req_output
.
outputs
[
0
].
text
output_ids
=
req_output
.
outputs
[
0
].
token_ids
outputs
.
append
((
prompt_ids
+
output_ids
,
prompt_str
+
output_str
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
return
self
.
generate
(
prompts
,
greedy_params
)
@
pytest
.
fixture
def
vllm_runner
():
return
VllmRunner
tests/models/test_models.py
0 → 100644
View file @
32b6816e
"""Compare the outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/models/test_models.py --forked`.
"""
import
pytest
MODELS
=
[
"facebook/opt-125m"
,
"gpt2"
,
"bigcode/tiny_starcoder_py"
,
"EleutherAI/gpt-j-6b"
,
"EleutherAI/pythia-70m"
,
"bigscience/bloom-560m"
,
"mosaicml/mpt-7b"
,
"tiiuae/falcon-7b"
,
"meta-llama/Llama-2-7b-hf"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_outputs
[
i
]
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment