Unverified Commit 0b407479 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[misc]refactor `Platform.set_device` method (#20262)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 5eaf5700
...@@ -75,6 +75,13 @@ class CpuPlatform(Platform): ...@@ -75,6 +75,13 @@ class CpuPlatform(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total return psutil.virtual_memory().total
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cpu.set_device(device)
@classmethod @classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return False return False
......
...@@ -77,7 +77,7 @@ class CudaPlatformBase(Platform): ...@@ -77,7 +77,7 @@ class CudaPlatformBase(Platform):
""" """
Set the device for the current platform. Set the device for the current platform.
""" """
super().set_device(device) torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly # With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668 # see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed # for why and when it is needed
......
...@@ -45,6 +45,13 @@ class HpuPlatform(Platform): ...@@ -45,6 +45,13 @@ class HpuPlatform(Platform):
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.hpu.set_device(device)
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
......
...@@ -305,7 +305,7 @@ class Platform: ...@@ -305,7 +305,7 @@ class Platform:
""" """
Set the device for the current platform. Set the device for the current platform.
""" """
torch.cuda.set_device(device) raise NotImplementedError
@classmethod @classmethod
def pre_register_and_update(cls, def pre_register_and_update(cls,
......
...@@ -241,6 +241,17 @@ class RocmPlatform(Platform): ...@@ -241,6 +241,17 @@ class RocmPlatform(Platform):
logger.info("Using ROCmFlashAttention backend.") logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)
@classmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_capability(cls, def get_device_capability(cls,
......
...@@ -55,6 +55,13 @@ class TpuPlatform(Platform): ...@@ -55,6 +55,13 @@ class TpuPlatform(Platform):
logger.info("Using Pallas V1 backend.") logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.tpu.set_device(device)
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
chip_type, _ = device.get_local_chips() chip_type, _ = device.get_local_chips()
......
...@@ -45,6 +45,13 @@ class XPUPlatform(Platform): ...@@ -45,6 +45,13 @@ class XPUPlatform(Platform):
logger.info("Using Flash Attention backend on V1 engine.") logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.xpu.set_device(device)
@classmethod @classmethod
def get_device_capability( def get_device_capability(
cls, cls,
......
...@@ -130,7 +130,7 @@ class Worker(WorkerBase): ...@@ -130,7 +130,7 @@ class Worker(WorkerBase):
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device) current_platform.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect() gc.collect()
......
...@@ -132,7 +132,7 @@ class XPUWorker(Worker): ...@@ -132,7 +132,7 @@ class XPUWorker(Worker):
if self.device_config.device.type == "xpu" and current_platform.is_xpu( if self.device_config.device.type == "xpu" and current_platform.is_xpu(
): ):
self.device = torch.device(f"xpu:{self.local_rank}") self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device) current_platform.set_device(self.device)
torch.xpu.empty_cache() torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties( self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory self.local_rank).total_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment