Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
15f56323
Unverified
Commit
15f56323
authored
Oct 30, 2023
by
Antoni Baum
Committed by
GitHub
Oct 30, 2023
Browse files
Delay GPU->CPU sync in sampling (#1337)
parent
aa9af07c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
47 deletions
+66
-47
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+18
-11
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-34
vllm/worker/worker.py
vllm/worker/worker.py
+45
-2
No files found.
vllm/model_executor/input_metadata.py
View file @
15f56323
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
xformers.ops
import
AttentionBias
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
...
...
@@ -29,6 +29,8 @@ class InputMetadata:
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
block_tables
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
self
.
seq_groups
=
seq_groups
...
...
@@ -38,6 +40,8 @@ class InputMetadata:
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
to_cache
=
None
...
...
@@ -72,13 +76,16 @@ class InputMetadata:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
f
'InputMetadata('
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'block_tables=
{
self
.
block_tables
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
)'
)
return
(
f
'InputMetadata('
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'block_tables=
{
self
.
block_tables
}
, '
f
'selected_token_indices=
{
self
.
selected_token_indices
}
, '
f
'categorized_sample_indices=
{
self
.
categorized_sample_indices
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
)'
)
vllm/model_executor/layers/sampler.py
View file @
15f56323
...
...
@@ -109,29 +109,8 @@ def _prune_hidden_states(
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
selected_token_indices
:
List
[
int
]
=
[]
start_idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
start_idx
+
prompt_len
-
1
)
start_idx
+=
input_metadata
.
max_prompt_len
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
device
=
hidden_states
.
device
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
selected_token_indices
)
return
hidden_states
.
index_select
(
0
,
input_metadata
.
selected_token_indices
)
def
_get_penalties
(
...
...
@@ -426,21 +405,11 @@ def _sample(
input_metadata
:
InputMetadata
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
start_idx
=
0
categorized_sample_indices
=
input_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
_
,
sampling_params
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
if
(
i
<
input_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: prompt token positions do not need sample, skip
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
start_idx
+=
prompt_len
-
1
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
num_seqs
=
len
(
seq_ids
)
categorized_sample_indices
[
sampling_type
].
extend
(
range
(
start_idx
,
start_idx
+
num_seqs
))
start_idx
+=
num_seqs
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
for
sampling_type
in
SamplingType
:
...
...
vllm/worker/worker.py
View file @
15f56323
...
...
@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
from
vllm.model_executor
import
get_model
,
InputMetadata
,
set_random_seed
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.utils
import
get_gpu_memory
,
get_max_shared_memory_bytes
...
...
@@ -161,6 +161,10 @@ class Worker:
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
...
...
@@ -180,6 +184,14 @@ class Worker:
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
)
categorized_sample_indices_start_idx
+=
1
input_tokens
.
append
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
...
...
@@ -205,14 +217,37 @@ class Worker:
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
for
seq_group_metadata
in
seq_group_metadata_list
:
if
seq_group_metadata
.
is_prompt
:
# We need to do this in this loop as we need to know max_seq_len
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
sampling_params
=
seq_group_metadata
.
sampling_params
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
continue
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
...
...
@@ -242,7 +277,6 @@ class Worker:
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
padded_input_tokens
=
[
_pad_to_max
(
tokens
,
max_seq_len
,
pad
=
0
)
for
tokens
in
input_tokens
]
...
...
@@ -272,6 +306,13 @@ class Worker:
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
categorized_sample_indices
=
{
t
:
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
...
...
@@ -288,6 +329,8 @@ class Worker:
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
sliding_window
=
self
.
sliding_window
,
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment