Unverified Commit 8ad54a99 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Platform] Add current_platform.num_compute_units interface (#35042)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
parent 92510edc
......@@ -13,6 +13,7 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.platform_utils import num_compute_units
_TRITON_TABLE_CACHE: dict[tuple[torch.device], tuple[torch.Tensor, torch.Tensor]] = {}
_TRITON_BUFFER_CACHE: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {}
......@@ -988,7 +989,7 @@ def apply_top_k_top_p_triton(
else:
p_ptr = logits # Dummy pointer (won't be read)
num_sm = torch.cuda.get_device_properties(logits.device).multi_processor_count
num_sm = num_compute_units(logits.device.index)
NUM_PROGRAMS = min(num_sm, batch_size)
# Cache per-Triton Program buffer on each device.
......
......@@ -98,7 +98,7 @@ from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import (
get_dtype_size,
kv_cache_dtype_str_to_dtype,
......@@ -909,8 +909,8 @@ class GPUModelRunner(
# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
self.num_sms = num_compute_units(self.device.index)
# Note: used for model runner override.
def _sync_device(self) -> None:
......
......@@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
logger = init_logger(__name__)
......@@ -72,8 +73,7 @@ class SMControlContextManager:
"SM control is currently only supported on CUDA"
)
props = torch.cuda.get_device_properties(torch.cuda.current_device())
total_sms = props.multi_processor_count
total_sms = num_compute_units(torch.cuda.current_device().index)
assert comm_sms < total_sms
self.total_sms = total_sms
......
......@@ -28,9 +28,6 @@ class XPUModelRunner(GPUModelRunner):
# FIXME: To be verified.
self.cascade_attn_enabled = False
def _init_device_properties(self) -> None:
self.num_sms = None
def _sync_device(self) -> None:
torch.xpu.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment