Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
473e7b36
"csrc/cuda_utils.cpp" did not exist on "03ffd0a02251e10c1aa14fca8cb0ab1e4e40b886"
Unverified
Commit
473e7b36
authored
Oct 14, 2024
by
Woosuk Kwon
Committed by
GitHub
Oct 14, 2024
Browse files
[TPU] Fix TPU SMEM OOM by Pallas paged attention kernel (#9350)
parent
fd47e57f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
21 deletions
+89
-21
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+80
-21
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+9
-0
No files found.
vllm/attention/backends/pallas.py
View file @
473e7b36
...
...
@@ -208,35 +208,54 @@ class PallasAttentionBackendImpl(AttentionImpl):
else
:
# Decoding run.
assert
kv_cache
[
0
].
numel
()
>
0
query
=
query
.
squeeze
(
dim
=
1
)
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
if
self
.
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
self
.
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
assert
attn_metadata
.
block_tables
is
not
None
assert
attn_metadata
.
context_lens
is
not
None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE
=
512
*
1024
size_per_seq
=
4
*
attn_metadata
.
block_tables
.
shape
[
1
]
max_num_seq
=
MAX_SMEM_USAGE
//
size_per_seq
if
batch_size
<=
max_num_seq
:
output
=
paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
self
.
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
chunk_size
=
max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size
=
chunk_size
//
2
*
2
num_chunks
=
(
batch_size
+
chunk_size
-
1
)
//
chunk_size
output
=
torch
.
empty_like
(
query
)
for
chunk_idx
in
range
(
num_chunks
):
chunk_start
=
chunk_idx
*
chunk_size
chunk_end
=
chunk_start
+
chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output
=
paged_attention
(
query
[
chunk_start
:
chunk_end
],
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
context_lens
[
chunk_start
:
chunk_end
]
,
attn_metadata
.
block_tables
[
chunk_start
:
chunk_end
]
,
pages_per_compute_block
,
self
.
megacore_mode
,
)
output
[
chunk_start
:
chunk_end
]
=
chunk_output
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
...
...
@@ -258,3 +277,43 @@ def write_to_kv_cache(
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
def
paged_attention
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
pages_per_compute_block
:
int
,
megacore_mode
:
Optional
[
str
],
)
->
torch
.
Tensor
:
batch_size
=
query
.
shape
[
0
]
if
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
)
return
output
vllm/worker/tpu_model_runner.py
View file @
473e7b36
...
...
@@ -123,6 +123,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
smem_size
=
512
*
1024
block_table_size
=
4
*
self
.
block_tables
.
size
if
block_table_size
>=
smem_size
:
logger
.
warning
(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value."
,
self
.
model_config
.
max_model_len
)
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment