Unverified Commit bcc2306c authored by Yufeng He's avatar Yufeng He Committed by GitHub
Browse files

[Bugfix] Respect VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY in prefetch offloader (#37699)


Signed-off-by: default avatarYufeng He <40085740+he-yufeng@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 3abf8584
...@@ -8,6 +8,7 @@ from vllm.model_executor.offloader.base import ( ...@@ -8,6 +8,7 @@ from vllm.model_executor.offloader.base import (
create_offloader, create_offloader,
get_offloader, get_offloader,
set_offloader, set_offloader,
should_pin_memory,
) )
from vllm.model_executor.offloader.prefetch import PrefetchOffloader from vllm.model_executor.offloader.prefetch import PrefetchOffloader
from vllm.model_executor.offloader.uva import UVAOffloader from vllm.model_executor.offloader.uva import UVAOffloader
...@@ -20,4 +21,5 @@ __all__ = [ ...@@ -20,4 +21,5 @@ __all__ = [
"create_offloader", "create_offloader",
"get_offloader", "get_offloader",
"set_offloader", "set_offloader",
"should_pin_memory",
] ]
...@@ -10,7 +10,9 @@ from typing import TYPE_CHECKING ...@@ -10,7 +10,9 @@ from typing import TYPE_CHECKING
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import OffloadConfig from vllm.config import OffloadConfig
...@@ -18,6 +20,18 @@ if TYPE_CHECKING: ...@@ -18,6 +20,18 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def should_pin_memory() -> bool:
"""Check if pinned memory should be used for weight offloading.
Combines the platform capability check with the user override env var.
On unified-memory systems (e.g. GH200) pinned memory eats into GPU
memory, so users can disable it via VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY.
"""
return (
is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
""" """
class relation: class relation:
......
...@@ -20,8 +20,7 @@ import torch.nn as nn ...@@ -20,8 +20,7 @@ import torch.nn as nn
# Import prefetch_ops to register custom ops at module load time # Import prefetch_ops to register custom ops at module load time
import vllm.model_executor.offloader.prefetch_ops # noqa: F401 import vllm.model_executor.offloader.prefetch_ops # noqa: F401
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.offloader.base import BaseOffloader from vllm.model_executor.offloader.base import BaseOffloader, should_pin_memory
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -528,7 +527,7 @@ class _ModuleOffloader: ...@@ -528,7 +527,7 @@ class _ModuleOffloader:
gpu_buffer = offloader._gpu_buffer gpu_buffer = offloader._gpu_buffer
assert cpu_storage is not None, "CPU storage not initialized" assert cpu_storage is not None, "CPU storage not initialized"
assert gpu_buffer is not None, "GPU buffer not assigned" assert gpu_buffer is not None, "GPU buffer not assigned"
assert not is_pin_memory_available() or cpu_storage.is_pinned(), ( assert not should_pin_memory() or cpu_storage.is_pinned(), (
f"CPU storage for {name} is not pinned! " f"CPU storage for {name} is not pinned! "
"non_blocking=True H2D copy from non-pinned memory " "non_blocking=True H2D copy from non-pinned memory "
"causes stream synchronization that breaks " "causes stream synchronization that breaks "
...@@ -629,7 +628,7 @@ class _CpuParamOffloader(_BaseParamOffloader): ...@@ -629,7 +628,7 @@ class _CpuParamOffloader(_BaseParamOffloader):
original GPU tensor is garbage collected. original GPU tensor is garbage collected.
""" """
param = self._param param = self._param
pin_memory = is_pin_memory_available() pin_memory = should_pin_memory()
# Create pinned CPU storage and copy current GPU data # Create pinned CPU storage and copy current GPU data
self._cpu_storage = torch.empty_strided( self._cpu_storage = torch.empty_strided(
...@@ -666,7 +665,7 @@ class _CpuParamOffloader(_BaseParamOffloader): ...@@ -666,7 +665,7 @@ class _CpuParamOffloader(_BaseParamOffloader):
param = self._param param = self._param
if param.data.device.type == "cpu": if param.data.device.type == "cpu":
if is_pin_memory_available() and not param.data.is_pinned(): if should_pin_memory() and not param.data.is_pinned():
pinned = torch.empty_strided( pinned = torch.empty_strided(
size=param.data.size(), size=param.data.size(),
stride=param.data.stride(), stride=param.data.stride(),
......
...@@ -10,9 +10,9 @@ from torch.func import functional_call ...@@ -10,9 +10,9 @@ from torch.func import functional_call
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.offloader.base import BaseOffloader from vllm.model_executor.offloader.base import BaseOffloader, should_pin_memory
from vllm.utils.mem_utils import format_gib from vllm.utils.mem_utils import format_gib
from vllm.utils.platform_utils import is_pin_memory_available, is_uva_available from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -43,10 +43,7 @@ class UVAOffloader(BaseOffloader): ...@@ -43,10 +43,7 @@ class UVAOffloader(BaseOffloader):
self.cpu_offload_bytes = 0 self.cpu_offload_bytes = 0
self.cpu_offload_params = cpu_offload_params or set() self.cpu_offload_params = cpu_offload_params or set()
self.pin_memory = ( self.pin_memory = should_pin_memory()
is_pin_memory_available()
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
self.uva_offloading = ( self.uva_offloading = (
is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment