Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
96bc209b
Commit
96bc209b
authored
Aug 02, 2022
by
Titus von Koeller
Browse files
tentative refactoring of the compute capabilities code
parent
59a615b3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
21 deletions
+35
-21
bitsandbytes/cuda_setup/compute_capability.py
bitsandbytes/cuda_setup/compute_capability.py
+35
-21
No files found.
bitsandbytes/cuda_setup/compute_capability.py
View file @
96bc209b
...
@@ -2,27 +2,28 @@ import ctypes
...
@@ -2,27 +2,28 @@ import ctypes
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
CUDA_SUCCESS
=
0
@
dataclass
@
dataclass
class
CudaLibVals
:
class
CudaLibVals
:
# code bits taken from
# code bits taken from
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
nGpus
=
ctypes
.
c_int
()
nGpus
:
ctypes
.
c_int
=
field
(
default
=
ctypes
.
c_int
()
)
cc_major
=
ctypes
.
c_int
()
cc_major
:
ctypes
.
c_int
=
field
(
default
=
ctypes
.
c_int
()
)
cc_minor
=
ctypes
.
c_int
()
cc_minor
:
ctypes
.
c_int
=
field
(
default
=
ctypes
.
c_int
()
)
device
=
ctypes
.
c_int
()
device
:
ctypes
.
c_int
=
field
(
default
=
ctypes
.
c_int
()
)
error_str
=
ctypes
.
c_char_p
()
error_str
:
ctypes
.
c_char_p
=
field
(
default
=
ctypes
.
c_char_p
()
)
cuda
:
ctypes
.
CDLL
=
field
(
init
=
False
,
repr
=
False
)
cuda
:
ctypes
.
CDLL
=
field
(
init
=
False
,
repr
=
False
)
ccs
:
List
[
str
,
...]
=
field
(
init
=
False
)
ccs
:
List
[
str
,
...]
=
field
(
init
=
False
)
def
load_cuda_lib
(
self
):
def
_initialize_driver_API
(
self
):
self
.
check_cuda_result
(
self
.
cuda
.
cuInit
(
0
))
def
_load_cuda_lib
(
self
):
"""
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
init_device -> init variables -> call function by reference
"""
"""
libnames
=
(
"libcuda.so"
)
libnames
=
"libcuda.so"
for
libname
in
libnames
:
for
libname
in
libnames
:
try
:
try
:
self
.
cuda
=
ctypes
.
CDLL
(
libname
)
self
.
cuda
=
ctypes
.
CDLL
(
libname
)
...
@@ -33,32 +34,45 @@ class CudaLibVals:
...
@@ -33,32 +34,45 @@ class CudaLibVals:
else
:
else
:
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
raise
OSError
(
"could not load any of: "
+
" "
.
join
(
libnames
))
def
check_cuda_result
(
self
,
result_val
):
def
call_cuda_func
(
self
,
function_obj
,
**
kwargs
):
CUDA_SUCCESS
=
0
# constant taken from cuda.h
pass
# if (CUDA_SUCCESS := function_obj(
def
_error_handle
(
cuda_lib_call_return_value
):
"""
"""
2. call extern C function to determine CC
2. call extern C function to determine CC
(see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
(see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
"""
"""
cls_fields
:
Tuple
[
Field
,
...]
=
fields
(
self
.
__class__
)
CUDA_SUCCESS
=
0
# constant taken from cuda.h
if
result_val
!=
0
:
if
cuda_lib_call_return_value
!=
CUDA_SUCCESS
:
self
.
cuda
.
cuGetErrorString
(
result_val
,
ctypes
.
byref
(
self
.
error_str
))
self
.
cuda
.
cuGetErrorString
(
cuda_lib_call_return_value
,
ctypes
.
byref
(
self
.
error_str
),
)
print
(
"Count not initialize CUDA - failure!"
)
print
(
"Count not initialize CUDA - failure!"
)
raise
Exception
(
"CUDA exception!"
)
raise
Exception
(
"CUDA exception!"
)
return
result
_val
return
cuda_lib_call_return
_val
ue
def
__post_init__
(
self
):
def
__post_init__
(
self
):
self
.
load_cuda_lib
()
self
.
_load_cuda_lib
()
self
.
check_cuda_result
(
self
.
cuda
.
cuInit
(
0
))
self
.
_initialize_driver_API
()
self
.
check_cuda_result
(
self
.
cuda
,
self
.
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
self
.
nGpus
)))
self
.
check_cuda_result
(
self
.
cuda
,
self
.
cuda
.
cuDeviceGetCount
(
ctypes
.
byref
(
self
.
nGpus
))
)
tmp_ccs
=
[]
tmp_ccs
=
[]
for
gpu_index
in
range
(
self
.
nGpus
.
value
):
for
gpu_index
in
range
(
self
.
nGpus
.
value
):
check_cuda_result
(
check_cuda_result
(
self
.
cuda
,
self
.
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
self
.
device
),
gpu_index
)
self
.
cuda
,
self
.
cuda
.
cuDeviceGet
(
ctypes
.
byref
(
self
.
device
),
gpu_index
),
)
)
check_cuda_result
(
check_cuda_result
(
self
.
cuda
,
self
.
cuda
,
self
.
cuda
.
cuDeviceComputeCapability
(
self
.
cuda
.
cuDeviceComputeCapability
(
ctypes
.
byref
(
self
.
cc_major
),
ctypes
.
byref
(
self
.
cc_minor
),
self
.
device
ctypes
.
byref
(
self
.
cc_major
),
ctypes
.
byref
(
self
.
cc_minor
),
self
.
device
,
),
),
)
)
tmp_ccs
.
append
(
f
"
{
self
.
cc_major
.
value
}
.
{
self
.
cc_minor
.
value
}
"
)
tmp_ccs
.
append
(
f
"
{
self
.
cc_major
.
value
}
.
{
self
.
cc_minor
.
value
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment