Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
50e9ec41
"vscode:/vscode.git/clone" did not exist on "cb84e45ac75b42ba6795145923e8eb323bb825ad"
Unverified
Commit
50e9ec41
authored
Sep 14, 2024
by
Woosuk Kwon
Committed by
GitHub
Sep 14, 2024
Browse files
[TPU] Implement multi-step scheduling (#8489)
parent
47790f3e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
279 additions
and
76 deletions
+279
-76
vllm/config.py
vllm/config.py
+1
-1
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+6
-2
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+11
-5
vllm/worker/multi_step_tpu_worker.py
vllm/worker/multi_step_tpu_worker.py
+105
-0
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+156
-68
No files found.
vllm/config.py
View file @
50e9ec41
...
...
@@ -379,7 +379,7 @@ class ModelConfig:
self
.
use_async_output_proc
=
False
return
if
self
.
enforce_eager
:
if
device_config
.
device_type
==
"cuda"
and
self
.
enforce_eager
:
logger
.
warning
(
"To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output "
...
...
vllm/executor/ray_tpu_executor.py
View file @
50e9ec41
...
...
@@ -68,8 +68,12 @@ class RayTPUExecutor(TPUExecutor):
)
assert
self
.
speculative_config
is
None
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
if
self
.
scheduler_config
.
is_multi_step
:
worker_module_name
=
"vllm.worker.multi_step_tpu_worker"
worker_class_name
=
"MultiStepTPUWorker"
else
:
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
...
...
vllm/executor/tpu_executor.py
View file @
50e9ec41
...
...
@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
,
):
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
if
self
.
scheduler_config
.
is_multi_step
:
from
vllm.worker.multi_step_tpu_worker
import
MultiStepTPUWorker
worker
=
MultiStepTPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
else
:
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
def
initialize_cache
(
self
,
...
...
vllm/worker/multi_step_tpu_worker.py
0 → 100644
View file @
50e9ec41
import
dataclasses
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.tpu_model_runner
import
ModelInputForTPU
from
vllm.worker.tpu_worker
import
TPUWorker
from
vllm.worker.worker_base
import
WorkerInput
class
MultiStepTPUWorker
(
TPUWorker
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cached_model_input
:
Optional
[
ModelInputForTPU
]
=
None
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]:
assert
self
.
is_driver_worker
assert
execute_model_req
.
virtual_engine
==
0
is_first_multi_step
=
execute_model_req
.
is_first_multi_step
is_last_step
=
execute_model_req
.
is_last_step
if
is_first_multi_step
:
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
worker_input
=
dataclasses
.
replace
(
worker_input
,
num_steps
=
execute_model_req
.
num_lookahead_slots
+
1
)
model_input
:
ModelInputForTPU
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
assert
self
.
cached_model_input
is
not
None
model_input
=
self
.
cached_model_input
worker_input
=
WorkerInput
()
model_input
=
dataclasses
.
replace
(
model_input
,
is_first_multi_step
=
is_first_multi_step
,
is_last_step
=
is_last_step
)
if
self
.
do_metadata_broadcast
:
if
is_first_multi_step
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
broadcast_data
=
{
"is_first_multi_step"
:
is_first_multi_step
,
"is_last_step"
:
is_last_step
,
}
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
({},
src
=
0
)
return
None
model_input
,
worker_input
,
_
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
if
model_input
.
is_first_multi_step
:
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
else
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
if
len
(
broadcast_data
)
==
2
:
assert
self
.
cached_model_input
is
not
None
self
.
cached_model_input
=
dataclasses
.
replace
(
self
.
cached_model_input
,
is_first_multi_step
=
broadcast_data
[
"is_first_multi_step"
],
is_last_step
=
broadcast_data
[
"is_last_step"
])
empty_worker_input
=
WorkerInput
()
return
self
.
cached_model_input
,
empty_worker_input
,
{}
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
vllm/worker/tpu_model_runner.py
View file @
50e9ec41
...
...
@@ -51,6 +51,8 @@ class ModelInputForTPU(ModelRunnerInputBase):
num_samples
:
int
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
...
...
@@ -65,6 +67,8 @@ class ModelInputForTPU(ModelRunnerInputBase):
"num_samples"
:
self
.
num_samples
,
"best_of"
:
self
.
best_of
,
"seq_groups"
:
self
.
seq_groups
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
"virtual_engine"
:
self
.
virtual_engine
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
...
...
@@ -118,6 +122,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self
.
block_size
,
False
,
)
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
...
...
@@ -518,97 +523,159 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
def
_execute_model
(
*
args
):
"""Move input args from CPU to device and execute the model."""
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
arg
=
arg
.
to
(
self
.
device
)
elif
isinstance
(
arg
,
AttentionMetadata
):
arg
.
slot_mapping
=
arg
.
slot_mapping
.
to
(
self
.
device
)
if
getattr
(
arg
,
"block_tables"
,
None
)
is
not
None
:
arg
.
block_tables
=
arg
.
block_tables
.
to
(
self
.
device
)
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
arg
.
context_lens
.
to
(
self
.
device
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
,
is_prompt
=
is_prompt
)
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
is_prompt
=
num_prefills
>
0
if
not
model_input
.
is_first_multi_step
:
if
not
model_input
.
is_last_step
:
return
[]
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
return
sampler_outputs
is_prompt
=
model_input
.
attn_metadata
.
num_prefills
>
0
if
is_prompt
:
assert
num_steps
==
1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids
=
[]
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
next_token_ids
=
[]
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
model_input
.
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
]
model_input
.
attn_metadata
.
num_prefills
=
1
output_token_ids
=
_execute_model
(
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
)
if
i
==
0
and
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
token_ids
=
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
num_prefills
=
1
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
,
is_prompt
=
True
)
next_token_ids
.
append
(
output_token_ids
[
0
])
start_idx
=
end_idx
else
:
# Execute the model.
output_token_ids
=
_execute_model
(
model_input
.
token_ids
,
model_input
.
position_ids
,
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
zero_logprob
=
Logprob
(
0.0
)
batch_idx
=
0
sampler_outputs
=
[]
for
seq_group
in
model_input
.
seq_groups
:
s
eq_ids
=
seq_group
seq_outputs
=
[]
if
is_prompt
:
next_token_ids
=
[
output_token_ids
.
cpu
().
tolist
()
for
output_token_ids
in
next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob
=
Logprob
(
0.0
)
s
ampler_outputs
=
[]
for
i
,
seq_group
in
enumerate
(
model_input
.
seq_groups
):
seq_ids
=
seq_group
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
for
i
in
range
(
model_input
.
best_of
[
batch_idx
]):
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
seq_outputs
=
[]
for
j
in
range
(
model_input
.
best_of
[
i
]):
next_token_id
=
next_token_ids
[
i
][
j
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
else
:
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
[
SamplerOutput
(
sampler_outputs
)]
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
[
SamplerOutput
(
sampler_outputs
)]
else
:
token_ids
=
model_input
.
token_ids
.
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
.
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
slot_mapping
=
attn_metadata
.
slot_mapping
.
to
(
self
.
device
)
attn_metadata
.
block_tables
=
attn_metadata
.
block_tables
.
to
(
self
.
device
)
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
.
to
(
self
.
device
)
t
=
model_input
.
t
.
to
(
self
.
device
)
p
=
model_input
.
p
.
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
for
i
in
range
(
num_steps
):
slot_mapping
=
attn_metadata
.
slot_mapping
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
,
is_prompt
=
False
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
if
i
<
num_steps
-
1
:
# Prepare the inputs for the next step.
token_ids
=
output_token_ids
.
unsqueeze
(
dim
=
1
).
int
()
position_ids
=
position_ids
+
1
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
+
1
block_tables
=
attn_metadata
.
block_tables
block_number
=
block_tables
.
gather
(
1
,
position_ids
.
long
()
//
self
.
block_size
)
block_offset
=
position_ids
%
self
.
block_size
is_padding
=
slot_mapping
==
_PAD_SLOT_ID
slot_mapping
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
=
slot_mapping
.
long
()
slot_mapping
=
torch
.
where
(
is_padding
,
_PAD_SLOT_ID
,
slot_mapping
)
attn_metadata
.
slot_mapping
=
slot_mapping
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
if
num_steps
>
1
:
return
[]
# Retrieve the outputs to CPU.
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
return
[
sampler_output
]
class
ModelWrapper
(
TorchCompileWrapperWithCustomDispatcher
):
...
...
@@ -756,3 +823,24 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
cutoff_logit
=
torch
.
gather
(
logits_sorted
,
-
1
,
cutoff_index
)
logits
=
logits
.
masked_fill_
(
logits
<
cutoff_logit
,
-
float
(
"inf"
))
return
logits
def
_make_decode_output
(
next_token_ids
:
List
[
int
],
seq_groups
:
List
[
List
[
int
]],
)
->
SamplerOutput
:
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
batch_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
SamplerOutput
(
sampler_outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment