Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
a407c4a9
Unverified
Commit
a407c4a9
authored
Dec 07, 2025
by
Yunqian Fan
Committed by
GitHub
Dec 07, 2025
Browse files
[Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379)
parent
3f8e6b59
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
135 additions
and
132 deletions
+135
-132
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
...hon/carver/test_tilelang_carver_cuda_driver_properties.py
+76
-0
tilelang/carver/arch/driver/cuda_driver.py
tilelang/carver/arch/driver/cuda_driver.py
+59
-132
No files found.
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
0 → 100644
View file @
a407c4a9
import
tilelang.testing
from
tilelang.carver.arch.driver.cuda_driver
import
(
get_cuda_device_properties
,
get_device_name
,
get_shared_memory_per_block
,
get_device_attribute
,
get_max_dynamic_shared_size_bytes
,
get_persisting_l2_cache_max_size
,
get_num_sms
,
get_registers_per_block
,
)
import
torch
class
_cudaDeviceAttrNames
:
r
"""
This struct carries all properties that are of int32_t.
refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd
"""
cudaDevAttrMaxThreadsPerBlock
:
int
=
1
cudaDevAttrMaxSharedMemoryPerBlock
:
int
=
8
cudaDevAttrMultiProcessorCount
:
int
=
16
cudaDevAttrMaxSharedMemoryPerMultiprocessor
:
int
=
81
cudaDevAttrMaxPersistingL2CacheSize
:
int
=
108
def
test_driver_get_device_properties
():
prop
=
get_cuda_device_properties
()
assert
prop
is
not
None
,
"Failed to get CUDA device properties"
assert
isinstance
(
prop
,
torch
.
cuda
.
_CudaDeviceProperties
),
(
"Returned object is not of type _CudaDeviceProperties"
)
def
test_device_get_device_name
():
tl_device_name
=
get_device_name
()
th_device_name
=
torch
.
cuda
.
get_device_name
()
assert
tl_device_name
==
th_device_name
,
"Device names do not match"
def
test_device_get_shared_memory_per_block
():
tl_smem
=
get_shared_memory_per_block
()
driver_smem
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerBlock
)
assert
tl_smem
==
driver_smem
,
"Shared memory per block values do not match"
def
test_device_get_persisting_l2_cache_size
():
tl_cache_size
=
get_persisting_l2_cache_max_size
()
driver_cache_size
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
)
assert
tl_cache_size
==
driver_cache_size
,
"Persisting L2 cache size values do not match"
def
test_device_get_num_sms
():
tl_num_sms
=
get_num_sms
()
driver_num_sms
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMultiProcessorCount
)
assert
tl_num_sms
==
driver_num_sms
,
"Number of SMs do not match"
def
test_device_get_registers_per_block
():
tl_regs_per_block
=
get_registers_per_block
()
driver_regs_per_block
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxThreadsPerBlock
)
assert
tl_regs_per_block
==
driver_regs_per_block
,
"Registers per block values do not match"
def
test_device_get_max_dynamic_shared_size_bytes
():
tl_dynamic_smem
=
get_max_dynamic_shared_size_bytes
()
driver_dynamic_smem
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
)
assert
tl_dynamic_smem
==
driver_dynamic_smem
,
(
"Max dynamic shared size bytes values do not match"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/carver/arch/driver/cuda_driver.py
View file @
a407c4a9
...
@@ -2,113 +2,43 @@ from __future__ import annotations
...
@@ -2,113 +2,43 @@ from __future__ import annotations
import
ctypes
import
ctypes
import
sys
import
sys
try
:
import
torch.cuda._CudaDeviceProperties
as
_CudaDeviceProperties
except
ImportError
:
_CudaDeviceProperties
=
type
(
"DummyCudaDeviceProperties"
,
(),
{})
class
cudaDeviceProp
(
ctypes
.
Structure
):
_fields_
=
[
(
"name"
,
ctypes
.
c_char
*
256
),
(
"uuid"
,
ctypes
.
c_byte
*
16
),
# cudaUUID_t
(
"luid"
,
ctypes
.
c_char
*
8
),
(
"luidDeviceNodeMask"
,
ctypes
.
c_uint
),
(
"totalGlobalMem"
,
ctypes
.
c_size_t
),
(
"sharedMemPerBlock"
,
ctypes
.
c_size_t
),
(
"regsPerBlock"
,
ctypes
.
c_int
),
(
"warpSize"
,
ctypes
.
c_int
),
(
"memPitch"
,
ctypes
.
c_size_t
),
(
"maxThreadsPerBlock"
,
ctypes
.
c_int
),
(
"maxThreadsDim"
,
ctypes
.
c_int
*
3
),
(
"maxGridSize"
,
ctypes
.
c_int
*
3
),
(
"clockRate"
,
ctypes
.
c_int
),
(
"totalConstMem"
,
ctypes
.
c_size_t
),
(
"major"
,
ctypes
.
c_int
),
(
"minor"
,
ctypes
.
c_int
),
(
"textureAlignment"
,
ctypes
.
c_size_t
),
(
"texturePitchAlignment"
,
ctypes
.
c_size_t
),
(
"deviceOverlap"
,
ctypes
.
c_int
),
(
"multiProcessorCount"
,
ctypes
.
c_int
),
(
"kernelExecTimeoutEnabled"
,
ctypes
.
c_int
),
(
"integrated"
,
ctypes
.
c_int
),
(
"canMapHostMemory"
,
ctypes
.
c_int
),
(
"computeMode"
,
ctypes
.
c_int
),
(
"maxTexture1D"
,
ctypes
.
c_int
),
(
"maxTexture1DMipmap"
,
ctypes
.
c_int
),
(
"maxTexture1DLinear"
,
ctypes
.
c_int
),
(
"maxTexture2D"
,
ctypes
.
c_int
*
2
),
(
"maxTexture2DMipmap"
,
ctypes
.
c_int
*
2
),
(
"maxTexture2DLinear"
,
ctypes
.
c_int
*
3
),
(
"maxTexture2DGather"
,
ctypes
.
c_int
*
2
),
(
"maxTexture3D"
,
ctypes
.
c_int
*
3
),
(
"maxTexture3DAlt"
,
ctypes
.
c_int
*
3
),
(
"maxTextureCubemap"
,
ctypes
.
c_int
),
(
"maxTexture1DLayered"
,
ctypes
.
c_int
*
2
),
(
"maxTexture2DLayered"
,
ctypes
.
c_int
*
3
),
(
"maxTextureCubemapLayered"
,
ctypes
.
c_int
*
2
),
(
"maxSurface1D"
,
ctypes
.
c_int
),
(
"maxSurface2D"
,
ctypes
.
c_int
*
2
),
(
"maxSurface3D"
,
ctypes
.
c_int
*
3
),
(
"maxSurface1DLayered"
,
ctypes
.
c_int
*
2
),
(
"maxSurface2DLayered"
,
ctypes
.
c_int
*
3
),
(
"maxSurfaceCubemap"
,
ctypes
.
c_int
),
(
"maxSurfaceCubemapLayered"
,
ctypes
.
c_int
*
2
),
(
"surfaceAlignment"
,
ctypes
.
c_size_t
),
(
"concurrentKernels"
,
ctypes
.
c_int
),
(
"ECCEnabled"
,
ctypes
.
c_int
),
(
"pciBusID"
,
ctypes
.
c_int
),
(
"pciDeviceID"
,
ctypes
.
c_int
),
(
"pciDomainID"
,
ctypes
.
c_int
),
(
"tccDriver"
,
ctypes
.
c_int
),
(
"asyncEngineCount"
,
ctypes
.
c_int
),
(
"unifiedAddressing"
,
ctypes
.
c_int
),
(
"memoryClockRate"
,
ctypes
.
c_int
),
(
"memoryBusWidth"
,
ctypes
.
c_int
),
(
"l2CacheSize"
,
ctypes
.
c_int
),
(
"persistingL2CacheMaxSize"
,
ctypes
.
c_int
),
(
"maxThreadsPerMultiProcessor"
,
ctypes
.
c_int
),
(
"streamPrioritiesSupported"
,
ctypes
.
c_int
),
(
"globalL1CacheSupported"
,
ctypes
.
c_int
),
(
"localL1CacheSupported"
,
ctypes
.
c_int
),
(
"sharedMemPerMultiprocessor"
,
ctypes
.
c_size_t
),
(
"regsPerMultiprocessor"
,
ctypes
.
c_int
),
(
"managedMemory"
,
ctypes
.
c_int
),
(
"isMultiGpuBoard"
,
ctypes
.
c_int
),
(
"multiGpuBoardGroupID"
,
ctypes
.
c_int
),
(
"reserved2"
,
ctypes
.
c_int
*
2
),
(
"reserved1"
,
ctypes
.
c_int
*
1
),
(
"reserved"
,
ctypes
.
c_int
*
60
)
]
class
cudaDeviceAttrNames
:
r
"""
refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd
"""
def
get_cuda_device_properties
(
device_id
:
int
=
0
)
->
cudaDeviceProp
|
None
:
cudaDevAttrMaxThreadsPerBlock
:
int
=
1
cudaDevAttrMaxSharedMemoryPerMultiprocessor
:
int
=
81
cudaDevAttrMaxPersistingL2CacheSize
:
int
=
108
if
sys
.
platform
==
"win32"
:
libcudart
=
ctypes
.
windll
.
LoadLibrary
(
"cudart64_110.dll"
)
else
:
libcudart
=
ctypes
.
cdll
.
LoadLibrary
(
"libcudart.so"
)
prop
=
cudaDeviceProp
()
def
get_cuda_device_properties
(
device_id
:
int
=
0
)
->
_CudaDeviceProperties
|
None
:
cudaGetDeviceProperties
=
libcudart
.
cudaGetDeviceProperties
try
:
cudaGetDeviceProperties
.
argtypes
=
[
ctypes
.
POINTER
(
cudaDeviceProp
),
ctypes
.
c_int
]
import
torch.cuda
cudaGetDeviceProperties
.
restype
=
ctypes
.
c_int
ret
=
cudaGetDeviceProperties
(
ctypes
.
byref
(
prop
),
device_id
)
if
not
torch
.
cuda
.
is_available
():
if
ret
==
0
:
return
None
return
prop
return
torch
.
cuda
.
get_device_properties
(
torch
.
device
(
device_id
))
e
lse
:
e
xcept
ImportError
:
r
aise
RuntimeError
(
f
"cudaGetDeviceProperties failed with error
{
ret
}
"
)
r
eturn
None
def
get_device_name
(
device_id
:
int
=
0
)
->
str
|
None
:
def
get_device_name
(
device_id
:
int
=
0
)
->
str
|
None
:
prop
=
get_cuda_device_properties
(
device_id
)
prop
=
get_cuda_device_properties
(
device_id
)
if
prop
:
if
prop
:
return
prop
.
name
.
decode
()
return
prop
.
name
else
:
raise
RuntimeError
(
"Failed to get device properties."
)
def
get_shared_memory_per_block
(
device_id
:
int
=
0
,
format
:
str
=
"bytes"
)
->
int
|
None
:
def
get_shared_memory_per_block
(
device_id
:
int
=
0
,
format
:
str
=
"bytes"
)
->
int
|
None
:
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
prop
=
get_cuda_device_properties
(
device_id
)
prop
=
get_cuda_device_properties
(
device_id
)
if
prop
:
shared_mem
=
int
(
prop
.
shared_memory_per_block
)
# Convert size_t to int to avoid overflow issues
shared_mem
=
int
(
prop
.
sharedMemPerBlock
)
if
format
==
"bytes"
:
if
format
==
"bytes"
:
return
shared_mem
return
shared_mem
elif
format
==
"kb"
:
elif
format
==
"kb"
:
...
@@ -117,8 +47,6 @@ def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> in
...
@@ -117,8 +47,6 @@ def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> in
return
shared_mem
//
(
1024
*
1024
)
return
shared_mem
//
(
1024
*
1024
)
else
:
else
:
raise
RuntimeError
(
"Invalid format. Must be one of: bytes, kb, mb"
)
raise
RuntimeError
(
"Invalid format. Must be one of: bytes, kb, mb"
)
else
:
raise
RuntimeError
(
"Failed to get device properties."
)
def
get_device_attribute
(
attr
:
int
,
device_id
:
int
=
0
)
->
int
:
def
get_device_attribute
(
attr
:
int
,
device_id
:
int
=
0
)
->
int
:
...
@@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
...
@@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
value
=
ctypes
.
c_int
()
value
=
ctypes
.
c_int
()
cudaDeviceGetAttribute
=
libcudart
.
cudaDeviceGetAttribute
cudaDeviceGetAttribute
=
libcudart
.
cudaDeviceGetAttribute
cudaDeviceGetAttribute
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_int
),
ctypes
.
c_int
,
ctypes
.
c_int
]
cudaDeviceGetAttribute
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_int
),
ctypes
.
c_int
,
ctypes
.
c_int
,
]
cudaDeviceGetAttribute
.
restype
=
ctypes
.
c_int
cudaDeviceGetAttribute
.
restype
=
ctypes
.
c_int
ret
=
cudaDeviceGetAttribute
(
ctypes
.
byref
(
value
),
attr
,
device_id
)
ret
=
cudaDeviceGetAttribute
(
ctypes
.
byref
(
value
),
attr
,
device_id
)
...
@@ -148,10 +80,8 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
...
@@ -148,10 +80,8 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
"""
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
assert
format
in
[
"bytes"
,
"kb"
,
"mb"
],
"Invalid format. Must be one of: bytes, kb, mb"
prop
=
get_cuda_device_properties
(
device_id
)
shared_mem
=
get_device_attribute
(
if
prop
:
cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device_id
)
# Convert size_t to int to avoid overflow issues
shared_mem
=
int
(
prop
.
sharedMemPerMultiprocessor
)
if
format
==
"bytes"
:
if
format
==
"bytes"
:
return
shared_mem
return
shared_mem
elif
format
==
"kb"
:
elif
format
==
"kb"
:
...
@@ -160,16 +90,11 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
...
@@ -160,16 +90,11 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
return
shared_mem
//
(
1024
*
1024
)
return
shared_mem
//
(
1024
*
1024
)
else
:
else
:
raise
RuntimeError
(
"Invalid format. Must be one of: bytes, kb, mb"
)
raise
RuntimeError
(
"Invalid format. Must be one of: bytes, kb, mb"
)
else
:
raise
RuntimeError
(
"Failed to get device properties."
)
def
get_persisting_l2_cache_max_size
(
device_id
:
int
=
0
)
->
int
:
def
get_persisting_l2_cache_max_size
(
device_id
:
int
=
0
)
->
int
:
prop
=
get_cuda_device_properties
(
device_id
)
prop
=
get_device_attribute
(
cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
,
device_id
)
if
prop
:
return
prop
return
prop
.
persistingL2CacheMaxSize
else
:
raise
RuntimeError
(
"Failed to get device properties for persisting L2 cache max size."
)
def
get_num_sms
(
device_id
:
int
=
0
)
->
int
:
def
get_num_sms
(
device_id
:
int
=
0
)
->
int
:
...
@@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int:
...
@@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int:
RuntimeError: If unable to get the device properties.
RuntimeError: If unable to get the device properties.
"""
"""
prop
=
get_cuda_device_properties
(
device_id
)
prop
=
get_cuda_device_properties
(
device_id
)
if
prop
:
if
prop
is
None
:
return
prop
.
multiProcessorCount
else
:
raise
RuntimeError
(
"Failed to get device properties."
)
raise
RuntimeError
(
"Failed to get device properties."
)
return
prop
.
multi_processor_count
def
get_registers_per_block
(
device_id
:
int
=
0
)
->
int
:
def
get_registers_per_block
(
device_id
:
int
=
0
)
->
int
:
prop
=
get_cuda_device_properties
(
device_id
)
"""
if
prop
:
Get the maximum number of 32-bit registers available per block.
return
prop
.
regsPerBlock
"""
else
:
prop
=
get_device_attribute
(
raise
RuntimeError
(
"Failed to get device properties."
)
cudaDeviceAttrNames
.
cudaDevAttrMaxThreadsPerBlock
,
device_id
,
)
return
prop
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment