Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f136da15
Unverified
Commit
f136da15
authored
Jun 27, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 27, 2024
Browse files
[Hardware][TPU] Optimize KV cache swapping (#5878)
parent
c3dde367
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
22 deletions
+36
-22
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+4
-12
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+32
-10
No files found.
vllm/attention/backends/pallas.py
View file @
f136da15
...
...
@@ -28,21 +28,13 @@ class PallasAttentionBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
dst_kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
src_to_dst
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
src_k_cache
,
src_v_cache
=
src_kv_cache
dst_k_cache
,
dst_v_cache
=
dst_kv_cache
src_indices
,
dst_indices
=
src_to_dst
device
=
dst_k_cache
.
device
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_v_cache
,
True
)
dst_k_cache
[:,
dst_indices
]
=
src_k_cache
[:,
src_indices
].
to
(
device
)
dst_v_cache
[:,
dst_indices
]
=
src_v_cache
[:,
src_indices
].
to
(
device
)
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
...
...
vllm/worker/tpu_worker.py
View file @
f136da15
...
...
@@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.experimental.dynamo_set_buffer_donor
# noqa: F401
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
...
...
@@ -152,8 +153,8 @@ class TPUWorker(LoraNotSupportedWorkerBase):
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
self
.
cpu_cache
=
[]
self
.
tpu_cache
=
[]
self
.
cpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
self
.
tpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
cpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
...
...
@@ -227,18 +228,25 @@ class TPUWorker(LoraNotSupportedWorkerBase):
if
blocks_to_swap_in
:
# Swap from CPU to TPU.
src_
to_dst
=
_make_src_to_dst
(
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
src_
indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
for
i
in
range
(
num_layers
):
attn_backend
.
swap_blocks
(
self
.
cpu_cache
[
i
],
self
.
tpu_cache
[
i
],
src_to_dst
)
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
blocks_to_swap_out
:
# Swap from TPU to CPU.
src_
to_dst
=
_make_src_to_dst
(
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
src_
indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
for
i
in
range
(
num_layers
):
attn_backend
.
swap_blocks
(
self
.
tpu_cache
[
i
],
self
.
cpu_cache
[
i
],
src_to_dst
)
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
].
cpu
()
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
].
cpu
()
if
blocks_to_copy
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
self
.
device
)
...
...
@@ -267,3 +275,17 @@ def _make_src_to_dst(
device
=
dst_device
,
dtype
=
torch
.
int64
)
return
src_indices
,
dst_indices
@
torch
.
compile
(
backend
=
"openxla"
)
def
_insert_kv
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
tpu_k_cache
:
torch
.
Tensor
,
tpu_v_cache
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_v_cache
,
True
)
tpu_k_cache
[:,
indices
]
=
k
tpu_v_cache
[:,
indices
]
=
v
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment