Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4634c872
"tests/vscode:/vscode.git/clone" did not exist on "afccc9d434d26750d39a1391b833e1042a526121"
Unverified
Commit
4634c872
authored
Jul 18, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 18, 2024
Browse files
[TPU] Refactor TPU worker & model runner (#6506)
parent
c8a7d51c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
272 additions
and
166 deletions
+272
-166
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+200
-97
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+72
-69
No files found.
vllm/worker/tpu_model_runner.py
View file @
4634c872
import
time
import
time
from
typing
import
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -12,10 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
...
@@ -12,10 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -27,7 +34,44 @@ _ENABLE_TOP_P = False
...
@@ -27,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES
=
128
_MAX_NUM_SAMPLES
=
128
class
TPUModelRunner
:
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
attn_metadata
:
AttentionMetadata
input_lens
:
torch
.
Tensor
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"token_ids"
:
self
.
token_ids
,
"position_ids"
:
self
.
position_ids
,
"input_lens"
:
self
.
input_lens
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForTPU"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForTPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
TPUModelRunner
(
ModelRunnerBase
[
ModelInputForTPU
]):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -79,6 +123,7 @@ class TPUModelRunner:
...
@@ -79,6 +123,7 @@ class TPUModelRunner:
multimodal_config
=
self
.
multimodal_config
,
multimodal_config
=
self
.
multimodal_config
,
lora_config
=
None
,
lora_config
=
None
,
)
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
model
=
ModelWrapper
(
model
)
...
@@ -147,8 +192,8 @@ class TPUModelRunner:
...
@@ -147,8 +192,8 @@ class TPUModelRunner:
# Dummy run.
# Dummy run.
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
input_
le
n
s
,
t
,
p
,
num_sampl
es
)
num_samp
les
,
kv_cach
es
)
def
warmup_model
(
def
warmup_model
(
self
,
self
,
...
@@ -177,7 +222,7 @@ class TPUModelRunner:
...
@@ -177,7 +222,7 @@ class TPUModelRunner:
# Decode
# Decode
start
=
time
.
time
()
start
=
time
.
time
()
seq_len
=
1
seq_len
=
1
batch_size
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
...
@@ -195,10 +240,10 @@ class TPUModelRunner:
...
@@ -195,10 +240,10 @@ class TPUModelRunner:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
...
@@ -212,50 +257,46 @@ class TPUModelRunner:
...
@@ -212,50 +257,46 @@ class TPUModelRunner:
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
app
end
(
prompt_tokens
)
input_tokens
.
ext
end
(
prompt_tokens
)
input_positions
.
app
end
(
list
(
range
(
prompt_len
)))
input_positions
.
ext
end
(
list
(
range
(
prompt_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
slot_mapping
.
append
([])
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
slot_mapping
.
append
(
slot
)
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len
=
_get_padded_prefill_len
(
prompt_len
)
num_paddings
=
padded_prompt_len
-
prompt_len
input_tokens
+=
[
0
]
*
num_paddings
input_positions
+=
[
0
]
*
num_paddings
slot_mapping
+=
[
_PAD_SLOT_ID
]
*
num_paddings
assert
len
(
prompt_lens
)
>
0
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
# Add paddings to make the shape [batch_size, max_prompt_len] where
# max_prompt_len is smallest power of 2 that is greater than or equal
# to the maximum prompt length.
# We need the 2D input shape because the Pallas FlashAttention kernel
# does not support packed 1D inputs.
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
max_prompt_len
=
_get_padded_prefill_len
(
max
(
prompt_lens
))
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
slot_mapping
=
make_tensor_with_pad
(
slot_mapping
,
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
"cpu"
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
# NOTE: This is not used.
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
block_tables
=
None
,
block_tables
=
None
,
...
@@ -306,22 +347,22 @@ class TPUModelRunner:
...
@@ -306,22 +347,22 @@ class TPUModelRunner:
input_tokens
=
torch
.
tensor
(
input_tokens
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
...
@@ -382,16 +423,18 @@ class TPUModelRunner:
...
@@ -382,16 +423,18 @@ class TPUModelRunner:
t
+=
[
1.0
]
*
num_paddings
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
best_of
return
t
,
p
,
best_of
def
_execut
e_model
(
def
prepar
e_model
_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
virtual_engine
:
int
=
0
,
)
->
List
[
CompletionSequenceGroupOutput
]:
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
# Prepare inputs.
)
->
ModelInputForTPU
:
del
finished_requests_ids
# Unused.
assert
virtual_engine
==
0
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
...
@@ -400,16 +443,104 @@ class TPUModelRunner:
...
@@ -400,16 +443,104 @@ class TPUModelRunner:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
seq_groups
=
[
list
(
metadata
.
seq_data
.
keys
())
for
metadata
in
seq_group_metadata_list
]
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
best_of
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
model_input
=
ModelInputForTPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
def
execute_model
(
self
,
model_input
:
ModelInputForTPU
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
def
_execute_model
(
*
args
,
clone
:
bool
=
False
)
->
torch
.
Tensor
:
"""Move input args from CPU to device and execute the model."""
def
_copy_to_device
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
clone
:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x
=
x
.
clone
()
return
x
.
to
(
self
.
device
)
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
arg
=
_copy_to_device
(
arg
)
elif
isinstance
(
arg
,
AttentionMetadata
):
arg
.
slot_mapping
=
_copy_to_device
(
arg
.
slot_mapping
)
if
getattr
(
arg
,
"block_tables"
,
None
)
is
not
None
:
arg
.
block_tables
=
_copy_to_device
(
arg
.
block_tables
)
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
_copy_to_device
(
arg
.
context_lens
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
)
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
is_prompt
=
num_prefills
>
0
if
is_prompt
:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids
=
[]
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
model_input
.
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
]
model_input
.
attn_metadata
.
num_prefills
=
1
output_token_ids
=
_execute_model
(
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
,
clone
=
True
)
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
else
:
# Execute the model.
# Execute the model.
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
output_token_ids
=
_execute_model
(
*
inputs
[
2
:],
t
,
p
,
num_samples
)
model_input
.
token_ids
,
model_input
.
position_ids
,
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
# Retrieve the outputs to CPU.
# Retrieve the outputs to CPU.
next_token_ids
=
nex
t_token_ids
.
cpu
().
tolist
()
next_token_ids
=
outpu
t_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# The TPU backend does not reuse the sampler, since the TPU backend
...
@@ -417,13 +548,13 @@ class TPUModelRunner:
...
@@ -417,13 +548,13 @@ class TPUModelRunner:
zero_logprob
=
Logprob
(
0.0
)
zero_logprob
=
Logprob
(
0.0
)
batch_idx
=
0
batch_idx
=
0
sampler_outputs
=
[]
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group
in
model_input
.
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
seq_outputs
=
[]
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
if
is_prompt
:
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
for
i
in
range
(
best_of
[
batch_idx
]):
for
i
in
range
(
model_input
.
best_of
[
batch_idx
]):
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
SequenceOutput
(
seq_id
,
next_token_id
,
...
@@ -438,35 +569,6 @@ class TPUModelRunner:
...
@@ -438,35 +569,6 @@ class TPUModelRunner:
batch_idx
+=
1
batch_idx
+=
1
sampler_outputs
.
append
(
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
sampler_outputs
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
if
seq_group_metadata_list
[
0
].
is_prompt
:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampler_outputs
+=
self
.
_execute_model
([
seq_group_metadata
],
kv_caches
)
else
:
sampler_outputs
=
self
.
_execute_model
(
seq_group_metadata_list
,
kv_caches
)
return
[
SamplerOutput
(
sampler_outputs
)]
return
[
SamplerOutput
(
sampler_outputs
)]
...
@@ -474,36 +576,37 @@ class ModelWrapper(nn.Module):
...
@@ -474,36 +576,37 @@ class ModelWrapper(nn.Module):
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
super
().
__init__
()
self
.
model
=
model
.
eval
()
self
.
model
=
model
def
forward
(
def
forward
(
self
,
self
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model and samples the next token.
Args:
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
"""
batch_size
,
seq_len
=
token_ids
.
shape
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
# Calculate the positions to sample from.
base
_indicies
=
torch
.
arange
(
start
_indicies
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
base
_indicies
+
input_lens
-
1
logits_indices
=
start
_indicies
+
input_lens
-
1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
# sampler and sampling metadata.
...
...
vllm/worker/tpu_worker.py
View file @
4634c872
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoraNotSupportedWorkerBase
):
class
TPUWorker
(
LoraNotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -57,7 +58,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -57,7 +58,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
model_config
,
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
device_config
,
device_config
,
...
@@ -196,40 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -196,40 +198,48 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
def
execute_model
(
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
# TODO(woosuk): Support TP.
return
False
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
def
prepare_worker_input
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
WorkerInput
:
if
not
self
.
is_driver_worker
:
virtual_engine
=
execute_model_req
.
virtual_engine
self
.
_execute_model_non_driver
()
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
return
[]
blocks_to_swap_in
=
_make_src_to_dst
(
assert
execute_model_req
is
not
None
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
# Issue cache operations.
blocks_to_swap_out
=
_make_src_to_dst
(
self
.
cache_swap
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
execute_model_req
.
blocks_to_swap_in
,
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
self
.
device
)
execute_model_req
.
blocks_to_copy
,
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
)
# Run the model.
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
assert
len
(
seq_group_metadata_list
)
>
0
virtual_engine
=
worker_input
.
virtual_engine
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
assert
virtual_engine
==
0
self
.
tpu_cache
)
return
output
def
cache_swap
(
self
,
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
attn_backend
=
self
.
model_runner
.
attn_backend
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
if
blocks_to_swap_in
:
# Issue cache operations.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
if
src_indices
.
numel
()
>
0
:
# Swap from CPU to TPU.
# Swap from CPU to TPU.
src_indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
...
@@ -237,28 +247,21 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -237,28 +247,21 @@ class TPUWorker(LoraNotSupportedWorkerBase):
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
blocks_to_swap_out
:
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
if
src_indices
.
numel
()
>
0
:
# Swap from TPU to CPU.
# Swap from TPU to CPU.
src_indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
].
cpu
()
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
].
cpu
()
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
blocks_to_copy
:
if
worker_input
.
blocks_to_copy
is
not
None
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
self
.
device
)
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
src_to_dst
)
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
(
src_indices
,
dst_indices
))
def
start_worker_execution_loop
(
self
)
->
None
:
while
self
.
_execute_model_non_driver
():
pass
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
def
_make_src_to_dst
(
def
_make_src_to_dst
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment