Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8efe23f1
Unverified
Commit
8efe23f1
authored
Nov 09, 2023
by
Yanming W
Committed by
GitHub
Nov 08, 2023
Browse files
Fix input_metadata.selected_token_indices in worker prepare_inputs (#1546)
parent
06458a0b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
1 deletion
+47
-1
tests/worker/test_worker.py
tests/worker/test_worker.py
+44
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-1
No files found.
tests/worker/test_worker.py
0 → 100644
View file @
8efe23f1
# pylint: disable=protected-access
import
random
import
torch
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.worker
import
Worker
def
test_worker_prepare_inputs_for_prompt
():
worker
=
Worker
(
None
,
None
,
None
)
worker
.
block_size
=
16
batch_size
=
random
.
randint
(
1
,
256
)
prompt_lens
=
[]
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
prompt_len
=
i
%
(
worker
.
block_size
-
1
)
+
1
prompt_lens
.
append
(
prompt_len
)
seq_data
=
list
(
range
(
prompt_len
))
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
seq_data
)},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
{
0
:
[
1
]},
))
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
max_seq_len
=
max
(
prompt_lens
)
for
prompt_len
in
prompt_lens
:
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
input_tokens
,
input_positions
,
input_metadata
=
worker
.
_prepare_inputs
(
seq_group_metadata_list
)
assert
input_tokens
.
shape
==
input_positions
.
shape
==
(
batch_size
,
max_seq_len
)
torch
.
testing
.
assert_close
(
input_tokens
,
input_positions
)
actual
=
input_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
)
torch
.
testing
.
assert_close
(
actual
,
expected
)
vllm/worker/worker.py
View file @
8efe23f1
...
...
@@ -211,12 +211,14 @@ class Worker:
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
for
seq_group_metadata
in
seq_group_metadata_list
:
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
)
:
if
seq_group_metadata
.
is_prompt
:
# We need to do this in this loop as we need to know max_seq_len
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
sampling_params
=
seq_group_metadata
.
sampling_params
assert
len
(
prompt_lens
)
==
len
(
seq_group_metadata_list
)
prompt_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment