Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dfbe60dc
Unverified
Commit
dfbe60dc
authored
Jun 03, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 02, 2024
Browse files
[Misc] Simplify code and fix type annotations in `conftest.py` (#5118)
parent
a66cf40b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
50 deletions
+42
-50
tests/conftest.py
tests/conftest.py
+42
-50
No files found.
tests/conftest.py
View file @
dfbe60dc
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
AutoProcessor
,
AutoTokenizer
,
from
transformers
import
(
AutoModelForCausalLM
,
AutoProcessor
,
AutoTokenizer
,
LlavaConfig
,
LlavaForConditionalGeneration
)
LlavaConfig
,
LlavaForConditionalGeneration
)
...
@@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
...
@@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
from
vllm.distributed
import
destroy_model_parallel
from
vllm.distributed
import
destroy_model_parallel
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Text
Prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
MultiModalData
,
SampleLogprobs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -188,10 +189,11 @@ class HfRunner:
...
@@ -188,10 +189,11 @@ class HfRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
outputs
:
List
[
Tuple
[
List
[
int
],
str
]]
=
[]
if
images
:
if
images
:
assert
len
(
prompts
)
==
len
(
images
)
assert
len
(
prompts
)
==
len
(
images
)
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
processor_kwargs
:
Dict
[
str
,
Any
]
=
{
"text"
:
prompt
,
"text"
:
prompt
,
...
@@ -201,17 +203,13 @@ class HfRunner:
...
@@ -201,17 +203,13 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
{
key
:
value
.
cuda
()
if
value
is
not
None
else
None
for
key
,
value
in
inputs
.
items
()
}
output_ids
=
self
.
model
.
generate
(
output_ids
=
self
.
model
.
generate
(
**
inputs
,
**
inputs
.
to
(
"cuda"
)
,
use_cache
=
True
,
use_cache
=
True
,
**
kwargs
,
**
kwargs
,
)
)
output_str
=
self
.
tokenize
r
.
batch_decode
(
output_str
=
self
.
processo
r
.
batch_decode
(
output_ids
,
output_ids
,
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
...
@@ -224,23 +222,22 @@ class HfRunner:
...
@@ -224,23 +222,22 @@ class HfRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
images
:
Optional
[
"torch.Tensor"
]
=
None
,
images
:
Optional
[
List
[
Image
.
Image
]
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
=
self
.
generate
(
prompts
,
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
max_new_tokens
=
max_tokens
,
images
=
images
)
images
=
images
)
for
i
in
range
(
len
(
outputs
)):
output_ids
,
output_str
=
outputs
[
i
]
return
[(
output_ids
[
0
],
output_str
[
0
])
outputs
[
i
]
=
(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
return
outputs
def
generate_beam_search
(
def
generate_beam_search
(
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
beam_width
:
int
,
beam_width
:
int
,
max_tokens
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]
]
,
List
[
str
]]
]
:
outputs
=
self
.
generate
(
prompts
,
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
max_new_tokens
=
max_tokens
,
...
@@ -282,9 +279,7 @@ class HfRunner:
...
@@ -282,9 +279,7 @@ class HfRunner:
if
self
.
model
.
get_output_embeddings
().
bias
is
not
None
:
if
self
.
model
.
get_output_embeddings
().
bias
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
(
logits
+=
self
.
model
.
get_output_embeddings
(
).
bias
.
unsqueeze
(
0
)
).
bias
.
unsqueeze
(
0
)
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
seq_logprobs
.
append
(
logprobs
)
all_logprobs
.
append
(
seq_logprobs
)
all_logprobs
.
append
(
seq_logprobs
)
return
all_logprobs
return
all_logprobs
...
@@ -294,10 +289,10 @@ class HfRunner:
...
@@ -294,10 +289,10 @@ class HfRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]
]]:
all_logprobs
=
[]
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
prompt
in
prompts
:
for
prompt
in
prompts
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
...
@@ -310,7 +305,7 @@ class HfRunner:
...
@@ -310,7 +305,7 @@ class HfRunner:
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
seq_logprobs
=
[]
seq_logprobs
:
List
[
torch
.
Tensor
]
=
[]
for
_
,
hidden_states
in
enumerate
(
output
.
hidden_states
):
for
_
,
hidden_states
in
enumerate
(
output
.
hidden_states
):
last_hidden_states
=
hidden_states
[
-
1
][
0
]
last_hidden_states
=
hidden_states
[
-
1
][
0
]
logits
=
torch
.
matmul
(
logits
=
torch
.
matmul
(
...
@@ -321,13 +316,11 @@ class HfRunner:
...
@@ -321,13 +316,11 @@ class HfRunner:
None
)
is
not
None
:
None
)
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
(
logits
+=
self
.
model
.
get_output_embeddings
(
).
bias
.
unsqueeze
(
0
)
).
bias
.
unsqueeze
(
0
)
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
seq_logprobs
.
append
(
logprobs
)
# convert to dict
# convert to dict
seq_logprobs_lst
=
[]
seq_logprobs_lst
:
List
[
Dict
[
int
,
float
]]
=
[]
for
tok_idx
,
tok_logprobs
in
enumerate
(
seq_logprobs
):
for
tok_idx
,
tok_logprobs
in
enumerate
(
seq_logprobs
):
# drop prompt logprobs
# drop prompt logprobs
if
tok_idx
==
0
:
if
tok_idx
==
0
:
...
@@ -372,13 +365,13 @@ class VllmRunner:
...
@@ -372,13 +365,13 @@ class VllmRunner:
tokenizer_name
:
Optional
[
str
]
=
None
,
tokenizer_name
:
Optional
[
str
]
=
None
,
# Use smaller max model length, otherwise bigger model cannot run due
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
# to kv cache size limit.
max_model_len
=
1024
,
max_model_len
:
int
=
1024
,
dtype
:
str
=
"half"
,
dtype
:
str
=
"half"
,
disable_log_stats
:
bool
=
True
,
disable_log_stats
:
bool
=
True
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
block_size
:
int
=
16
,
block_size
:
int
=
16
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
swap_space
=
4
,
swap_space
:
int
=
4
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
self
.
model
=
LLM
(
self
.
model
=
LLM
(
...
@@ -399,32 +392,31 @@ class VllmRunner:
...
@@ -399,32 +392,31 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
"
torch.Tensor
"
]
=
None
,
images
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]
]
,
List
[
str
]]
]
:
if
images
is
not
None
:
if
images
is
not
None
:
assert
len
(
prompts
)
==
images
.
shape
[
0
]
assert
len
(
prompts
)
==
len
(
images
)
prompt_inputs
:
List
[
Prompt
Inputs
]
=
[]
prompt_inputs
:
List
[
Text
Prompt
]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
image
=
None
if
images
is
None
else
images
[
i
:
i
+
1
]
prompt
=
TextPrompt
(
prompt
=
prompt
)
mm_data
=
None
if
image
is
None
else
MultiModalData
(
if
images
is
not
None
:
type
=
MultiModalData
.
Type
.
IMAGE
,
prompt
[
"multi_modal_data"
]
=
MultiModalData
(
data
=
image
,
type
=
MultiModalData
.
Type
.
IMAGE
,
)
data
=
images
[
i
:
i
+
1
],
)
prompt_inputs
.
append
({
prompt_inputs
.
append
(
prompt
)
"prompt"
:
prompt
,
"multi_modal_data"
:
mm_data
,
})
req_outputs
=
self
.
model
.
generate
(
prompt_inputs
,
req_outputs
=
self
.
model
.
generate
(
prompt_inputs
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
)
outputs
=
[]
outputs
:
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
prompt_ids
=
req_output
.
prompt_token_ids
req_sample_output_ids
=
[]
req_sample_output_ids
:
List
[
List
[
int
]]
=
[]
req_sample_output_strs
=
[]
req_sample_output_strs
:
List
[
str
]
=
[]
for
sample
in
req_output
.
outputs
:
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_str
=
sample
.
text
output_ids
=
sample
.
token_ids
output_ids
=
sample
.
token_ids
...
@@ -437,12 +429,12 @@ class VllmRunner:
...
@@ -437,12 +429,12 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]
]]:
assert
sampling_params
.
logprobs
is
not
None
assert
sampling_params
.
logprobs
is
not
None
req_outputs
=
self
.
model
.
generate
(
prompts
,
req_outputs
=
self
.
model
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
)
outputs
=
[]
outputs
:
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
for
sample
in
req_output
.
outputs
:
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_str
=
sample
.
text
...
@@ -467,7 +459,7 @@ class VllmRunner:
...
@@ -467,7 +459,7 @@ class VllmRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]
]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
logprobs
=
num_logprobs
)
...
@@ -481,7 +473,7 @@ class VllmRunner:
...
@@ -481,7 +473,7 @@ class VllmRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
beam_width
:
int
,
beam_width
:
int
,
max_tokens
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]
]
,
List
[
str
]]
]
:
beam_search_params
=
SamplingParams
(
n
=
beam_width
,
beam_search_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
use_beam_search
=
True
,
temperature
=
0.0
,
temperature
=
0.0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment