Commit 761b730d authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by Kevin H. Luu
Browse files

[BugFix] Fix memory spike in workspace allocation (#30744)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
(cherry picked from commit 00a8d762)
parent f34eca5f
...@@ -1223,6 +1223,8 @@ steps: ...@@ -1223,6 +1223,8 @@ steps:
# FIXIT: find out which code initialize cuda before running the test # FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it # before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Alot of these tests are on the edge of OOMing
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# There is some Tensor Parallelism related processing logic in LoRA that # There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation. # requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py
......
...@@ -145,12 +145,20 @@ class WorkspaceManager: ...@@ -145,12 +145,20 @@ class WorkspaceManager:
for ubatch_id in range(self._num_ubatches): for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id] current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None: if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty( self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device (required_bytes,), dtype=torch.uint8, device=self._device
) )
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE: if envs.VLLM_DEBUG_WORKSPACE:
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment