Unverified Commit 9ddac563 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Platform] move current_memory_usage() into platform (#11369)


Signed-off-by: default avatarShanshan Shen <467638484@qq.com>
parent 1a51b9f8
...@@ -143,6 +143,13 @@ class CudaPlatformBase(Platform): ...@@ -143,6 +143,13 @@ class CudaPlatformBase(Platform):
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = 16
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str: kv_cache_dtype, block_size, use_v1) -> str:
......
...@@ -277,6 +277,15 @@ class Platform: ...@@ -277,6 +277,15 @@ class Platform:
return False return False
return True return True
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
"""
Return the memory usage in bytes.
"""
raise NotImplementedError
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
""" """
......
...@@ -157,3 +157,10 @@ class RocmPlatform(Platform): ...@@ -157,3 +157,10 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_punica_wrapper(cls) -> str: def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
...@@ -94,3 +94,10 @@ class XPUPlatform(Platform): ...@@ -94,3 +94,10 @@ class XPUPlatform(Platform):
def is_pin_memory_available(cls): def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.") logger.warning("Pin memory is not supported on XPU.")
return False return False
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.xpu.reset_peak_memory_stats(device)
return torch.xpu.max_memory_allocated(device)
...@@ -710,13 +710,7 @@ class DeviceMemoryProfiler: ...@@ -710,13 +710,7 @@ class DeviceMemoryProfiler:
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
# Return the memory usage in bytes. # Return the memory usage in bytes.
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_cuda_alike(): return current_platform.get_current_memory_usage(self.device)
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif current_platform.is_xpu():
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
def __enter__(self): def __enter__(self):
self.initial_memory = self.current_memory_usage() self.initial_memory = self.current_memory_usage()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment