Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3439c5a8
Unverified
Commit
3439c5a8
authored
Jun 26, 2024
by
Woosuk Kwon
Committed by
GitHub
Jun 26, 2024
Browse files
[Bugfix][TPU] Fix KV cache size calculation (#5860)
parent
6806998b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+7
-6
No files found.
vllm/worker/tpu_worker.py
View file @
3439c5a8
...
...
@@ -118,14 +118,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
xm
.
wait_device_ops
()
m
=
xm
.
get_memory_info
(
self
.
device
)
program_size
=
1024
*
1024
*
1024
# 1GB
free_bytes
=
max
(
m
[
"bytes_limit"
]
-
m
[
"bytes_used"
]
-
program_size
,
0
)
kv_cache_bytes
=
int
(
free_bytes
*
total_memory_size
=
m
[
"bytes_limit"
]
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
kv_cache_dtype_btyes
=
get_dtype_size
(
self
.
cache_dtype
)
profiled
=
m
[
"bytes_used"
]
# Weights + intermediate activations.
kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
dtype_btyes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size
=
self
.
cache_config
.
block_size
num_tpu_blocks
=
(
kv_cache_bytes
//
(
kv_cache_
dtype_btyes
*
block_size
*
num_layers
*
2
*
(
dtype_btyes
*
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
))
num_tpu_blocks
=
(
num_tpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment