Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
32b6816e
Unverified
Commit
32b6816e
authored
Sep 01, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 01, 2023
Browse files
Add tests for models (#922)
parent
c128d698
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
178 additions
and
0 deletions
+178
-0
requirements-dev.txt
requirements-dev.txt
+1
-0
tests/conftest.py
tests/conftest.py
+132
-0
tests/models/test_models.py
tests/models/test_models.py
+45
-0
No files found.
requirements-dev.txt
View file @
32b6816e
...
@@ -10,3 +10,4 @@ types-setuptools
...
@@ -10,3 +10,4 @@ types-setuptools
# testing
# testing
pytest
pytest
pytest-forked
tests/conftest.py
0 → 100644
View file @
32b6816e
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
from
transformers
import
AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
_TEST_PROMPTS
=
[
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
,
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020."
,
"Compare and contrast artificial intelligence with human intelligence in terms of processing information."
,
"Describe the basic components of a neural network and how it can be trained."
,
"Write a short story about a robot that dreams for the first time."
,
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models."
,
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies."
,
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'"
,
]
@
pytest
.
fixture
def
example_prompts
()
->
List
[
str
]:
return
_TEST_PROMPTS
_STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
}
class
HfRunner
:
def
__init__
(
self
,
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
dtype
:
str
=
"half"
,
)
->
None
:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
def
generate
(
self
,
prompts
:
List
[
str
],
**
kwargs
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
[]
for
prompt
in
prompts
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
output_ids
=
self
.
model
.
generate
(
input_ids
.
cuda
(),
use_cache
=
True
,
**
kwargs
,
)
output_str
=
self
.
tokenizer
.
batch_decode
(
output_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
)[
0
]
output_ids
=
output_ids
[
0
].
cpu
().
tolist
()
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
return
self
.
generate
(
prompts
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
)
@
pytest
.
fixture
def
hf_runner
():
return
HfRunner
class
VllmRunner
:
def
__init__
(
self
,
model_name
:
str
,
tokenizer_name
:
Optional
[
str
]
=
None
,
dtype
:
str
=
"half"
,
)
->
None
:
self
.
model
=
LLM
(
model
=
model_name
,
tokenizer
=
tokenizer_name
,
trust_remote_code
=
True
,
dtype
=
dtype
,
swap_space
=
0
,
)
def
generate
(
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
req_outputs
=
self
.
model
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
outputs
=
[]
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
output_str
=
req_output
.
outputs
[
0
].
text
output_ids
=
req_output
.
outputs
[
0
].
token_ids
outputs
.
append
((
prompt_ids
+
output_ids
,
prompt_str
+
output_str
))
return
outputs
def
generate_greedy
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
return
self
.
generate
(
prompts
,
greedy_params
)
@
pytest
.
fixture
def
vllm_runner
():
return
VllmRunner
tests/models/test_models.py
0 → 100644
View file @
32b6816e
"""Compare the outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/models/test_models.py --forked`.
"""
import
pytest
MODELS
=
[
"facebook/opt-125m"
,
"gpt2"
,
"bigcode/tiny_starcoder_py"
,
"EleutherAI/gpt-j-6b"
,
"EleutherAI/pythia-70m"
,
"bigscience/bloom-560m"
,
"mosaicml/mpt-7b"
,
"tiiuae/falcon-7b"
,
"meta-llama/Llama-2-7b-hf"
,
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_outputs
[
i
]
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment