Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
cd3aa153
"vscode:/vscode.git/clone" did not exist on "5545bbc54b5a2d5ea1e55e0426075afca1753ee8"
Unverified
Commit
cd3aa153
authored
Dec 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 02, 2023
Browse files
Fix broken worker test (#1900)
parent
9b294976
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
9 deletions
+14
-9
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+14
-9
No files found.
tests/worker/test_
work
er.py
→
tests/worker/test_
model_runn
er.py
View file @
cd3aa153
...
...
@@ -2,18 +2,19 @@ import random
import
torch
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.
work
er
import
Work
er
from
vllm.worker.
model_runn
er
import
ModelRunn
er
def
test_worker_prepare_inputs_for_prompt
():
worker
=
Worker
(
None
,
None
,
None
)
worker
.
block_size
=
16
def
test_prepare_prompt
():
model_runner
=
ModelRunner
(
None
,
None
,
None
)
model_runner
.
set_block_size
(
16
)
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
work
er
.
block_size
-
1
)
+
1
prompt_len
=
i
%
(
model_runn
er
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_group_metadata_list
.
append
(
...
...
@@ -24,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt():
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
))
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
max_seq_len
=
max
(
prompt_lens
)
...
...
@@ -31,12 +33,15 @@ def test_worker_prepare_inputs_for_prompt():
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
input_tokens
,
input_positions
,
input_metadata
=
work
er
.
_prepare_
inputs
(
input_tokens
,
input_positions
,
_
=
model_runn
er
.
_prepare_
prompt
(
seq_group_metadata_list
)
assert
input_tokens
.
shape
==
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
assert
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
actual
=
input_metadata
.
selected_token_indices
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment