Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c1f7600
Unverified
Commit
7c1f7600
authored
Mar 28, 2025
by
yarongmu-google
Committed by
GitHub
Mar 28, 2025
Browse files
[Kernel][TPU][ragged-paged-attn] vLLM code change for PR#8896 (#15659)
Signed-off-by:
Yarong Mu
<
ymu@google.com
>
parent
da461f3c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
37 deletions
+37
-37
requirements/tpu.txt
requirements/tpu.txt
+6
-6
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+22
-21
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+5
-6
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+4
-4
No files found.
requirements/tpu.txt
View file @
7c1f7600
...
...
@@ -17,9 +17,9 @@ ray[data]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
19
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
19
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
19
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
19
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
19
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
19
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
28
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
28
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev202503
28
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
28
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
28
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev202503
28
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
vllm/v1/attention/backends/pallas.py
View file @
7c1f7600
...
...
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
*
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -132,7 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -142,14 +142,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads * head_size])
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if
kv_cache
[
0
]
.
numel
()
==
0
:
if
kv_cache
.
numel
()
==
0
:
if
output
is
None
:
output
=
torch
.
ones_like
(
query
)
return
output
...
...
@@ -158,15 +157,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
kv_cache
if
kv_cache
[
0
].
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
k
ey_cache
,
value
_cache
,
slot_mapping
)
write_to_kv_cache
(
key
,
value
,
k
v
_cache
,
slot_mapping
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
key_cache
,
value_cache
,
kv_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
query_start_loc
,
...
...
@@ -183,23 +180,27 @@ class PallasAttentionBackendImpl(AttentionImpl):
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
_
,
_
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
num_kv_heads
=
num_combined_kv_heads
//
2
key_cache
=
key_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
head_size
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache
,
True
)
kv_cache
=
kv_cache
.
flatten
(
0
,
1
)
kv_cache
.
index_copy_
(
0
,
slot_mapping
,
kv
)
vllm/v1/worker/tpu_model_runner.py
View file @
7c1f7600
...
...
@@ -861,12 +861,11 @@ class TPUModelRunner:
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
)
dtype
=
kv_cache_spec
.
dtype
tpu_k_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
tpu_kv_cache
=
torch
.
zeros
(
kv_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k
_cache
,
tpu_
v_cache
)
kv_caches
[
layer_name
]
=
tpu_kv_cache
else
:
raise
NotImplementedError
...
...
@@ -893,7 +892,7 @@ class ModelWrapperV1(nn.Module):
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
list
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
kv_caches
:
list
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model.
...
...
vllm/v1/worker/tpu_worker.py
View file @
7c1f7600
...
...
@@ -136,10 +136,10 @@ class TPUWorker:
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k
_cache
,
tpu_
v_cache
)
tpu_k
v
_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
tpu_kv_cache
else
:
raise
NotImplementedError
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment