Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5c8628f
Unverified
Commit
f5c8628f
authored
Jun 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 26, 2024
Browse files
[Bugfix][TPU] Fix CPU cache allocation (#5869)
parent
cbc53b6b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
5 deletions
+8
-5
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+2
-3
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+6
-2
No files found.
vllm/attention/backends/pallas.py
View file @
f5c8628f
...
...
@@ -37,11 +37,10 @@ class PallasAttentionBackend(AttentionBackend):
)
->
None
:
src_k_cache
,
src_v_cache
=
src_kv_cache
dst_k_cache
,
dst_v_cache
=
dst_kv_cache
src_indices
,
dst_indices
=
src_to_dst
device
=
dst_k_cache
.
device
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
dst_v_cache
,
True
)
device
=
dst_k_cache
.
device
src_indices
,
dst_indices
=
src_to_dst
dst_k_cache
[:,
dst_indices
]
=
src_k_cache
[:,
src_indices
].
to
(
device
)
dst_v_cache
[:,
dst_indices
]
=
src_v_cache
[:,
src_indices
].
to
(
device
)
...
...
vllm/worker/tpu_worker.py
View file @
f5c8628f
...
...
@@ -156,14 +156,18 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
tpu_cache
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
cpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_cpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
for
_
in
range
(
num_layers
):
tpu_k_cache
=
torch
.
zeros
(
tpu_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
self
.
tpu_cache
.
append
((
tpu_k_cache
,
tpu_v_cache
))
cpu_k_cache
=
torch
.
zeros_like
(
tpu_k_cache
,
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
tpu_v_cache
,
device
=
"cpu"
)
cpu_k_cache
=
torch
.
zeros
(
cpu_cache_shape
,
dtype
=
dtype
,
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
cpu_k_cache
)
self
.
cpu_cache
.
append
((
cpu_k_cache
,
cpu_v_cache
))
self
.
_warmup_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment