Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f77b6cf
Unverified
Commit
2f77b6cf
authored
Nov 20, 2024
by
Woosuk Kwon
Committed by
GitHub
Nov 20, 2024
Browse files
[TPU] Implement prefix caching for TPUs (#10307)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
c68f7ede
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
182 additions
and
105 deletions
+182
-105
requirements-tpu.txt
requirements-tpu.txt
+3
-3
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+43
-23
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+134
-77
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+2
-2
No files found.
requirements-tpu.txt
View file @
2f77b6cf
...
@@ -16,8 +16,8 @@ ray[default]
...
@@ -16,8 +16,8 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241
028
+cpu
torch==2.6.0.dev20241
114
+cpu
torchvision==0.20.0.dev20241
028
+cpu
torchvision==0.20.0.dev20241
114
+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241
028
-cp310-cp310-linux_x86_64.whl
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241
114
-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829
jax==0.4.32.dev20240829
vllm/attention/backends/pallas.py
View file @
2f77b6cf
...
@@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata):
...
@@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata):
# or all decoding.
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
effective_query_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
...
@@ -72,8 +73,6 @@ class PallasMetadata(AttentionMetadata):
...
@@ -72,8 +73,6 @@ class PallasMetadata(AttentionMetadata):
return
None
return
None
assert
self
.
num_decode_tokens
==
0
assert
self
.
num_decode_tokens
==
0
assert
self
.
block_tables
is
None
assert
self
.
context_lens
is
None
return
self
return
self
@
property
@
property
...
@@ -186,29 +185,50 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -186,29 +185,50 @@ class PallasAttentionBackendImpl(AttentionImpl):
query
=
query
*
self
.
scale
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
num_prefills
>
0
:
assert
seq_len
%
16
==
0
,
(
if
attn_metadata
.
block_tables
is
None
:
"Pallas FlashAttention kernel requires seq_len to be a "
# Prefill without paged KV cache.
f
"multiple of 16 but got
{
seq_len
}
"
)
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
# Handle GQA/MQA.
f
"multiple of 16 but got
{
seq_len
}
"
)
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
# Handle GQA/MQA.
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
if
self
.
num_kv_heads
!=
self
.
num_heads
:
self
.
head_size
)
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
self
.
head_size
)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
# while the input is [batch_size, seq_len, num_heads, d_model].
dim
=-
2
)
# Permute the input to match the required format.
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
output
=
torch
.
ops
.
xla
.
flash_attention
(
self
.
head_size
)
query
.
permute
(
0
,
2
,
1
,
3
),
# FlashAttention kernel requires the input shape to be
key
.
permute
(
0
,
2
,
1
,
3
),
# [batch_size, num_heads, seq_len, d_model]
value
.
permute
(
0
,
2
,
1
,
3
),
# while the input is [batch_size, seq_len, num_heads, d_model].
True
,
# Permute the input to match the required format.
)
output
=
torch
.
ops
.
xla
.
flash_attention
(
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block
=
16
num_queries_per_compute_block
=
16
assert
seq_len
%
num_queries_per_compute_block
==
0
output
=
torch
.
ops
.
xla
.
multi_queries_paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
effective_query_lens
,
num_kv_pages_per_compute_block
,
num_queries_per_compute_block
,
use_kernel
=
True
,
)
else
:
else
:
# Decoding run.
# Decoding run.
assert
kv_cache
[
0
].
numel
()
>
0
assert
kv_cache
[
0
].
numel
()
>
0
...
...
vllm/worker/tpu_model_runner.py
View file @
2f77b6cf
import
enum
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
...
@@ -11,7 +12,6 @@ import torch_xla.core.xla_model as xm
...
@@ -11,7 +12,6 @@ import torch_xla.core.xla_model as xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -39,6 +39,15 @@ _ENABLE_TOP_P = False
...
@@ -39,6 +39,15 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES
=
128
_MAX_NUM_SAMPLES
=
128
class
ExecutionMode
(
enum
.
Enum
):
PREFILL
=
enum
.
auto
()
DECODE
=
enum
.
auto
()
PREFIX_PREFILL
=
enum
.
auto
()
def
is_prefill
(
self
)
->
bool
:
return
self
in
(
ExecutionMode
.
PREFILL
,
ExecutionMode
.
PREFIX_PREFILL
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
token_ids
:
torch
.
Tensor
...
@@ -140,16 +149,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -140,16 +149,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
self
.
model
=
ModelWrapper
(
model
,
self
.
vllm_config
)
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
seq_len
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
is_prompt
:
bool
,
exec_mode
:
ExecutionMode
,
)
->
None
:
)
->
None
:
if
is_prompt
:
exec_mode
=
ExecutionMode
(
exec_mode
)
if
exec_mode
.
is_prefill
():
seq_len
=
(
seq_len
+
15
)
//
16
*
16
seq_len
=
(
seq_len
+
15
)
//
16
*
16
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -160,18 +174,38 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -160,18 +174,38 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
block_tables
=
None
,
context_lens
=
None
,
)
input_lens
=
torch
.
ones
((
batch_size
,
),
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
self
.
device
)
if
exec_mode
==
ExecutionMode
.
PREFILL
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
block_tables
=
None
,
context_lens
=
None
,
effective_query_lens
=
None
,
)
else
:
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
effective_query_lens
=
torch
.
ones_like
(
context_lens
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
effective_query_lens
=
effective_query_lens
,
)
else
:
else
:
assert
seq_len
==
1
assert
seq_len
==
1
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
...
@@ -204,7 +238,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -204,7 +238,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
)
t
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
t
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
exec_mode
.
is_prefill
()
else
1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
...
@@ -213,7 +247,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -213,7 +247,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
# be re-compiled for every different shapes. This overhead is inevitable
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if
is_prompt
:
if
exec_mode
.
is_prefill
()
:
# Prefll
# Prefll
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
1
)
...
@@ -229,15 +263,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -229,15 +263,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
# Dummy run.
# Dummy run.
self
.
model
(
token_ids
,
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
position_ids
,
num_samples
,
kv_caches
)
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
,
is_prompt
=
is_prompt
)
def
warmup_model
(
def
warmup_model
(
self
,
self
,
...
@@ -248,13 +275,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -248,13 +275,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
start
=
time
.
time
()
start
=
time
.
time
()
for
batch_size
in
[
1
]:
for
batch_size
in
[
1
]:
seq_len
=
16
seq_len
=
16
while
True
:
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
True
)
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFILL
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
if
seq_len
>=
self
.
model_config
.
max_model_len
:
break
num_tokens
=
batch_size
*
seq_len
num_tokens
=
batch_size
*
seq_len
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
:
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
:
break
break
...
@@ -263,12 +290,39 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -263,12 +290,39 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
end
=
time
.
time
()
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefill done in %.2f s."
,
end
-
start
)
logger
.
info
(
"Compilation for prefill done in %.2f s."
,
end
-
start
)
# Prefix prefill
if
self
.
cache_config
.
enable_prefix_caching
:
logger
.
info
(
"Compiling the model with different input shapes for "
"prefix prefill..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFIX_PREFILL
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
num_tokens
=
batch_size
*
seq_len
if
(
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
):
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefix prefill done in %.2f s."
,
end
-
start
)
# Decode
# Decode
start
=
time
.
time
()
start
=
time
.
time
()
seq_len
=
1
seq_len
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
DECODE
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
...
@@ -287,9 +341,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -287,9 +341,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
batch_idx
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
...
@@ -298,19 +354,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -298,19 +354,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
# Could include output tokens when a request is preempted.
# Could include output tokens when a request is preempted.
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_tokens
=
seq_data
.
get_token_ids
()
seq_len
=
len
(
prompt_tokens
)
num_computed_blocks
=
len
(
seq_group_metadata
.
computed_block_nums
)
num_computed_tokens
=
num_computed_blocks
*
self
.
block_size
if
num_computed_tokens
>
0
:
prompt_tokens
=
prompt_tokens
[
num_computed_tokens
:]
context_lens
.
append
(
seq_len
)
else
:
context_lens
.
append
(
0
)
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
)
input_tokens
.
extend
(
prompt_tokens
)
input_positions
.
extend
(
list
(
range
(
prompt
_len
))
)
input_positions
.
extend
(
range
(
num_computed_tokens
,
seq
_len
))
assert
seq_group_metadata
.
block_tables
is
not
None
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt
_len
):
for
i
in
range
(
num_computed_tokens
,
seq
_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
if
num_computed_tokens
>
0
:
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# greater than or equal to the prompt length.
...
@@ -338,14 +406,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -338,14 +406,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
multi_modal_placeholder_index_maps
=
None
,
block_tables
=
None
,
block_tables
=
block_tables
,
context_lens
=
None
,
context_lens
=
context_lens
,
effective_query_lens
=
prompt_lens
,
)
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
...
@@ -550,6 +625,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -550,6 +625,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
# process them separately. This is a temporary hack that should be
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
# optimized by using SplashAttention.
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
orig_block_tables
=
model_input
.
attn_metadata
.
block_tables
orig_context_lens
=
model_input
.
attn_metadata
.
context_lens
orig_effective_query_lens
=
\
model_input
.
attn_metadata
.
effective_query_lens
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
start_idx
=
0
next_token_ids
=
[]
next_token_ids
=
[]
...
@@ -568,18 +647,24 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -568,18 +647,24 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
attn_metadata
.
num_prefills
=
1
attn_metadata
.
num_prefills
=
1
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
if
orig_context_lens
[
i
].
item
()
>
0
:
attn_metadata
.
context_lens
=
orig_context_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
attn_metadata
.
block_tables
=
orig_block_tables
[
i
].
unsqueeze
(
0
).
to
(
self
.
device
)
attn_metadata
.
effective_query_lens
=
\
orig_effective_query_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
else
:
attn_metadata
.
context_lens
=
None
attn_metadata
.
block_tables
=
None
attn_metadata
.
effective_query_lens
=
None
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
output_token_ids
=
self
.
model
(
token_ids
,
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
model_input
.
num_samples
,
kv_caches
,
kv_caches
)
is_prompt
=
True
)
next_token_ids
.
append
(
output_token_ids
[
0
])
next_token_ids
.
append
(
output_token_ids
[
0
])
start_idx
=
end_idx
start_idx
=
end_idx
...
@@ -624,15 +709,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -624,15 +709,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
for
i
in
range
(
num_steps
):
for
i
in
range
(
num_steps
):
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
attn_metadata
.
slot_mapping
output_token_ids
=
self
.
model
(
token_ids
,
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
attn_metadata
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
model_input
.
num_samples
,
kv_caches
,
kv_caches
)
is_prompt
=
False
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
if
i
<
num_steps
-
1
:
if
i
<
num_steps
-
1
:
...
@@ -667,34 +747,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -667,34 +747,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
return
[
sampler_output
]
return
[
sampler_output
]
class
ModelWrapper
(
TorchCompileWrapperWithCustomDispatcher
):
class
ModelWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
,
vllm_config
:
VllmConfig
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
self
.
model
=
model
compiled_callable
=
torch
.
compile
(
self
.
forward
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
super
().
__init__
(
compiled_callable
,
compilation_level
=
vllm_config
.
compilation_config
.
level
)
def
__call__
(
self
,
*
args
,
is_prompt
:
bool
,
**
kwargs
):
if
len
(
self
.
compiled_codes
)
<
3
or
not
self
.
use_custom_dispatcher
:
# not fully compiled yet, or not using the custom dispatcher,
# let PyTorch handle it
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
# the 3 compiled codes are:
# 0: for profiling
# 1: for prompt
# 2: for decode
# dispatch to the compiled code directly, skip PyTorch
if
is_prompt
:
with
self
.
dispatch_to_code
(
1
):
return
self
.
forward
(
*
args
,
**
kwargs
)
else
:
with
self
.
dispatch_to_code
(
2
):
return
self
.
forward
(
*
args
,
**
kwargs
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/worker/tpu_worker.py
View file @
2f77b6cf
...
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
...
@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.tpu_model_runner
import
ExecutionMode
,
TPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerBase
,
WorkerInput
)
WorkerInput
)
...
@@ -112,7 +112,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -112,7 +112,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
batch_size
=
1
,
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
is_prompt
=
True
,
exec_mode
=
ExecutionMode
.
PREFILL
,
)
)
# Synchronize before measuring the memory usage.
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment