Unverified Commit a407c4a9 authored by Yunqian Fan's avatar Yunqian Fan Committed by GitHub
Browse files

[Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379)

parent 3f8e6b59
import tilelang.testing
from tilelang.carver.arch.driver.cuda_driver import (
get_cuda_device_properties,
get_device_name,
get_shared_memory_per_block,
get_device_attribute,
get_max_dynamic_shared_size_bytes,
get_persisting_l2_cache_max_size,
get_num_sms,
get_registers_per_block,
)
import torch
class _cudaDeviceAttrNames:
r"""
This struct carries all properties that are of int32_t.
refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd
"""
cudaDevAttrMaxThreadsPerBlock: int = 1
cudaDevAttrMaxSharedMemoryPerBlock: int = 8
cudaDevAttrMultiProcessorCount: int = 16
cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81
cudaDevAttrMaxPersistingL2CacheSize: int = 108
def test_driver_get_device_properties():
prop = get_cuda_device_properties()
assert prop is not None, "Failed to get CUDA device properties"
assert isinstance(
prop,
torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties")
def test_device_get_device_name():
tl_device_name = get_device_name()
th_device_name = torch.cuda.get_device_name()
assert tl_device_name == th_device_name, "Device names do not match"
def test_device_get_shared_memory_per_block():
tl_smem = get_shared_memory_per_block()
driver_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerBlock)
assert tl_smem == driver_smem, "Shared memory per block values do not match"
def test_device_get_persisting_l2_cache_size():
tl_cache_size = get_persisting_l2_cache_max_size()
driver_cache_size = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize)
assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match"
def test_device_get_num_sms():
tl_num_sms = get_num_sms()
driver_num_sms = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMultiProcessorCount)
assert tl_num_sms == driver_num_sms, "Number of SMs do not match"
def test_device_get_registers_per_block():
tl_regs_per_block = get_registers_per_block()
driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock)
assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match"
def test_device_get_max_dynamic_shared_size_bytes():
tl_dynamic_smem = get_max_dynamic_shared_size_bytes()
driver_dynamic_smem = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor)
assert tl_dynamic_smem == driver_dynamic_smem, (
"Max dynamic shared size bytes values do not match")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -2,113 +2,43 @@ from __future__ import annotations ...@@ -2,113 +2,43 @@ from __future__ import annotations
import ctypes import ctypes
import sys import sys
try:
import torch.cuda._CudaDeviceProperties as _CudaDeviceProperties
except ImportError:
_CudaDeviceProperties = type("DummyCudaDeviceProperties", (), {})
class cudaDeviceProp(ctypes.Structure):
_fields_ = [
("name", ctypes.c_char * 256),
("uuid", ctypes.c_byte * 16), # cudaUUID_t
("luid", ctypes.c_char * 8),
("luidDeviceNodeMask", ctypes.c_uint),
("totalGlobalMem", ctypes.c_size_t),
("sharedMemPerBlock", ctypes.c_size_t),
("regsPerBlock", ctypes.c_int),
("warpSize", ctypes.c_int),
("memPitch", ctypes.c_size_t),
("maxThreadsPerBlock", ctypes.c_int),
("maxThreadsDim", ctypes.c_int * 3),
("maxGridSize", ctypes.c_int * 3),
("clockRate", ctypes.c_int),
("totalConstMem", ctypes.c_size_t),
("major", ctypes.c_int),
("minor", ctypes.c_int),
("textureAlignment", ctypes.c_size_t),
("texturePitchAlignment", ctypes.c_size_t),
("deviceOverlap", ctypes.c_int),
("multiProcessorCount", ctypes.c_int),
("kernelExecTimeoutEnabled", ctypes.c_int),
("integrated", ctypes.c_int),
("canMapHostMemory", ctypes.c_int),
("computeMode", ctypes.c_int),
("maxTexture1D", ctypes.c_int),
("maxTexture1DMipmap", ctypes.c_int),
("maxTexture1DLinear", ctypes.c_int),
("maxTexture2D", ctypes.c_int * 2),
("maxTexture2DMipmap", ctypes.c_int * 2),
("maxTexture2DLinear", ctypes.c_int * 3),
("maxTexture2DGather", ctypes.c_int * 2),
("maxTexture3D", ctypes.c_int * 3),
("maxTexture3DAlt", ctypes.c_int * 3),
("maxTextureCubemap", ctypes.c_int),
("maxTexture1DLayered", ctypes.c_int * 2),
("maxTexture2DLayered", ctypes.c_int * 3),
("maxTextureCubemapLayered", ctypes.c_int * 2),
("maxSurface1D", ctypes.c_int),
("maxSurface2D", ctypes.c_int * 2),
("maxSurface3D", ctypes.c_int * 3),
("maxSurface1DLayered", ctypes.c_int * 2),
("maxSurface2DLayered", ctypes.c_int * 3),
("maxSurfaceCubemap", ctypes.c_int),
("maxSurfaceCubemapLayered", ctypes.c_int * 2),
("surfaceAlignment", ctypes.c_size_t),
("concurrentKernels", ctypes.c_int),
("ECCEnabled", ctypes.c_int),
("pciBusID", ctypes.c_int),
("pciDeviceID", ctypes.c_int),
("pciDomainID", ctypes.c_int),
("tccDriver", ctypes.c_int),
("asyncEngineCount", ctypes.c_int),
("unifiedAddressing", ctypes.c_int),
("memoryClockRate", ctypes.c_int),
("memoryBusWidth", ctypes.c_int),
("l2CacheSize", ctypes.c_int),
("persistingL2CacheMaxSize", ctypes.c_int),
("maxThreadsPerMultiProcessor", ctypes.c_int),
("streamPrioritiesSupported", ctypes.c_int),
("globalL1CacheSupported", ctypes.c_int),
("localL1CacheSupported", ctypes.c_int),
("sharedMemPerMultiprocessor", ctypes.c_size_t),
("regsPerMultiprocessor", ctypes.c_int),
("managedMemory", ctypes.c_int),
("isMultiGpuBoard", ctypes.c_int),
("multiGpuBoardGroupID", ctypes.c_int),
("reserved2", ctypes.c_int * 2),
("reserved1", ctypes.c_int * 1),
("reserved", ctypes.c_int * 60)
]
class cudaDeviceAttrNames:
r"""
refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd
"""
def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None: cudaDevAttrMaxThreadsPerBlock: int = 1
cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81
cudaDevAttrMaxPersistingL2CacheSize: int = 108
if sys.platform == "win32":
libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll")
else:
libcudart = ctypes.cdll.LoadLibrary("libcudart.so")
prop = cudaDeviceProp() def get_cuda_device_properties(device_id: int = 0) -> _CudaDeviceProperties | None:
cudaGetDeviceProperties = libcudart.cudaGetDeviceProperties try:
cudaGetDeviceProperties.argtypes = [ctypes.POINTER(cudaDeviceProp), ctypes.c_int] import torch.cuda
cudaGetDeviceProperties.restype = ctypes.c_int
ret = cudaGetDeviceProperties(ctypes.byref(prop), device_id) if not torch.cuda.is_available():
if ret == 0: return None
return prop return torch.cuda.get_device_properties(torch.device(device_id))
else: except ImportError:
raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}") return None
def get_device_name(device_id: int = 0) -> str | None: def get_device_name(device_id: int = 0) -> str | None:
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop: if prop:
return prop.name.decode() return prop.name
else:
raise RuntimeError("Failed to get device properties.")
def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop: shared_mem = int(prop.shared_memory_per_block)
# Convert size_t to int to avoid overflow issues
shared_mem = int(prop.sharedMemPerBlock)
if format == "bytes": if format == "bytes":
return shared_mem return shared_mem
elif format == "kb": elif format == "kb":
...@@ -117,8 +47,6 @@ def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> in ...@@ -117,8 +47,6 @@ def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> in
return shared_mem // (1024 * 1024) return shared_mem // (1024 * 1024)
else: else:
raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb")
else:
raise RuntimeError("Failed to get device properties.")
def get_device_attribute(attr: int, device_id: int = 0) -> int: def get_device_attribute(attr: int, device_id: int = 0) -> int:
...@@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int: ...@@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
value = ctypes.c_int() value = ctypes.c_int()
cudaDeviceGetAttribute = libcudart.cudaDeviceGetAttribute cudaDeviceGetAttribute = libcudart.cudaDeviceGetAttribute
cudaDeviceGetAttribute.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.c_int] cudaDeviceGetAttribute.argtypes = [
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
ctypes.c_int,
]
cudaDeviceGetAttribute.restype = ctypes.c_int cudaDeviceGetAttribute.restype = ctypes.c_int
ret = cudaDeviceGetAttribute(ctypes.byref(value), attr, device_id) ret = cudaDeviceGetAttribute(ctypes.byref(value), attr, device_id)
...@@ -148,10 +80,8 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") ...@@ -148,10 +80,8 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
""" """
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id) shared_mem = get_device_attribute(
if prop: cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id)
# Convert size_t to int to avoid overflow issues
shared_mem = int(prop.sharedMemPerMultiprocessor)
if format == "bytes": if format == "bytes":
return shared_mem return shared_mem
elif format == "kb": elif format == "kb":
...@@ -160,16 +90,11 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") ...@@ -160,16 +90,11 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
return shared_mem // (1024 * 1024) return shared_mem // (1024 * 1024)
else: else:
raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb")
else:
raise RuntimeError("Failed to get device properties.")
def get_persisting_l2_cache_max_size(device_id: int = 0) -> int: def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
prop = get_cuda_device_properties(device_id) prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id)
if prop: return prop
return prop.persistingL2CacheMaxSize
else:
raise RuntimeError("Failed to get device properties for persisting L2 cache max size.")
def get_num_sms(device_id: int = 0) -> int: def get_num_sms(device_id: int = 0) -> int:
...@@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int: ...@@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int:
RuntimeError: If unable to get the device properties. RuntimeError: If unable to get the device properties.
""" """
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop: if prop is None:
return prop.multiProcessorCount
else:
raise RuntimeError("Failed to get device properties.") raise RuntimeError("Failed to get device properties.")
return prop.multi_processor_count
def get_registers_per_block(device_id: int = 0) -> int: def get_registers_per_block(device_id: int = 0) -> int:
prop = get_cuda_device_properties(device_id) """
if prop: Get the maximum number of 32-bit registers available per block.
return prop.regsPerBlock """
else: prop = get_device_attribute(
raise RuntimeError("Failed to get device properties.") cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock,
device_id,
)
return prop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment