Unverified Commit 7272bfae authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Misc] Refactor platform to get device specific stream and event (#14411)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent d9ac9e3d
...@@ -404,6 +404,15 @@ class Platform: ...@@ -404,6 +404,15 @@ class Platform:
) -> None: ) -> None:
"""Raises if this request is unsupported on this platform""" """Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s doesn't has '%s' attribute.",
self.device_name, key)
return None
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler) SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -89,14 +90,14 @@ class AsyncMetricsCollector: ...@@ -89,14 +90,14 @@ class AsyncMetricsCollector:
self._rank = rank self._rank = rank
if isinstance(device_type, torch.device): if isinstance(device_type, torch.device):
device_type = device_type.type device_type = device_type.type
if device_type == 'cuda': stream = current_platform.Stream
self._copy_stream = torch.cuda.Stream() if stream is not None:
self._copy_stream = stream()
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform # Skip for any platform that doesn't have device Event
from vllm.platforms import current_platform if current_platform.Event is None:
if not current_platform.is_cuda_alike():
return None return None
# If a copy was initiated in the previous call, collect and return. # If a copy was initiated in the previous call, collect and return.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment