Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
380e1863
Unverified
Commit
380e1863
authored
Oct 18, 2024
by
Joe Runde
Committed by
GitHub
Oct 18, 2024
Browse files
🐛
fix torch memory profiling (#9516)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
337ed766
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
11 deletions
+14
-11
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+1
-2
tests/worker/test_profile.py
tests/worker/test_profile.py
+6
-5
vllm/worker/worker.py
vllm/worker/worker.py
+7
-4
No files found.
tests/quantization/test_bitsandbytes.py
View file @
380e1863
...
...
@@ -107,8 +107,7 @@ def validate_generated_texts(hf_runner,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
tensor_parallel_size
=
vllm_tp_size
,
enforce_eager
=
False
,
gpu_memory_utilization
=
0.8
)
as
llm
:
enforce_eager
=
False
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
...
...
tests/worker/test_profile.py
View file @
380e1863
...
...
@@ -54,16 +54,17 @@ def test_gpu_memory_profiling():
gpu_blocks
,
_
=
worker
.
determine_num_available_blocks
()
# Peak vram usage by torch should be 0.7077 GiB
# No
n-torch allocations should be 0.0079 GiB
# No
memory should be allocated outside of torch
# 9.0 GiB should be the utilization target
# 8.2
84
3 GiB should be available for the KV cache
# 8.2
92
3 GiB should be available for the KV cache
block_size
=
CacheEngine
.
get_cache_block_size
(
engine_config
.
cache_config
,
engine_config
.
model_config
,
engine_config
.
parallel_config
)
expected_blocks
=
(
8.2
84
3
*
1024
**
3
)
//
block_size
expected_blocks
=
(
8.2
92
3
*
1024
**
3
)
//
block_size
# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization
assert
abs
(
gpu_blocks
-
expected_blocks
)
<
5
# utilization.
# A 10 block tolerance here should be about 6MB of wiggle room.
assert
abs
(
gpu_blocks
-
expected_blocks
)
<
10
vllm/worker/worker.py
View file @
380e1863
...
...
@@ -232,10 +232,11 @@ class Worker(LocalOrDistributedWorkerBase):
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch
.
cuda
.
empty_cache
()
# After emptying the torch cache, any other increase in gpu ram should
# be from non-torch allocations.
non_torch_allocations
=
free_memory_pre_profile
-
\
torch
.
cuda
.
mem_get_info
()[
0
]
torch_allocated_bytes
=
torch
.
cuda
.
memory_stats
(
)[
"allocated_bytes.all.current"
]
total_allocated_bytes
=
torch
.
cuda
.
mem_get_info
(
)[
1
]
-
torch
.
cuda
.
mem_get_info
()[
0
]
non_torch_allocations
=
total_allocated_bytes
-
torch_allocated_bytes
if
non_torch_allocations
>
0
:
peak_memory
+=
non_torch_allocations
...
...
@@ -259,10 +260,12 @@ class Worker(LocalOrDistributedWorkerBase):
logger
.
info
(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGib"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f"
,
total_gpu_memory
/
(
1024
**
3
),
(
total_gpu_memory
-
free_memory_pre_profile
)
/
(
1024
**
3
),
(
peak_memory
-
non_torch_allocations
)
/
(
1024
**
3
),
total_allocated_bytes
/
(
1024
**
3
),
non_torch_allocations
/
(
1024
**
3
),
available_kv_cache_memory
/
(
1024
**
3
),
self
.
cache_config
.
gpu_memory_utilization
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment