Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
837e1851
Unverified
Commit
837e1851
authored
Mar 24, 2024
by
youkaichao
Committed by
GitHub
Mar 24, 2024
Browse files
[CI/Build] fix flaky test (#3602)
parent
42bc3861
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
16 deletions
+10
-16
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+10
-16
No files found.
tests/worker/test_model_runner.py
View file @
837e1851
import
random
import
pytest
import
torch
from
vllm.config
import
ModelConfig
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
,
_
BATCH_SIZE_ALIGNMENT
from
vllm.worker.model_runner
import
ModelRunner
,
_
get_graph_batch_size
def
get_aligned_size
(
batch_size
:
int
,
alignment
:
int
):
return
((
batch_size
+
alignment
-
1
)
//
alignment
*
alignment
)
def
test_prepare_prompt
():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_prompt
(
batch_size
):
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
block_tables
=
{
0
:
[
1
]}
...
...
@@ -111,7 +107,8 @@ def test_prepare_prompt():
torch
.
testing
.
assert_close
(
actual
,
expected
)
def
test_prepare_decode_cuda_graph
():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
def
test_prepare_decode_cuda_graph
(
batch_size
):
model_config
=
ModelConfig
(
"facebook/opt-125m"
,
"facebook/opt-125m"
,
...
...
@@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
model_runner
=
ModelRunner
(
model_config
,
None
,
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
...
...
@@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
input_tokens
,
input_positions
,
input_metadata
,
_
,
_
,
_
=
(
model_runner
.
_prepare_decode
(
seq_group_metadata_list
))
expected_bs
=
_get_graph_batch_size
(
len
(
seq_group_metadata_list
))
# Verify input metadata is correct for prompts.
device
=
model_runner
.
device
assert
input_metadata
.
is_prompt
is
False
assert
input_metadata
.
prompt_lens
is
None
assert
input_metadata
.
num_prompt_tokens
==
0
assert
input_metadata
.
num_generation_tokens
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
))
assert
input_metadata
.
num_generation_tokens
==
expected_bs
assert
input_metadata
.
max_seq_len
is
None
assert
input_metadata
.
subquery_start_loc
is
None
assert
input_metadata
.
seq_start_loc
is
None
...
...
@@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
assert
input_metadata
.
use_cuda_graph
is
True
assert
input_metadata
.
kv_cache_dtype
==
"auto"
assert
input_tokens
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
assert
input_positions
.
shape
==
(
get_aligned_size
(
len
(
seq_group_metadata_list
),
_BATCH_SIZE_ALIGNMENT
),
)
assert
input_tokens
.
shape
==
(
expected_bs
,
)
assert
input_positions
.
shape
==
(
expected_bs
,
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
# Verify Sampling
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment