"...composable_kernel.git" did not exist on "146972f447503ec8443889855bc0b80f1e3d364e"
Unverified Commit 8f50c122 authored by Yunqian Fan's avatar Yunqian Fan Committed by GitHub
Browse files

[Fix] typo in cuda attr (#1380)

* [Bugfix] make cuda driver api compat with cuda12/13, along with tests

* fix typo in cudaDevAttr
parent a407c4a9
...@@ -20,6 +20,7 @@ class _cudaDeviceAttrNames: ...@@ -20,6 +20,7 @@ class _cudaDeviceAttrNames:
cudaDevAttrMaxThreadsPerBlock: int = 1 cudaDevAttrMaxThreadsPerBlock: int = 1
cudaDevAttrMaxSharedMemoryPerBlock: int = 8 cudaDevAttrMaxSharedMemoryPerBlock: int = 8
cudaDevAttrMaxRegistersPerBlock: int = 12
cudaDevAttrMultiProcessorCount: int = 16 cudaDevAttrMultiProcessorCount: int = 16
cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81
cudaDevAttrMaxPersistingL2CacheSize: int = 108 cudaDevAttrMaxPersistingL2CacheSize: int = 108
...@@ -60,7 +61,8 @@ def test_device_get_num_sms(): ...@@ -60,7 +61,8 @@ def test_device_get_num_sms():
def test_device_get_registers_per_block(): def test_device_get_registers_per_block():
tl_regs_per_block = get_registers_per_block() tl_regs_per_block = get_registers_per_block()
driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock) driver_regs_per_block = get_device_attribute(
_cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock)
assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match"
......
...@@ -14,6 +14,7 @@ class cudaDeviceAttrNames: ...@@ -14,6 +14,7 @@ class cudaDeviceAttrNames:
""" """
cudaDevAttrMaxThreadsPerBlock: int = 1 cudaDevAttrMaxThreadsPerBlock: int = 1
cudaDevAttrMaxRegistersPerBlock: int = 12
cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81
cudaDevAttrMaxPersistingL2CacheSize: int = 108 cudaDevAttrMaxPersistingL2CacheSize: int = 108
...@@ -38,6 +39,8 @@ def get_device_name(device_id: int = 0) -> str | None: ...@@ -38,6 +39,8 @@ def get_device_name(device_id: int = 0) -> str | None:
def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id) prop = get_cuda_device_properties(device_id)
if prop is None:
raise RuntimeError("Failed to get device properties.")
shared_mem = int(prop.shared_memory_per_block) shared_mem = int(prop.shared_memory_per_block)
if format == "bytes": if format == "bytes":
return shared_mem return shared_mem
...@@ -121,7 +124,7 @@ def get_registers_per_block(device_id: int = 0) -> int: ...@@ -121,7 +124,7 @@ def get_registers_per_block(device_id: int = 0) -> int:
Get the maximum number of 32-bit registers available per block. Get the maximum number of 32-bit registers available per block.
""" """
prop = get_device_attribute( prop = get_device_attribute(
cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock, cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock,
device_id, device_id,
) )
return prop return prop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment