Unverified Commit 8ad54a99 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Platform] Add current_platform.num_compute_units interface (#35042)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
parent 92510edc
...@@ -13,6 +13,7 @@ import torch ...@@ -13,6 +13,7 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import num_compute_units
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {} _TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {} _TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
...@@ -988,7 +989,7 @@ def apply_top_k_top_p_triton( ...@@ -988,7 +989,7 @@ def apply_top_k_top_p_triton(
else: else:
p_ptr = logits # Dummy pointer (won't be read) p_ptr = logits # Dummy pointer (won't be read)
num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count num_sm = num_compute_units(logits.device.index)
NUM_PROGRAMS = min(num_sm, batch_size) NUM_PROGRAMS = min(num_sm, batch_size)
# Cache per-Triton Program buffer on each device. # Cache per-Triton Program buffer on each device.
......
...@@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds ...@@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
get_dtype_size, get_dtype_size,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
...@@ -909,8 +909,8 @@ class GPUModelRunner( ...@@ -909,8 +909,8 @@ class GPUModelRunner(
# Note: used for model runner override. # Note: used for model runner override.
def _init_device_properties(self) -> None: def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties""" """Initialize attributes from torch.cuda.get_device_properties"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count self.num_sms = num_compute_units(self.device.index)
# Note: used for model runner override. # Note: used for model runner override.
def _sync_device(self) -> None: def _sync_device(self) -> None:
......
...@@ -23,6 +23,7 @@ from vllm.logger import init_logger ...@@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -72,8 +73,7 @@ class SMControlContextManager: ...@@ -72,8 +73,7 @@ class SMControlContextManager:
"SM control is currently only supported on CUDA" "SM control is currently only supported on CUDA"
) )
props = torch.cuda.get_device_properties(torch.cuda.current_device()) total_sms = num_compute_units(torch.cuda.current_device().index)
total_sms = props.multi_processor_count
assert comm_sms < total_sms assert comm_sms < total_sms
self.total_sms = total_sms self.total_sms = total_sms
......
...@@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner): ...@@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner):
# FIXME: To be verified. # FIXME: To be verified.
self.cascade_attn_enabled = False self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
self.num_sms = None
def _sync_device(self) -> None: def _sync_device(self) -> None:
torch.xpu.synchronize() torch.xpu.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment