Unverified Commit 8ad54a99 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Platform] Add current_platform.num_compute_units interface (#35042)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
parent 92510edc
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.platform_utils import num_compute_units
def cal_diff( def cal_diff(
...@@ -124,8 +125,7 @@ def test_cutlass_mla_decode( ...@@ -124,8 +125,7 @@ def test_cutlass_mla_decode(
q_pe = q_pe_padded q_pe = q_pe_padded
kv_cache_flat = blocked_k.squeeze(2) kv_cache_flat = blocked_k.squeeze(2)
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0")) sm_count = num_compute_units(device.index)
sm_count = device_properties.multi_processor_count
workspace_size = ops.sm100_cutlass_mla_get_workspace_size( workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen * block_size, b, sm_count, num_kv_splits=1 max_seqlen * block_size, b, sm_count, num_kv_splits=1
) )
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.platform_utils import num_compute_units
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool: def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
...@@ -78,7 +79,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): ...@@ -78,7 +79,7 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
if has_zp: if has_zp:
zp = zp.to(dtype) zp = zp.to(dtype)
properties = torch.cuda.get_device_properties(qw.device.index) properties = torch.cuda.get_device_properties(qw.device.index)
sm_count = properties.multi_processor_count sm_count = num_compute_units(qw.device.index)
sm_version = properties.major * 10 + properties.minor sm_version = properties.major * 10 + properties.minor
n_32align = (n + 32 - 1) // 32 * 32 n_32align = (n + 32 - 1) // 32 * 32
......
...@@ -9,7 +9,7 @@ import vllm._custom_ops as ops ...@@ -9,7 +9,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950 from vllm.platforms.rocm import on_gfx950
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import num_compute_units
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
BIAS_MODES = [0, 1, 2] BIAS_MODES = [0, 1, 2]
...@@ -121,7 +121,7 @@ def pad_fp8(weight): ...@@ -121,7 +121,7 @@ def pad_fp8(weight):
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = num_compute_units()
# Next ^2 of n # Next ^2 of n
N_p2 = 1 << (n - 1).bit_length() N_p2 = 1 << (n - 1).bit_length()
...@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): ...@@ -186,7 +186,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = num_compute_units()
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
...@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): ...@@ -203,7 +203,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
...@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): ...@@ -222,7 +222,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
cu_count = get_cu_count() cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
...@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel( ...@@ -267,7 +267,7 @@ def test_rocm_wvsplitk_fp8_kernel(
ref_out = torch._scaled_mm( ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
) )
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS) out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, num_compute_units(), BIAS)
if xnorm: if xnorm:
torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8)
......
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.allspark_utils import (
check_allspark_supported_dtype_shape, check_allspark_supported_dtype_shape,
) )
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.utils.platform_utils import num_compute_units
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
...@@ -45,7 +46,7 @@ class AllSparkLinearKernel(MPLinearKernel): ...@@ -45,7 +46,7 @@ class AllSparkLinearKernel(MPLinearKernel):
# prepare the parameters required for the kernel # prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index) properties = torch.cuda.get_device_properties(device.index)
sm_count = properties.multi_processor_count sm_count = num_compute_units(device.index)
sm_version = properties.major * 10 + properties.minor sm_version = properties.major * 10 + properties.minor
gemm_args = {} gemm_args = {}
gemm_args["sm_count"] = sm_count gemm_args["sm_count"] = sm_count
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
...@@ -36,7 +36,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( ...@@ -36,7 +36,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
out_dtype, out_dtype,
As, As,
Bs, Bs,
get_cu_count(), num_compute_units(),
bias, bias,
) )
# Fallback # Fallback
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
...@@ -147,7 +148,7 @@ def matmul_persistent( ...@@ -147,7 +148,7 @@ def matmul_persistent(
assert bias is None or bias.dim() == 1, ( assert bias is None or bias.dim() == 1, (
"Currently assuming bias is 1D, let Horace know if you run into this" "Currently assuming bias is 1D, let Horace know if you run into this"
) )
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count NUM_SMS = num_compute_units(a.device.index)
M, K = a.shape M, K = a.shape
K, N = b.shape K, N = b.shape
dtype = a.dtype dtype = a.dtype
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from functools import lru_cache
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -22,6 +20,7 @@ from einops import rearrange ...@@ -22,6 +20,7 @@ from einops import rearrange
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, next_power_of_2 from vllm.utils.math_utils import cdiv, next_power_of_2
from vllm.utils.platform_utils import num_compute_units
from .utils import input_guard from .utils import input_guard
...@@ -162,15 +161,8 @@ def layer_norm_fwd_kernel( ...@@ -162,15 +161,8 @@ def layer_norm_fwd_kernel(
tl.store(Y_base, y, mask=mask) tl.store(Y_base, y, mask=mask)
@lru_cache
def _get_sm_count(device: torch.device) -> int:
"""Get and cache the SM count for a given device."""
props = torch.cuda.get_device_properties(device)
return props.multi_processor_count
def calc_rows_per_block(M: int, device: torch.device) -> int: def calc_rows_per_block(M: int, device: torch.device) -> int:
sm_count = _get_sm_count(device) sm_count = num_compute_units(device.index)
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4) rows_per_block = min(rows_per_block, 4)
return rows_per_block return rows_per_block
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils.platform_utils import num_compute_units
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
...@@ -271,7 +272,7 @@ def marlin_make_workspace_new( ...@@ -271,7 +272,7 @@ def marlin_make_workspace_new(
) -> torch.Tensor: ) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace # In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is sms_count * max_blocks_per_sm. # size. The num of threadblocks is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count sms = num_compute_units(device.index)
return torch.zeros( return torch.zeros(
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
) )
......
...@@ -11,7 +11,7 @@ from vllm import envs ...@@ -11,7 +11,7 @@ from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -149,7 +149,7 @@ def rocm_unquantized_gemm_impl( ...@@ -149,7 +149,7 @@ def rocm_unquantized_gemm_impl(
m = weight.shape[0] m = weight.shape[0]
k = weight.shape[1] k = weight.shape[1]
cu_count = get_cu_count() cu_count = num_compute_units()
if use_aiter_triton_gemm(n, m, k, x.dtype): if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
...@@ -199,7 +199,7 @@ def rocm_unquantized_gemm_impl( ...@@ -199,7 +199,7 @@ def rocm_unquantized_gemm_impl(
x_view = x.reshape(-1, x.size(-1)) x_view = x.reshape(-1, x.size(-1))
if m > 8 and 0 < n <= 4: if m > 8 and 0 < n <= 4:
cu_count = get_cu_count() cu_count = num_compute_units()
out = ops.wvSplitK(weight, x_view, cu_count, bias) out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0]) return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
......
...@@ -26,6 +26,7 @@ from vllm.utils.deep_gemm import ( ...@@ -26,6 +26,7 @@ from vllm.utils.deep_gemm import (
m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nt_contiguous,
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
def _generate_optimal_warmup_m_values( def _generate_optimal_warmup_m_values(
...@@ -44,7 +45,7 @@ def _generate_optimal_warmup_m_values( ...@@ -44,7 +45,7 @@ def _generate_optimal_warmup_m_values(
# DeepGEMM's possible block sizes # DeepGEMM's possible block sizes
block_ms = [64, 128, 256] block_ms = [64, 128, 256]
block_ns = list(range(16, min(257, n + 1), 16)) block_ns = list(range(16, min(257, n + 1), 16))
num_sms = torch.cuda.get_device_properties(device).multi_processor_count num_sms = num_compute_units(device.index)
m_values = set() m_values = set()
......
...@@ -538,6 +538,10 @@ class CudaPlatformBase(Platform): ...@@ -538,6 +538,10 @@ class CudaPlatformBase(Platform):
def support_static_graph_mode(cls) -> bool: def support_static_graph_mode(cls) -> bool:
return True return True
@classmethod
def num_compute_units(cls, device_id=0):
return torch.cuda.get_device_properties(device_id).multi_processor_count
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -692,6 +692,16 @@ class Platform: ...@@ -692,6 +692,16 @@ class Platform:
""" """
return {} return {}
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
"""
Get the number of compute units for the current platform.
(NVIDIA SM / AMD CU / Intel EU)
"""
raise NotImplementedError(
"num_compute_units is not implemented for the current platform."
)
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
......
...@@ -682,3 +682,7 @@ class RocmPlatform(Platform): ...@@ -682,3 +682,7 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def support_static_graph_mode(cls) -> bool: def support_static_graph_mode(cls) -> bool:
return True return True
@classmethod
def num_compute_units(cls, device_id=0):
return torch.cuda.get_device_properties(device_id).multi_processor_count
...@@ -277,3 +277,7 @@ class XPUPlatform(Platform): ...@@ -277,3 +277,7 @@ class XPUPlatform(Platform):
"""Copy blocks from XPU to host (CPU).""" """Copy blocks from XPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices] _src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu() dst_cache[:, dst_block_indices] = _src_cache.cpu()
@classmethod
def num_compute_units(cls, device_id: int = 0) -> int:
return torch.xpu.get_device_properties(device_id).max_compute_units
...@@ -24,11 +24,6 @@ def xpu_is_initialized() -> bool: ...@@ -24,11 +24,6 @@ def xpu_is_initialized() -> bool:
return torch.xpu.is_initialized() return torch.xpu.is_initialized()
def get_cu_count(device_id: int = 0) -> int:
"""Returns the total number of compute units (CU) on single GPU."""
return torch.cuda.get_device_properties(device_id).multi_processor_count
def cuda_get_device_properties( def cuda_get_device_properties(
device, names: Sequence[str], init_cuda=False device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]: ) -> tuple[Any, ...]:
...@@ -57,3 +52,11 @@ def is_uva_available() -> bool: ...@@ -57,3 +52,11 @@ def is_uva_available() -> bool:
# UVA requires pinned memory. # UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed. # TODO: Add more requirements for UVA if needed.
return is_pin_memory_available() return is_pin_memory_available()
@cache
def num_compute_units(device_id: int = 0) -> int:
"""Get the number of compute units of the current device."""
from vllm.platforms import current_platform
return current_platform.num_compute_units(device_id)
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
...@@ -74,8 +75,7 @@ class SM100Workspace: ...@@ -74,8 +75,7 @@ class SM100Workspace:
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy # Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
# (assumes all devices are similar) # (assumes all devices are similar)
properties = torch.cuda.get_device_properties(torch.device("cuda:0")) self._sm_count = num_compute_units(0)
self._sm_count = properties.multi_processor_count
def get_buf(self): def get_buf(self):
return self._workspace_buf return self._workspace_buf
......
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
...@@ -130,8 +131,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -130,8 +131,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_num_splits = None self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
device_properties = torch.cuda.get_device_properties(self.device) num_sms = num_compute_units(self.device.index)
num_sms = device_properties.multi_processor_count
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros( self.cg_buf_tile_scheduler_metadata = torch.zeros(
......
...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -237,8 +238,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -237,8 +238,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2) # DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device) sm_count = num_compute_units(device.index)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config) self.mla_dims = get_mla_dims(self.model_config)
......
...@@ -9,6 +9,7 @@ from vllm.config import VllmConfig ...@@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -219,8 +220,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -219,8 +220,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
) )
self.reorder_batch_threshold += self.num_speculative_tokens self.reorder_batch_threshold += self.num_speculative_tokens
props = torch.cuda.get_device_properties(self.device) sm_count = num_compute_units(self.device.index)
sm_count = props.multi_processor_count
self.num_sms = sm_count self.num_sms = sm_count
self.decode_lens_buffer = torch.empty( self.decode_lens_buffer = torch.empty(
......
...@@ -13,7 +13,7 @@ from vllm.logger import init_logger ...@@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -38,7 +38,7 @@ if current_platform.is_rocm(): ...@@ -38,7 +38,7 @@ if current_platform.is_rocm():
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens): def num_programs(total_tokens):
return min(total_tokens, get_cu_count()) return min(total_tokens, num_compute_units())
@triton.jit @triton.jit
def cp_mha_gather_cache_kernel( def cp_mha_gather_cache_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment