Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
5e456be5
Unverified
Commit
5e456be5
authored
Apr 10, 2023
by
justheuristic
Committed by
GitHub
Apr 10, 2023
Browse files
Support 1650, 1660
parent
49a04253
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
1 deletion
+12
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+12
-1
No files found.
bitsandbytes/autograd/_functions.py
View file @
5e456be5
...
@@ -221,6 +221,17 @@ bmm_cublas = MatMul8bit.apply
...
@@ -221,6 +221,17 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
def
supports_igemmlt
(
device
:
torch
.
device
)
->
bool
:
"""check if this device supports the optimized int8 kernel"""
if
torch
.
cuda
.
get_device_capability
(
device
=
device
)
<
(
7
,
5
):
return
False
device_name
=
torch
.
cuda
.
get_device_name
(
device
=
device
)
nvidia16_models
=
(
'GTX 1630'
,
'GTX 1650'
,
'GTX 1660'
)
# https://en.wikipedia.org/wiki/GeForce_16_series
if
any
(
model_name
in
device_name
for
model_name
in
nvidia16_models
):
return
False
# these devices are technically cuda 7.5-capable, but they lack tensor cores
return
True
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -270,7 +281,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -270,7 +281,7 @@ class MatMul8bitLt(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
):
using_igemmlt
=
torch
.
cuda
.
get_device_capability
(
device
=
A
.
device
)
>=
(
7
,
5
)
and
not
state
.
force_no_igemmlt
using_igemmlt
=
supports_igemmlt
(
A
.
device
)
and
not
state
.
force_no_igemmlt
# default of pytorch behavior if inputs are empty
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
if
prod
(
A
.
shape
)
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment