Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1e88edd8
Commit
1e88edd8
authored
Jul 25, 2022
by
Tim Dettmers
Browse files
Removed rowscale (segfaults on ampere).
parent
8b1fd32e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
10 deletions
+6
-10
Makefile
Makefile
+0
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-9
tests/test_functional.py
tests/test_functional.py
+2
-0
No files found.
Makefile
View file @
1e88edd8
...
...
@@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_61,code
=
sm_61
# Pascal
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_70,code
=
sm_70
# Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
COMPUTE_CAPABILITY
+=
-gencode
arch
=
compute_72,code
=
sm_72
# Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92
:=
-gencode
arch
=
compute_30,code
=
sm_30
...
...
bitsandbytes/functional.py
View file @
1e88edd8
...
...
@@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
))
return
out
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
row_scale
=
None
,
dtype
=
torch
.
int32
):
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
dtype
=
torch
.
int32
):
shapeA
=
SA
[
0
]
shapeB
=
SB
[
0
]
dimsA
=
len
(
shapeA
)
...
...
@@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
elif
dimsA
==
3
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
if
row_scale
is
not
None
:
assert
row_scale
.
numel
()
==
out
.
shape
[
0
]
assert
dimsB
!=
3
,
'len(B.shape)==3 not supported'
assert
A
.
device
.
type
==
'cuda'
assert
B
.
device
.
type
==
'cuda'
...
...
@@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
ptrA
=
get_ptr
(
A
)
ptrB
=
get_ptr
(
B
)
ptrC
=
get_ptr
(
out
)
ptrRowScale
=
get_ptr
(
row_scale
)
k
=
shapeA
[
-
1
]
lda
=
ct
.
c_int32
(
m
*
32
)
...
...
@@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
k
=
ct
.
c_int32
(
k
)
has_error
=
0
ptrRowScale
=
get_ptr
(
None
)
if
formatB
==
'col_turing'
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
row_scale
is
None
:
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_turing_8
_rowscale
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
formatB
==
'col_ampere'
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
row_scale
is
None
:
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_ampere_8
_rowscale
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
if
has_error
==
1
:
raise
Exception
(
'cublasLt ran into an error!'
)
...
...
tests/test_functional.py
View file @
1e88edd8
...
...
@@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
formatB
=
F
.
get_special_format_str
()
err1
,
err2
,
err3
=
[],
[],
[]
...
...
@@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_row_scale_bench
(
dim1
,
dim4
,
inner
):
err1
,
err2
,
err3
=
[],
[],
[]
relerr1
,
relerr2
=
[],
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment