Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0e96aae
Unverified
Commit
b0e96aae
authored
Mar 19, 2025
by
iefgnoix
Committed by
GitHub
Mar 19, 2025
Browse files
[V1][TPU] Change kv cache shape. (#15145)
Signed-off-by:
Xiongfei Wei
<
isaacwxf23@gmail.com
>
parent
8310e0b5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
16 deletions
+13
-16
requirements/tpu.txt
requirements/tpu.txt
+6
-6
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+7
-10
No files found.
requirements/tpu.txt
View file @
b0e96aae
...
@@ -17,9 +17,9 @@ ray[data]
...
@@ -17,9 +17,9 @@ ray[data]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
4%2Bcxx11
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
9
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
4%2Bcxx11
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
9
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
4%2Bcxx11
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev2025031
9
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
4%2Bcxx11
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
9
-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
4%2Bcxx11
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
9
-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
4%2Bcxx11
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev2025031
9
-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
vllm/v1/attention/backends/pallas.py
View file @
b0e96aae
...
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
...
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
head_size
)
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
...
@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads
,
head_size],
kv_cache = ([num_blocks, block_size, num_kv_heads
*
head_size],
[num_blocks, block_size, num_kv_heads
,
head_size])
[num_blocks, block_size, num_kv_heads
*
head_size])
attn_metadata: Metadata for attention.
attn_metadata: Metadata for attention.
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
...
@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
num_tokens
,
self
.
num_kv_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
if
kv_cache
[
0
].
numel
()
>
0
:
if
kv_cache
[
0
].
numel
()
>
0
:
...
@@ -192,10 +190,10 @@ def write_to_kv_cache(
...
@@ -192,10 +190,10 @@ def write_to_kv_cache(
""" Write the key and values to the KV cache.
""" Write the key and values to the KV cache.
Args:
Args:
key: shape = [num_tokens, num_kv_heads
,
head_size]
key: shape = [num_tokens, num_kv_heads
*
head_size]
value: shape = [num_tokens, num_kv_heads
,
head_size]
value: shape = [num_tokens, num_kv_heads
*
head_size]
k_cache = [num_blocks, block_size, num_kv_heads
,
head_size]
k_cache = [num_blocks, block_size, num_kv_heads
*
head_size]
v_cache = [num_blocks, block_size, num_kv_heads
,
head_size]
v_cache = [num_blocks, block_size, num_kv_heads
*
head_size]
"""
"""
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
...
@@ -203,6 +201,5 @@ def write_to_kv_cache(
...
@@ -203,6 +201,5 @@ def write_to_kv_cache(
key_cache
=
key_cache
.
flatten
(
0
,
1
)
key_cache
=
key_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
slot_mapping
=
slot_mapping
.
flatten
()
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment