Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1417 additions
and
1007 deletions
+1417
-1007
tests/test_autograd.py
tests/test_autograd.py
+137
-63
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+39
-38
tests/test_functional.py
tests/test_functional.py
+801
-652
tests/test_modules.py
tests/test_modules.py
+175
-122
tests/test_optim.py
tests/test_optim.py
+265
-132
No files found.
tests/test_autograd.py
View file @
bfa0e332
import
p
ytes
t
from
itertools
import
p
roduc
t
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
itertools
import
product
import
bitsandbytes
as
bnb
n
=
1
k
=
25
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
funcs
=
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)]
str_funcs
=
[
'
bmm
'
,
'
matmul
'
]
str_funcs
=
[
"
bmm
"
,
"
matmul
"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'
FF
'
,
'
TF
'
,
'
TT
'
,
'
FT
'
]
req_grad_str
=
[
"
FF
"
,
"
TF
"
,
"
TT
"
,
"
FT
"
]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
'
FF
'
,
'
FT
'
,
'
TT
'
,
'
TF
'
]
str_transpose
=
[
"
FF
"
,
"
FT
"
,
"
TT
"
,
"
TF
"
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
...
...
@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
...
@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
...
...
@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
)
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
...
...
@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
...
...
@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
transpose
[
1
]:
...
...
@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
...
...
@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
#dim4 = (23,)
#
dim1 = (17,)
#
dim2 = (7,)
#
dim3 = (37,)
#
dim4 = (23,)
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
str_funcs
=
[
'
matmul
'
]
str_funcs
=
[
"
matmul
"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'
FF
'
,
'
TF
'
,
'
TT
'
,
'
FT
'
]
req_grad_str
=
[
"
FF
"
,
"
TF
"
,
"
TT
"
,
"
FT
"
]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
'
NT
'
,
'
NN
'
]
str_transpose
=
[
"
NT
"
,
"
NN
"
]
dtype
=
[
torch
.
float16
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
)
)
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
,
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
,
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
'
cuda
'
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"
cuda
"
)
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
...
...
@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state
.
threshold
=
decomp
state
.
has_fp16_weights
=
has_fp16_weights
if
not
has_fp16_weights
:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
B2
=
B2
.
t
().
contiguous
()
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B2
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
B2
=
B2
.
t
().
contiguous
()
(
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
,
)
=
bnb
.
functional
.
double_quant
(
B2
)
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
...
...
@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
#print(f'abs error {err:.4f}')
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
#
print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
has_fp16_weights
:
if
any
(
req_grad
):
...
...
@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
bfa0e332
import
pytest
import
os
from
typing
import
List
,
NamedTuple
import
pytest
from
typing
import
List
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
get_cuda_runtime_lib_path
,
tokenize_paths
)
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
get_cuda_runtime_lib_path
,
evaluate_cuda_setup
,
tokenize_paths
,
)
class
InputAndExpectedOutput
(
NamedTuple
):
input
:
str
output
:
str
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
tuple
[
str
,
str
]]
=
[
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
InputAndExpectedOutput
]
=
[
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
]
@
pytest
.
mark
.
parametrize
(
"test_input, expected"
,
HAPPY_PATH__LD_LIB_TEST_PATHS
)
@
pytest
.
fixture
(
params
=
HAPPY_PATH__LD_LIB_TEST_PATHS
)
def
happy_path_path_string
(
tmpdir
,
request
):
for
path
in
tokenize_paths
(
request
.
param
):
test_dir
.
mkdir
()
if
CUDA_RUNTIME_LIB
in
path
:
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
@
pytest
.
mark
.
parametrize
(
"test_input, expected"
,
HAPPY_PATH__LD_LIB_TEST_PATHS
)
def
test_get_cuda_runtime_lib_path__happy_path
(
tmp_path
,
test_input
:
str
,
expected
:
str
tmp_path
,
test_input
:
str
,
expected
:
str
):
for
path
in
tokenize_paths
(
test_input
):
assert
False
==
tmp_path
/
test_input
test_dir
.
mkdir
()
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
path
.
mkdir
()
(
path
/
CUDA_RUNTIME_LIB
).
touch
()
assert
get_cuda_runtime_lib_path
(
test_input
)
==
expected
...
...
@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
with
pytest
.
raises
(
FileNotFoundError
)
as
err_info
:
get_cuda_runtime_lib_path
(
test_input
)
assert
all
(
match
in
err_info
for
match
in
{
"duplicate"
,
CUDA_RUNTIME_LIB
}
)
assert
all
(
match
in
err_info
for
match
in
{
"duplicate"
,
CUDA_RUNTIME_LIB
})
def
test_get_cuda_runtime_lib_path__non_existent_dir
(
capsys
,
tmp_path
):
existent_dir
=
tmp_path
/
'
a/b
'
existent_dir
=
tmp_path
/
"
a/b
"
existent_dir
.
mkdir
()
non_existent_dir
=
tmp_path
/
'
c/d
'
# non-existent dir
non_existent_dir
=
tmp_path
/
"
c/d
"
# non-existent dir
test_input
=
":"
.
join
([
str
(
existent_dir
),
str
(
non_existent_dir
)])
get_cuda_runtime_lib_path
(
test_input
)
std_err
=
capsys
.
readouterr
().
err
assert
all
(
match
in
std_err
for
match
in
{
"WARNING"
,
"non-existent"
}
)
assert
all
(
match
in
std_err
for
match
in
{
"WARNING"
,
"non-existent"
})
def
test_full_system
():
## this only tests the cuda version and not compute capability
ld_path
=
os
.
environ
[
'
LD_LIBRARY_PATH
'
]
paths
=
ld_path
.
split
(
':'
)
version
=
''
ld_path
=
os
.
environ
[
"
LD_LIBRARY_PATH
"
]
paths
=
ld_path
.
split
(
":"
)
version
=
""
for
p
in
paths
:
if
'
cuda
'
in
p
:
idx
=
p
.
rfind
(
'
cuda-
'
)
version
=
p
[
idx
+
5
:
idx
+
5
+
4
].
replace
(
'/'
,
''
)
if
"
cuda
"
in
p
:
idx
=
p
.
rfind
(
"
cuda-
"
)
version
=
p
[
idx
+
5
:
idx
+
5
+
4
].
replace
(
"/"
,
""
)
version
=
float
(
version
)
break
binary_name
=
evaluate_cuda_setup
()
binary_name
=
binary_name
.
replace
(
'libbitsandbytes_cuda'
,
''
)
assert
binary_name
.
startswith
(
str
(
version
).
replace
(
'.'
,
''
))
binary_name
=
binary_name
.
replace
(
"libbitsandbytes_cuda"
,
""
)
assert
binary_name
.
startswith
(
str
(
version
).
replace
(
"."
,
""
))
tests/test_functional.py
View file @
bfa0e332
import
pytest
import
math
import
random
import
time
import
torch
import
bitsandbytes
as
bnb
import
einops
from
itertools
import
product
import
einops
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
k
=
20
def
assert_all_approx_close
(
a
,
b
,
rtol
,
atol
,
count
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
print
(
f
'
Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
print
(
f
"
Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
FFN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_features
,
hidden_size
,
bias
=
True
):
super
(
FFN
,
self
).
__init__
()
...
...
@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x
=
self
.
fc2
(
x
)
return
x
class
Timer
(
object
):
def
__init__
(
self
):
self
.
starts
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
def
tick
(
self
,
name
=
'
default
'
):
def
tick
(
self
,
name
=
"
default
"
):
if
name
not
in
self
.
starts
:
self
.
starts
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
ends
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
...
@@ -49,66 +54,70 @@ class Timer(object):
else
:
ms
=
self
.
tock
(
name
,
evict
=
True
,
print_ms
=
False
)
def
tock
(
self
,
name
=
'
default
'
,
evict
=
True
,
print_ms
=
True
):
def
tock
(
self
,
name
=
"
default
"
,
evict
=
True
,
print_ms
=
True
):
if
name
in
self
.
ends
:
self
.
ends
[
name
].
record
()
torch
.
cuda
.
synchronize
()
ms
=
self
.
starts
[
name
].
elapsed_time
(
self
.
ends
[
name
])
if
name
not
in
self
.
agg
:
self
.
agg
[
name
]
=
0.0
if
name
not
in
self
.
agg
:
self
.
agg
[
name
]
=
0.0
self
.
agg
[
name
]
+=
ms
if
evict
:
self
.
starts
.
pop
(
name
)
self
.
ends
.
pop
(
name
)
if
print_ms
and
name
in
self
.
agg
:
print
(
'
{0} took: {1:.5f}s
'
.
format
(
name
,
self
.
agg
[
name
]
/
1000.0
))
print
(
"
{0} took: {1:.5f}s
"
.
format
(
name
,
self
.
agg
[
name
]
/
1000.0
))
return
self
.
agg
[
name
]
def
reset
(
self
):
self
.
starts
=
{}
self
.
starts
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
print
(
'Resetting benchmark data'
)
print
(
"Resetting benchmark data"
)
def
setup
():
pass
def
teardown
():
pass
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
torch
.
testing
.
assert_allclose
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
quantiles
=
torch
.
quantile
(
A
.
float
(),
percs
)
diff
=
torch
.
abs
(
code
-
quantiles
)
diff
=
torch
.
abs
(
code
-
quantiles
)
assert
(
diff
>
5e-02
).
sum
().
item
()
==
0
def
test_quantile_quantization
():
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0075
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
assert
diff
<
0.001
...
...
@@ -117,22 +126,22 @@ def test_dynamic_quantization():
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
#
print(sum(diffs)/len(diffs))
#
print(sum(reldiffs)/len(reldiffs))
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
assert
diff
<
0.004
...
...
@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization():
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
#
print(sum(diffs)/len(diffs))
#
print(sum(reldiffs)/len(reldiffs))
diffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0033
diffs
.
append
(
diff
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
#print(sum(diffs)/len(diffs))
# print(sum(diffs)/len(diffs))
def
test_dynamic_blockwise_stochastic_quantization
():
diffs
=
[]
reldiffs
=
[]
rand
=
torch
.
rand
(
1024
).
cuda
()
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
# a maximunm distance of quantized values of 1
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
'
cuda
'
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
'
cuda
'
)
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"
cuda
"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"
cuda
"
)
n
=
4
step
=
0
percentile
=
5
percentile
=
5
for
i
in
range
(
k
):
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
'cuda'
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
gnorm2
=
torch
.
norm
(
g
.
float
())
if
step
==
1
:
...
...
@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype):
def
quant
(
x
):
max1
=
torch
.
abs
(
x
).
max
()
x
=
torch
.
round
(
x
/
max1
*
127
)
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
dequant
(
c
,
maxC
):
return
c
.
float
()
*
(
maxC
/
127
)
return
c
.
float
()
*
(
maxC
/
127
)
def
mm_dequant
(
maxA
,
maxB
,
C
):
return
C
.
float
()
*
(
maxA
/
127
)
*
(
maxB
/
127
)
return
C
.
float
()
*
(
maxA
/
127
)
*
(
maxB
/
127
)
def
quant_multi
(
x
,
dim
):
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_multi_chunk
(
x
,
dim
,
chunk_size
=
32
):
if
dim
==
1
:
x_chunked
=
einops
.
rearrange
(
x
,
'
(c a) b -> c a b
'
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
+
1
,
keepdim
=
True
)
if
dim
==
1
:
x_chunked
=
einops
.
rearrange
(
x
,
"
(c a) b -> c a b
"
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
+
1
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
1
,
1
,
x
.
shape
[
1
]))
max1
=
max1
.
view
(
x
.
shape
)
elif
dim
==
0
:
x_chunked
=
einops
.
rearrange
(
x
,
'
a (b c) -> a b c
'
,
c
=
chunk_size
)
elif
dim
==
0
:
x_chunked
=
einops
.
rearrange
(
x
,
"
a (b c) -> a b c
"
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
x
.
shape
[
0
],
1
,
1
))
max1
=
max1
.
view
(
x
.
shape
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_minmax
(
A
):
minA
=
A
.
min
()
maxA
=
A
.
max
()
def
mean
(
xx
):
return
sum
(
xx
)
/
float
(
len
(
xx
))
return
sum
(
xx
)
/
float
(
len
(
xx
))
#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
methods
=
[(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)]
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
methods
=
[
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)
]
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names
=
[
'
linear
'
,
'
vectorwise
'
]
#
methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names
=
[
"
linear
"
,
"
vectorwise
"
]
batched
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
'dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'
.
format
(
*
vals
)
for
vals
in
values_names
]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
dim1
=
dim1
-
(
dim1
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
relerrors
=
[]
print
(
''
)
print
(
""
)
for
i
in
range
(
5
):
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim2
//
32
,
dim1
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"
cuda
"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim2
//
32
,
dim1
),
device
=
"
cuda
"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
2
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
1
)
else
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"
cuda
"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_allclose
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
torch
.
testing
.
assert_allclose
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
if
batched
:
out2
=
torch
.
bmm
(
A
,
B
)
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
...
...
@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C
=
F
.
igemm
(
Ac
,
Bc
)
out
=
quant_methods
[
4
](
maxA
,
maxB
,
C
)
std
=
out2
.
std
()
out
/=
std
out2
/=
std
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
out
/=
std
out2
/=
std
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
print
(
mean
(
errors
))
print
(
mean
(
relerrors
))
def
test_stable_embedding
():
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
.
reset_parameters
()
n
=
2
hidden_dim
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)]
values
=
list
(
product
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
))
names
=
[
'hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
))
names
=
[
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
shapeA
=
(
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
...
...
@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
...
...
@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n
=
3
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
'seq_dim{0}_hidden_dim{1}_batch_dim{2}'
.
format
(
*
vals
)
for
vals
in
values
]
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
def
test_dim3_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
):
seq_dim
=
seq_dim
-
(
seq_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
for
i
in
range
(
25
):
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
'cuda'
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
'bsi, bso->io'
,
A
.
float
(),
B
.
float
())
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
n
=
2
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
transpose
=
[
False
,
True
]
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
))
names
=
[
'seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
))
names
=
[
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
2
,
keepdim
=
True
)
scale
=
(
maxA
-
minA
)
/
2.0
return
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
),
minA
,
scale
scale
=
(
maxA
-
minA
)
/
2.0
return
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
),
minA
,
scale
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
16
)
...
...
@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"
cuda
"
)
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"
cuda
"
)
else
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
hidden_dim
,
256
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
hidden_dim
,
256
),
device
=
"
cuda
"
)
Ac
,
minA
,
scale
=
min_max
(
A
)
if
transpose
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
(
1
if
transpose
else
0
))
out
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out2
=
torch
.
matmul
(
A
,
B
.
t
())
offset
=
B
.
t
().
sum
(
0
)
*
(
minA
+
scale
)
out2
=
torch
.
matmul
(
A
,
B
.
t
())
offset
=
B
.
t
().
sum
(
0
)
*
(
minA
+
scale
)
out
=
out
.
float
()
out
=
(
out
*
maxB
.
t
()
*
scale
/
(
127
*
127
))
+
offset
out
=
(
out
*
maxB
.
t
()
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out3
=
mm_dequant
(
maxA
,
maxB
.
t
(),
out3
)
else
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
0
)
offset
=
B
.
sum
(
0
)
*
(
minA
+
scale
)
offset
=
B
.
sum
(
0
)
*
(
minA
+
scale
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
out2
=
torch
.
matmul
(
A
,
B
)
out
=
out
.
float
()
out
=
(
out
*
maxB
*
scale
/
(
127
*
127
))
+
offset
out
=
(
out
*
maxB
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
)
...
...
@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out
/=
std
out3
/=
std
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
(
torch
.
abs
(
out2
)
+
1e-7
)
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
(
torch
.
abs
(
out2
)
+
1e-7
)
err2
=
torch
.
abs
(
out3
-
out2
)
relerr2
=
err2
/
(
torch
.
abs
(
out2
)
+
1e-7
)
err2
=
torch
.
abs
(
out3
-
out2
)
relerr2
=
err2
/
(
torch
.
abs
(
out2
)
+
1e-7
)
errs
.
append
(
err
.
mean
().
item
())
relerrs
.
append
(
relerr
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#print(mean(errs))
#print(mean(relerrs))
#print(mean(errs2))
#print(mean(relerrs2))
#
print(mean(errs))
#
print(mean(relerrs))
#
print(mean(errs2))
#
print(mean(relerrs2))
assert
mean
(
errs
)
<
0.015
assert
mean
(
relerrs
)
<
0.3
n
=
2
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, transpose"
,
values
,
ids
=
names
)
def
test_ibmm
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
...
...
@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for
i
in
range
(
k
):
shapeA
=
(
dim1
,
dim3
,
dim2
)
if
transpose
[
0
]
else
(
dim1
,
dim2
,
dim3
)
shapeB
=
(
dim1
,
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim1
,
dim3
,
dim4
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
float
(),
B
.
float
())
...
...
@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
n
=
1
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3"
,
values
,
ids
=
names
)
def
test_vector_quant
(
dim1
,
dim2
,
dim3
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
for
i
in
range
(
k
):
A
=
torch
.
randn
(
size
=
(
dim2
,
dim3
),
device
=
'
cuda
'
)
A
=
torch
.
randn
(
size
=
(
dim2
,
dim3
),
device
=
"
cuda
"
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
torch
.
testing
.
assert_allclose
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
)
n
=
2
dim1
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
#dim1, dim2 = (256,), (256,)
dim1
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
#
dim1, dim2 = (256,), (256,)
dtype
=
[
torch
.
int8
,
torch
.
int32
]
a_order
=
[
'
row
'
]
out_order
=
[
'
col
'
,
'
row
'
,
'
col32
'
]
a_order
=
[
"
row
"
]
out_order
=
[
"
col
"
,
"
row
"
,
"
col32
"
]
transpose
=
[
False
]
dims
=
[
2
,
3
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
if
dims
==
3
and
out_order
!=
'col32'
:
return
if
dtype
==
torch
.
int32
and
out_order
!=
'col32'
:
return
if
dims
==
3
and
out_order
!=
"col32"
:
return
if
dtype
==
torch
.
int32
and
out_order
!=
"col32"
:
return
func
=
F
.
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
)
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
to
(
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
if
orderOut
==
'
row
'
:
if
orderOut
==
"
row
"
:
torch
.
testing
.
assert_allclose
(
A
.
flatten
(),
out
.
flatten
())
elif
orderOut
==
'
col
'
:
elif
orderOut
==
"
col
"
:
torch
.
testing
.
assert_allclose
(
A
.
t
().
flatten
(),
out
.
flatten
())
elif
orderOut
==
'
col32
'
:
elif
orderOut
==
"
col32
"
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
assert
out
.
numel
()
==
n
elif
orderOut
==
'
col_turing
'
:
elif
orderOut
==
"
col_turing
"
:
# 32 col 8 row tiles
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
))
)
assert
out
.
numel
()
==
n
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
for
row
in
range
(
A
.
shape
[
0
]):
for
col
in
range
(
A
.
shape
[
1
]):
i
=
row
*
A
.
shape
[
1
]
i
=
row
*
A
.
shape
[
1
]
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
row2
=
(
row
%
8
)
*
32
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
#assert A.flatten()[i+j] == out.flatten()[row2+col2]
#torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
#torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
'col32'
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
'row'
,
state
=
S
)
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
n
=
1
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2]
#dim2 = [2]
#dim3 = [2]
#dim4 = [2]
#
dim1 = [2]
#
dim2 = [2]
#
dim3 = [2]
#
dim4 = [2]
dims
=
(
2
,
3
)
dims
=
(
2
,
3
)
ldb
=
[
0
]
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'
.
format
(
*
vals
)
for
vals
in
values
]
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims, ldb"
,
values
,
ids
=
names
)
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
B2
,
SB
=
F
.
transform
(
B
,
'
col_turing
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
"
col_turing
"
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
'
cuda
'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"
cuda
"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
'
col_turing
'
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"
col_turing
"
,
transpose
=
True
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
dim1
=
[
32
]
dim2
=
[
32
]
dim3
=
[
32
]
dim4
=
[
32
]
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims"
,
values
,
ids
=
names
)
def
test_igemmlt_half
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
):
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
'
cuda
'
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"
cuda
"
).
half
()
elif
dims
==
3
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
'
cuda
'
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
bnb
.
matmul
(
A
,
B
.
t
())
...
...
@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
output
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsAt
,
statsBt
)
#print('')
#print(output.flatten()[:10])
#print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#
print('')
#
print(output.flatten()[:10])
#
print(C1.flatten()[:10])
#
print(C2.flatten()[:10])
#torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
#B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
#C1 = torch.matmul(A.float(), B.float())
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
#B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
#C2, SC = F.igemmlt(A2, B2t, SA, SBt)
#C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
batch_size
=
2
seqdim
=
512
#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values
=
[(
batch_size
,
seqdim
,
4
*
1024
,
3
*
4
*
1024
),(
batch_size
,
seqdim
,
5120
,
3
*
5120
),(
batch_size
,
seqdim
,
12
*
1024
,
4
*
12
*
1024
)]
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values
=
[
(
batch_size
,
seqdim
,
4
*
1024
,
3
*
4
*
1024
),
(
batch_size
,
seqdim
,
5120
,
3
*
5120
),
(
batch_size
,
seqdim
,
12
*
1024
,
4
*
12
*
1024
),
]
# values = list(product(batch, seq, model, hidden))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
#values = list(product(batch, seq, model, hidden))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_8bit_training
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
grad
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
w1
=
torch
.
randint
(
-
128
,
127
,
size
=
(
hidden
,
model
),
device
=
'
cuda
'
).
half
()
w2
=
torch
.
randint
(
-
128
,
127
,
size
=
(
model
,
hidden
),
device
=
'
cuda
'
).
half
()
print
(
''
)
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
grad
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
w1
=
torch
.
randint
(
-
128
,
127
,
size
=
(
hidden
,
model
),
device
=
"
cuda
"
).
half
()
w2
=
torch
.
randint
(
-
128
,
127
,
size
=
(
model
,
hidden
),
device
=
"
cuda
"
).
half
()
print
(
""
)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
## warmup
#for i in range(100):
#
for i in range(100):
# torch.matmul(A, w1.t())
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
dtype
=
torch
.
int8
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
...
...
@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0
=
time
.
time
()
for
i
in
range
(
k
):
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
#out2 = torch.matmul(out1, w2.t())# fc2
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
#
out2 = torch.matmul(out1, w2.t())# fc2
#d1 = torch.matmul(grad, w2) # delta1
#d2 = torch.matmul(d1, w1) # delta2
#
d1 = torch.matmul(grad, w2) # delta1
#
d2 = torch.matmul(d1, w1) # delta2
#grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
#
grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#
grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch
.
cuda
.
synchronize
()
t16
=
time
.
time
()
-
t0
print
(
t16
)
#torch.cuda.empty_cache()
#
torch.cuda.empty_cache()
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#
CTw1, Sw1 = F.transform2(Cw1, formatB)
#
CTw2, Sw2 = F.transform2(Cw2, formatB)
#
CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#
CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#C32A, SA = F.transform2(CA, 'col32')
#
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#
C32A, SA = F.transform2(CA, 'col32')
## fc1
#out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
#
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
#Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#C32out1, Sout1 = F.transform2(Cout1, 'col32')
#out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
#
Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#
C32out1, Sout1 = F.transform2(Cout1, 'col32')
#
out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
#Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#C32grad, Sgrad = F.transform2(Cgrad, 'col32')
#
Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#
C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
#Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#C32d1, Sd1 = F.transform2(Cd1, 'col32')
#
Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#
C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
#C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
#
C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#
CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
#C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
#
C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#
CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(k):
#
CTw1, Sw1 = F.transform2(Cw1, formatB)
#
CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#
CTw2, Sw2 = F.transform2(Cw2, formatB)
#
CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
...
...
@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
#torch.cuda.synchronize()
#t8 = time.time() - t0
#print(t8)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
n
=
2
dim1
=
torch
.
randint
(
64
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
64
,
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
64
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
64
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2*1024]
#dim4 = [2*1024]
#
dim1 = [2*1024]
#
dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
#
dim1 = [4]
#
dim4 = [4]
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
formatB
=
[
'col_turing'
,
'col_ampere'
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
'dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
# ldb = list(range(256, 1*1024, 256))
formatB
=
[
"col_turing"
,
"col_ampere"
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
formatB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
n
=
C1
.
numel
()
p
=
0.06
assert
count
/
n
<
p
,
f
'error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
'
assert
(
count
/
n
<
p
),
f
"error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
"
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
())
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
#print(C2)
# print(C2)
n
=
2
dim1
=
[
1
*
1024
]
dim2
=
[
1
*
1024
]
#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
1024
]
dim2
=
[
1
*
1024
]
#
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dims_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dims
))
names
=
[
"dim1_{0}_dim2_{1}_dims_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dims"
,
values
,
ids
=
names
)
def
test_colrow_absmax
(
dim1
,
dim2
,
dims
):
for
i
in
range
(
k
):
threshold
=
3.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
A_truncated
=
A
.
clone
()
A_truncated
[
torch
.
abs
(
A_truncated
)
>=
3.0
]
=
0.0
if
dims
==
2
:
...
...
@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims):
else
:
assert
False
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
'(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size'
,
row_tiles
=
16
,
block_size
=
64
*
4
)
nnz_rows1_counts
=
(
torch
.
abs
(
A_blocked
)
>=
threshold
).
sum
(
3
).
flatten
()
nnz_block_ptr1
=
torch
.
zeros
(
nnz_rows1_counts
.
shape
[
0
]
+
1
,
dtype
=
nnz_rows1_counts
.
dtype
,
device
=
nnz_rows1_counts
.
device
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size"
,
row_tiles
=
16
,
block_size
=
64
*
4
,
)
nnz_rows1_counts
=
(
torch
.
abs
(
A_blocked
)
>=
threshold
).
sum
(
3
).
flatten
()
nnz_block_ptr1
=
torch
.
zeros
(
nnz_rows1_counts
.
shape
[
0
]
+
1
,
dtype
=
nnz_rows1_counts
.
dtype
,
device
=
nnz_rows1_counts
.
device
,
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
torch
.
testing
.
assert_allclose
(
col_stats1_trunc
,
col_stats2
)
...
...
@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims):
assert
nnz_block_ptr2
is
None
n
=
2
#dim1 = [8*1024]
#dim2 = [4*1024]
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
# dim1 = [8*1024]
# dim2 = [4*1024]
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_double_quant
(
dim1
,
dim2
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
out_col1
,
Scol
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
out_row1
,
Srow
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
...
...
@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_allclose
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
'
)
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
assert
False
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
'
)
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
assert
False
torch
.
testing
.
assert_allclose
(
Srow
.
flatten
(),
statsA
)
...
...
@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2):
n
=
4
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
[
6
]
dim4
=
[
4
]
inner
=
[
8
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
...
...
@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch
.
testing
.
assert_allclose
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_allclose
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
'
col_turing
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
"
col_turing
"
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'
col32
'
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
'
col_turing
'
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
"
col_turing
"
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
out3
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out3
).
mean
().
item
()
assert
err2
<=
err1
*
1.01
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out3
).
mean
().
item
()
assert
err2
<=
err1
*
1.01
n
=
6
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
...
...
@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1
,
relerr2
=
[],
[]
scale
=
1
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"
row
"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
scale
=
1.5
else
:
scale
=
maxval
/
120
out3
=
C3
*
maxA
*
absmaxB
*
c
/
(
127
*
127
)
scale
=
maxval
/
120
out3
=
C3
*
maxA
*
absmaxB
*
c
/
(
127
*
127
)
C4
=
torch
.
matmul
(
C1a
.
float
(),
CB
.
float
().
t
())
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
CA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
,
quant_type
=
'
vector
'
)
CB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'
linear
'
)
CA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
,
quant_type
=
"
vector
"
)
CB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"
linear
"
)
C
=
torch
.
matmul
(
CA
.
float
(),
CB
.
t
().
float
())
out4
=
C
*
SA
*
SB
/
(
127
*
127
)
#out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
out4
=
C
*
SA
*
SB
/
(
127
*
127
)
#
out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#print('='*80)
#print(out1)
#print(out2)
#print(out3)
#
print('='*80)
#
print(out1)
#
print(out2)
#
print(out3)
#print(out1)
#print(out2)
#print(out3)
err1
.
append
(
torch
.
abs
(
out1
-
out2
).
mean
().
item
())
err2
.
append
(
torch
.
abs
(
out1
-
out3
).
mean
().
item
())
err3
.
append
(
torch
.
abs
(
out1
-
out4
).
mean
().
item
())
#
print(out1)
#
print(out2)
#
print(out3)
err1
.
append
(
torch
.
abs
(
out1
-
out2
).
mean
().
item
())
err2
.
append
(
torch
.
abs
(
out1
-
out3
).
mean
().
item
())
err3
.
append
(
torch
.
abs
(
out1
-
out4
).
mean
().
item
())
#assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print
(
''
)
print
(
sum
(
err1
)
/
len
(
err1
))
print
(
sum
(
err2
)
/
len
(
err2
))
print
(
sum
(
err3
)
/
len
(
err3
))
#
assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print
(
""
)
print
(
sum
(
err1
)
/
len
(
err1
))
print
(
sum
(
err2
)
/
len
(
err2
))
print
(
sum
(
err3
)
/
len
(
err3
))
dim1
=
[
1024
,
2048
]
inner
=
[
12288
*
4
,
4096
*
4
]
inner
=
[
12288
*
4
,
4096
*
4
]
dim4
=
[
12288
,
4096
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_row_scale_bench
(
dim1
,
dim4
,
inner
):
err1
,
err2
,
err3
=
[],
[],
[]
relerr1
,
relerr2
=
[],
[]
scale
=
1
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
# warmpup
for
i
in
range
(
k
):
...
...
@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner):
for
i
in
range
(
k
):
C1
=
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
'
16
'
,
time
.
time
()
-
t0
)
print
(
"
16
"
,
time
.
time
()
-
t0
)
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
row_scale
=
maxA
/
c
c
=
10.0
*
inner
*
scale
row_scale
=
maxA
/
c
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
print
(
'row-wise'
,
time
.
time
()
-
t0
)
print
(
"row-wise"
,
time
.
time
()
-
t0
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
...
...
@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
print
(
'vector-wise'
,
time
.
time
()
-
t0
)
print
(
"vector-wise"
,
time
.
time
()
-
t0
)
n
=
2
dim1
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [8*1024]
#dim2 = [4*1024]
dim1
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
#
dim1 = [8*1024]
#
dim2 = [4*1024]
dim3
=
[
0
]
dtype
=
[
torch
.
int8
]
a_order
=
[
'
row
'
]
out_order
=
[
'
col32
'
,
'
col_turing
'
,
'
col_ampere
'
]
a_order
=
[
"
row
"
]
out_order
=
[
"
col32
"
,
"
col_turing
"
,
"
col_ampere
"
]
transpose
=
[
False
,
True
]
dims
=
[
2
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
...
...
@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert
S1
[
0
][
0
]
==
S2
[
0
][
0
]
assert
S1
[
0
][
1
]
==
S2
[
0
][
1
]
#print(out1)
#print(out2)
#
print(out1)
#
print(out2)
torch
.
testing
.
assert_allclose
(
out1
,
out2
)
n
=
2
#dim1 = torch.randint(2,1024, size=(n,)).tolist()
#dim2 = torch.randint(2,1024, size=(n,)).tolist()
#
dim1 = torch.randint(2,1024, size=(n,)).tolist()
#
dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1
=
[
1
]
dim2
=
[
33
]
dtype
=
[
torch
.
int8
]
#a_order = ['col_turing', 'col_ampere']
a_order
=
[
'col_turing'
]
out_order
=
[
'row'
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
a_order
,
out_order
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
# a_order = ['col_turing', 'col_ampere']
a_order
=
[
"col_turing"
]
out_order
=
[
"row"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
a_order
,
out_order
))
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
for
i
in
range
(
1
):
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
out2
,
S2
=
F
.
transform
(
A
,
to_order
=
orderA
)
A2
,
S3
=
F
.
transform
(
out2
,
from_order
=
orderA
,
to_order
=
'
row
'
,
state
=
S2
)
A2
,
S3
=
F
.
transform
(
out2
,
from_order
=
orderA
,
to_order
=
"
row
"
,
state
=
S2
)
assert
A2
.
shape
[
0
]
==
A
.
shape
[
0
]
assert
A2
.
shape
[
1
]
==
A
.
shape
[
1
]
print
(
''
)
print
(
""
)
print
(
A
)
print
(
out2
)
print
(
A2
)
#torch.testing.assert_allclose(A, A2)
# torch.testing.assert_allclose(A, A2)
def
test_overflow
():
formatB
=
F
.
get_special_format_str
()
print
(
formatB
)
for
i
in
range
(
2
):
a
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
b
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
a
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
b
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
Ca
,
Sa
=
F
.
nvidia_transform
(
a
,
'
col32
'
)
Ca
,
Sa
=
F
.
nvidia_transform
(
a
,
"
col32
"
)
Cb
,
Sb
=
F
.
nvidia_transform
(
b
,
formatB
)
c
=
F
.
igemmlt
(
Ca
,
Cb
,
Sa
,
Sb
,
dtype
=
torch
.
int8
)
...
...
@@ -1198,46 +1315,51 @@ def test_overflow():
n
=
2
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [4]
#dim2 = [5]
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
# dim1 = [4]
# dim2 = [5]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_coo_double_quant
(
dim1
,
dim2
):
threshold
=
3.00
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
idx
=
(
torch
.
abs
(
A
)
>=
threshold
)
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
n
=
2
dim1
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [7]
#dim2 = [11]
dim1
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
#
dim1 = [7]
#
dim2 = [11]
transposed_B
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
transposed_B
))
names
=
[
'dim1_{0}_dim2_{1}_transposed_B_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
,
transposed_B
))
names
=
[
"dim1_{0}_dim2_{1}_transposed_B_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, transposed_B"
,
values
,
ids
=
names
)
def
test_spmm_coo
(
dim1
,
dim2
,
transposed_B
):
threshold
=
1.5
dim3
=
torch
.
randint
(
32
,
128
,
size
=
(
1
,)).
item
()
#dim3 = 17
#
dim3 = 17
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
if
transposed_B
:
...
...
@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
if
transposed_B
:
out2
=
F
.
spmm_coo
(
cooA
,
B
.
t
())
...
...
@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
30
)
def
test_spmm_bench
():
batch
=
2
model
=
1024
*
1
hidden
=
model
*
4
model
=
1024
*
1
hidden
=
model
*
4
seq
=
1024
dim1
=
batch
*
seq
dim1
=
batch
*
seq
dim2
=
model
dim3
=
hidden
threshold
=
4
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim3
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim2
,
dim3
,
device
=
"
cuda
"
).
half
()
for
i
in
range
(
10
):
C1
=
bnb
.
matmul
(
A
,
B
)
...
...
@@ -1282,14 +1405,16 @@ def test_spmm_bench():
for
i
in
range
(
k
):
C1
=
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
t8
=
time
.
time
()
-
t0
t8
=
time
.
time
()
-
t0
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
print
(
nnz
/
idx
.
numel
())
print
(
nnz
/
idx
.
numel
())
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
for
i
in
range
(
10
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
...
...
@@ -1299,20 +1424,22 @@ def test_spmm_bench():
for
i
in
range
(
k
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
cuda
.
synchronize
()
tsp
=
time
.
time
()
-
t0
tsp
=
time
.
time
()
-
t0
print
(
tsp
,
t8
)
print
(
tsp
/
t8
)
print
(
tsp
/
t8
)
n
=
2
dim1
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
dim1
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_integrated_sparse_decomp
(
dim1
,
dim2
):
threshold
=
3.0
formatB
=
'
col_turing
'
formatB
=
"
col_turing
"
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
w1
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
...
...
@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1
,
Sw1
=
F
.
transform
(
Cw1
,
formatB
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out3
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
...
...
@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4
=
F
.
spmm_coo
(
coo_tensor
,
w1
.
t
())
out5
=
out3
+
out4
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out5
).
mean
().
item
()
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out5
).
mean
().
item
()
assert
err2
<
err1
...
...
@@ -1350,91 +1477,95 @@ def test_matmuls():
c2
=
bnb
.
matmul
(
a
,
b
)
c3
=
bnb
.
matmul
(
a
,
b
)
err1
=
torch
.
abs
(
c1
-
c2
).
mean
().
item
()
err2
=
torch
.
abs
(
c1
-
c3
).
mean
().
item
()
err1
=
torch
.
abs
(
c1
-
c2
).
mean
().
item
()
err2
=
torch
.
abs
(
c1
-
c3
).
mean
().
item
()
assert
err1
<
0.2
assert
err2
<
0.2
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
#
dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
dim2
=
[
12288
]
#dim1 = [32]
#dim2 = [32]
#dtype = [torch.float16, torch.int8]
#
dim1 = [32]
#
dim2 = [32]
#
dtype = [torch.float16, torch.int8]
dtype
=
[
torch
.
float16
]
out_function
=
[
'zeros'
,
'ones'
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
out_function
=
[
"zeros"
,
"ones"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
def
test_spmm_coo_very_sparse
(
dim1
,
dim2
,
dtype
,
out_func
):
out_func
=
getattr
(
torch
,
out_func
)
threshold
=
3.3
#threshold = 2.8
#threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
#
threshold = 2.8
#
threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
if
dtype
==
torch
.
float16
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
else
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B
,
SB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
#B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
B
,
SB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
#
B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print
(
''
)
print
(
""
)
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out1
+=
out
.
clone
()
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
B
,
out
=
out
)
#print(B)
#print(out1)
#print(out2)
p
=
200
/
(
2048
*
12288
*
4
)
#
print(B)
#
print(out1)
#
print(out2)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
count
=
math
.
ceil
(
p
*
n
)
std
=
out1
.
std
()
out1
/=
std
out2
/=
std
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
#assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
#
assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
#torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#
torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#torch.cuda.synchronize()
#t0 = time.time()
#print(A2.shape, B.shape)
#for i in range(100):
#
Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#
torch.cuda.synchronize()
#
t0 = time.time()
#
print(A2.shape, B.shape)
#
for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
#torch.cuda.synchronize()
#print(time.time() - t0)
# torch.cuda.synchronize()
# print(time.time() - t0)
def
test_layout
():
a1
=
torch
.
rand
(
16
,
64
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
a1
=
torch
.
arange
(
16
*
64
,
device
=
'
cuda
'
).
reshape
(
16
,
64
).
byte
()
a2
,
s2
=
F
.
transform
(
a1
,
'
col_turing
'
)
a1
=
torch
.
rand
(
16
,
64
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
a1
=
torch
.
arange
(
16
*
64
,
device
=
"
cuda
"
).
reshape
(
16
,
64
).
byte
()
a2
,
s2
=
F
.
transform
(
a1
,
"
col_turing
"
)
print
(
a2
.
shape
)
print
(
a1
.
flatten
()[
8
*
64
:
8
*
64
+
32
])
print
(
a1
.
flatten
()[
8
*
64
:
8
*
64
+
32
])
for
i
in
range
(
4
):
print
(
a2
.
flatten
()[
i
*
8
*
32
:
i
*
8
*
32
+
32
],
0
)
print
(
a2
.
flatten
()[
i
*
8
*
32
:
i
*
8
*
32
+
32
],
0
)
def
test_coo2csr
():
...
...
@@ -1444,14 +1575,16 @@ def test_coo2csr():
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
csrA
=
F
.
coo2csr
(
cooA
)
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
1
))
idx
=
(
A2
!=
0
)
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
1
))
idx
=
A2
!=
0
torch
.
testing
.
assert_allclose
(
A2
[
idx
],
csrA
.
values
)
...
...
@@ -1462,41 +1595,43 @@ def test_coo2csc():
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cscA
=
F
.
coo2csc
(
cooA
)
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
0
))
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
0
))
# torch uses row-major -> use transpose to transfer to col-major
idx
=
(
A2
.
t
()
!=
0
)
idx
=
A2
.
t
()
!=
0
torch
.
testing
.
assert_allclose
(
A2
.
t
()[
idx
],
cscA
.
values
)
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
#dim2 = [12288]
#
dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
#
dim2 = [12288]
dim2
=
[
2048
]
#dim1 = [2]
#dim2 = [2]
#
dim1 = [2]
#
dim2 = [2]
dtype
=
[
torch
.
int8
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
))
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype"
,
values
,
ids
=
names
)
def
test_spmm_coo_dequant
(
dim1
,
dim2
,
dtype
):
threshold
=
6.0
#threshold = 2.8
#threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
empty
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
#
threshold = 2.8
#
threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
empty
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
Bt
=
B
.
t
().
contiguous
()
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
rowidx
=
torch
.
randint
(
0
,
A
.
shape
[
-
1
],
size
=
(
15
,))
...
...
@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out3
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
.
half
())
out3
=
out3
*
statsBt
.
half
()
/
127
out3
=
out3
*
statsBt
.
half
()
/
127
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
offset
=
counts
.
cumsum
(
0
).
int
()
...
...
@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch
.
testing
.
assert_allclose
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
p
=
200
/
(
2048
*
12288
*
4
)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
count
=
math
.
ceil
(
p
*
n
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(100):
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
#torch.cuda.synchronize()
#print('fp16', time.time() - t0)
#
torch.cuda.synchronize()
#
print('fp16', time.time() - t0)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
cuda
.
synchronize
()
print
(
'
cusparse fp16
'
,
time
.
time
()
-
t0
)
print
(
"
cusparse fp16
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
)
torch
.
cuda
.
synchronize
()
print
(
'
int8
'
,
time
.
time
()
-
t0
)
print
(
"
int8
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
torch
.
cuda
.
synchronize
()
print
(
'
int8+dequant
'
,
time
.
time
()
-
t0
)
print
(
"
int8+dequant
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
torch
.
matmul
(
A
,
B
)
out2
=
torch
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
'
matmul
'
,
time
.
time
()
-
t0
)
print
(
"
matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out
=
out1
+
out2
out
=
out1
+
out2
torch
.
cuda
.
synchronize
()
print
(
'
sparse+ matmul
'
,
time
.
time
()
-
t0
)
print
(
"
sparse+ matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
...
...
@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
matmul
(
A
[:,
rowidx
],
Bt
.
t
()[
rowidx
],
out
=
out1
)
torch
.
cuda
.
synchronize
()
print
(
'
partial matmul
'
,
time
.
time
()
-
t0
)
print
(
"
partial matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
cuda
.
synchronize
()
print
(
'partial matmul'
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
1
seqdim
=
2048
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
...
...
@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
)
linearMixedBit
.
eval
()
# warmup
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
''
)
print
(
""
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
f
'pytorch: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"pytorch: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
...
...
@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
print
(
f
'igemmlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"igemmlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
...
...
@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'
row
'
,
state
=
Sout32
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"
row
"
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
print
(
f
'vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'
linear
'
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"
linear
"
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
'
linear
'
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'
col32
'
)
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
"
linear
"
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'
row
'
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"
row
"
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
torch
.
cuda
.
synchronize
()
print
(
f
'linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
...
...
@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
...
...
@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
def
test_zeropoint
():
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
1
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
1
,
keepdim
=
True
)
midpoint
=
(
maxA
-
minA
)
/
2.0
dyna
=
252
/
(
maxA
-
minA
)
#dyna *= 0.98
x
=
dyna
*
x
x
=
x
-
torch
.
round
((
dyna
*
(
minA
+
midpoint
)))
midpoint
=
(
maxA
-
minA
)
/
2.0
dyna
=
252
/
(
maxA
-
minA
)
#
dyna *= 0.98
x
=
dyna
*
x
x
=
x
-
torch
.
round
((
dyna
*
(
minA
+
midpoint
)))
return
x
.
to
(
torch
.
int8
),
minA
,
midpoint
,
dyna
batch
=
2
seq
=
2
model
=
4
hidden
=
2
*
model
#batch = 4
#seq = 2048
#model = 1024
#hidden = 8*model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'
cuda
'
).
half
()
-
0.4
B
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
model
,
hidden
,
device
=
'
cuda
'
).
half
())
#A[0] = 0
#B[:, 0] = 0
#A = A*(A>0)
#A[0, 0] = 0
#A[0, 0] = 6.0
hidden
=
2
*
model
#
batch = 4
#
seq = 2048
#
model = 1024
#
hidden = 8*model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
"
cuda
"
).
half
()
-
0.4
B
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
model
,
hidden
,
device
=
"
cuda
"
).
half
())
#
A[0] = 0
#
B[:, 0] = 0
#
A = A*(A>0)
#
A[0, 0] = 0
#
A[0, 0] = 6.0
Ac
,
minA
,
midpoint
,
dyna
=
min_max
(
A
)
#print(Ac[0, 0], 'zero')
#print(Ac, Ac.min(), Ac.max())
Bc
,
maxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
#
print(Ac[0, 0], 'zero')
#
print(Ac, Ac.min(), Ac.max())
Bc
,
maxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
offset
=
B
.
sum
(
0
)
*
torch
.
round
(
dyna
*
(
minA
+
midpoint
))
/
dyna
out2
=
torch
.
matmul
(
A
,
B
)
offset
=
B
.
sum
(
0
)
*
torch
.
round
(
dyna
*
(
minA
+
midpoint
))
/
dyna
out
=
out
.
float
()
#print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1
=
maxB
/
127
C4
=
(
out
/
dyna
)
*
norm1
+
offset
# print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1
=
maxB
/
127
C4
=
(
out
/
dyna
)
*
norm1
+
offset
B1
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B2
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B3
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B4
=
torch
.
nn
.
Parameter
(
B
.
clone
())
C1
=
torch
.
matmul
(
A
,
B1
)
C2
=
bnb
.
matmul_cublas
(
A
,
B2
,
None
,
'
linear
'
)
C3
=
bnb
.
matmul_cublas
(
A
,
B3
,
None
,
'
zeropoint
'
)
C4
=
bnb
.
matmul_cublas
(
A
,
B4
,
None
,
'
vector-zeropoint
'
)
C2
=
bnb
.
matmul_cublas
(
A
,
B2
,
None
,
"
linear
"
)
C3
=
bnb
.
matmul_cublas
(
A
,
B3
,
None
,
"
zeropoint
"
)
C4
=
bnb
.
matmul_cublas
(
A
,
B4
,
None
,
"
vector-zeropoint
"
)
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
#assert err1 > err2
#
assert err1 > err2
loss1
=
C1
.
mean
()
loss2
=
C2
.
mean
()
...
...
@@ -1765,40 +1917,38 @@ def test_zeropoint():
print
(
B2
.
grad
)
print
(
B3
.
grad
)
print
(
B4
.
grad
)
err1
=
torch
.
abs
(
B1
.
grad
-
B2
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
B1
.
grad
-
B3
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
B1
.
grad
-
B4
.
grad
).
mean
().
item
()
err1
=
torch
.
abs
(
B1
.
grad
-
B2
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
B1
.
grad
-
B3
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
B1
.
grad
-
B4
.
grad
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
def
test_zp
():
def
quant_zp
(
x
):
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
qx
=
254.
/
dyna
if
dyna
==
0
:
dyna
=
1
qx
=
254.0
/
dyna
minx
=
x
.
min
()
#zpx = torch.round(minx* qx)
#zpx = 127 - torch.round(x.max()* qx)
zpx
=
torch
.
round
(
x
.
min
()
*
qx
)
-
127
x
=
(
qx
*
x
)
+
zpx
#
zpx = torch.round(minx* qx)
#
zpx = 127 - torch.round(x.max()* qx)
zpx
=
torch
.
round
(
x
.
min
()
*
qx
)
-
127
x
=
(
qx
*
x
)
+
zpx
return
x
,
qx
,
zpx
batch
=
2
seq
=
512
model
=
1024
hidden
=
4
*
model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'cuda'
).
half
()
*
0.1
B
=
torch
.
randn
(
model
,
hidden
,
device
=
'cuda'
).
half
()
*
0.1
hidden
=
4
*
model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
"cuda"
).
half
()
*
0.1
B
=
torch
.
randn
(
model
,
hidden
,
device
=
"cuda"
).
half
()
*
0.1
C0
=
torch
.
matmul
(
A
,
B
)
#A, SA = F.vectorwise_quant(A, quant_type='linear')
#B, SB = F.vectorwise_quant(B, quant_type='linear')
# A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
A
=
A
.
float
()
B
=
B
.
float
()
...
...
@@ -1806,69 +1956,68 @@ def test_zp():
C3
=
bnb
.
matmul
(
A
.
half
(),
B
.
t
().
contiguous
().
half
())
zp
=
1
#C2 = torch.matmul(A-zp, B)
#C2 += B.sum(0).view(1, -1)*zp
C2
=
torch
.
matmul
(
A
,
B
-
zp
)
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
#
C2 = torch.matmul(A-zp, B)
#
C2 += B.sum(0).view(1, -1)*zp
C2
=
torch
.
matmul
(
A
,
B
-
zp
)
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
ca
,
cqa
,
cza
=
quant_zp
(
A
)
print
(
ca
.
min
(),
ca
.
max
())
print
((
ca
-
cza
).
min
(),
(
ca
-
cza
).
max
())
print
((
ca
-
cza
).
min
(),
(
ca
-
cza
).
max
())
zp
=
1
scale
=
2.0
C5
=
torch
.
matmul
((
A
*
scale
)
-
zp
,
B
)
C5
+=
B
.
sum
(
0
)
*
zp
C5
=
torch
.
matmul
((
A
*
scale
)
-
zp
,
B
)
C5
+=
B
.
sum
(
0
)
*
zp
C5
/=
scale
CA
,
qa
,
zpa
=
quant_zp
(
A
)
C4
=
torch
.
matmul
(
CA
,
B
)
C4
-=
B
.
sum
(
0
)
*
zpa
C4
-=
B
.
sum
(
0
)
*
zpa
C4
/=
qa
zpb
=
1
zpa
=
1
qa
=
2
qb
=
2
C6
=
torch
.
matmul
((
A
*
qa
)
+
zpa
,
(
B
*
qb
)
+
zpb
)
C6
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C6
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C6
/=
qa
*
qb
C6
=
torch
.
matmul
((
A
*
qa
)
+
zpa
,
(
B
*
qb
)
+
zpb
)
C6
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C6
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C6
/=
qa
*
qb
CA
,
qa
,
zpa
=
quant_zp
(
A
)
CB
,
qb
,
zpb
=
quant_zp
(
B
)
C7
=
torch
.
matmul
(
CA
,
CB
)
C7
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
/=
qa
*
qb
C7
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
/=
qa
*
qb
print
(
''
)
#print(C0.flatten()[:10])
print
(
""
)
#
print(C0.flatten()[:10])
print
(
C1
.
flatten
()[:
10
])
print
(
C2
.
flatten
()[:
10
])
print
(
C3
.
flatten
()[:
10
])
print
(
C5
.
flatten
()[:
10
])
print
(
C6
.
flatten
()[:
10
])
print
(
C7
.
flatten
()[:
10
])
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err4
=
torch
.
abs
(
C1
-
C5
).
mean
().
item
()
err5
=
torch
.
abs
(
C1
-
C6
).
mean
().
item
()
err6
=
torch
.
abs
(
C1
-
C7
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err4
=
torch
.
abs
(
C1
-
C5
).
mean
().
item
()
err5
=
torch
.
abs
(
C1
-
C6
).
mean
().
item
()
err6
=
torch
.
abs
(
C1
-
C7
).
mean
().
item
()
print
(
err1
,
err2
,
err3
,
err4
,
err5
,
err6
)
def
test_extract_outliers
():
for
i
in
range
(
k
):
shapeA
=
(
4096
,
4096
*
4
)
shapeA
=
(
4096
,
4096
*
4
)
idx
=
torch
.
unique
(
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
()).
cuda
()
#idx = torch.Tensor([0]).int().cuda()
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
#
idx = torch.Tensor([0]).int().cuda()
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
outliers1
=
A
[:,
idx
.
long
()]
CA
,
SA
=
F
.
transform
(
A
,
'
col_turing
'
)
CA
,
SA
=
F
.
transform
(
A
,
"
col_turing
"
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
...
...
@@ -1877,7 +2026,7 @@ def test_extract_outliers():
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
CA
,
SA
=
F
.
transform
(
A
,
'
col_ampere
'
)
CA
,
SA
=
F
.
transform
(
A
,
"
col_ampere
"
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
...
...
tests/test_modules.py
View file @
bfa0e332
from
itertools
import
product
import
pytest
import
torch
from
itertools
import
product
from
torch
import
nn
import
bitsandbytes
as
bnb
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
for
key
in
initial_data
:
setattr
(
self
,
key
,
initial_data
[
key
])
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
...
...
@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def
get_args
():
args
=
MockArgs
([])
args
.
quant_type
=
'
vector
'
args
.
use_8bit_training
=
'
full
'
args
.
quant_type
=
"
vector
"
args
.
use_8bit_training
=
"
full
"
args
.
clip_freq
=
9999
return
args
def
assert_all_approx_close
(
a
,
b
,
atol
=
1e-8
,
rtol
=
1e-5
,
count
=
10
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
print
(
f
'
Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
print
(
f
"
Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
#std = torch.abs(x).mean()*norm
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
#
std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
max1
=
std
*
trim_value
x
=
x
/
max1
*
127
max1
=
std
*
trim_value
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
x
[
x
>
127
]
=
127
x
[
x
<
-
127
]
=
-
127
x
=
x
/
127
*
max1
x
=
x
/
127
*
max1
return
x
def
quant
(
x
,
quant_type
,
dim
=
1
):
if
quant_type
==
'
linear
'
:
if
quant_type
==
"
linear
"
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'
vector
'
:
elif
quant_type
==
"
vector
"
:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'
min-max
'
:
elif
quant_type
==
"
min-max
"
:
maxA
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
minA
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
scale
=
(
maxA
-
minA
)
/
2.0
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
scale
=
(
maxA
-
minA
)
/
2.0
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
return
xq
,
(
minA
.
float
(),
scale
.
float
())
else
:
return
None
else
:
return
None
def
dequant
(
xq
,
S1
,
S2
,
dtype
,
quant_type
):
if
quant_type
==
'
linear
'
:
norm
=
S1
*
S2
/
(
127
*
127
)
if
quant_type
==
"
linear
"
:
norm
=
S1
*
S2
/
(
127
*
127
)
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
vector
'
:
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
"
vector
"
:
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
#print(x.shape, S1.shape, S2.shape)
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
# print(x.shape, S1.shape, S2.shape)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
.
t
()
/
127
x
*=
S1
.
t
()
/
127
else
:
x
*=
S1
/
127
x
*=
S2
/
127
x
*=
S1
/
127
x
*=
S2
/
127
return
x
.
to
(
dtype
)
else
:
return
None
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
x
*=
SB
.
t
()
/
127
else
:
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
#x = torch.round(x)/128*max1
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
#
x = torch.round(x)/128*max1
return
x
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
round_func
(
x
)
/
127
*
max1
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
round_func
(
x
)
/
127
*
max1
return
x
@
staticmethod
def
round_stoachastic
(
x
):
sign
=
torch
.
sign
(
x
)
absx
=
torch
.
abs
(
x
)
decimal
=
absx
-
torch
.
floor
(
absx
)
decimal
=
absx
-
torch
.
floor
(
absx
)
rdm
=
torch
.
rand_like
(
decimal
)
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
@
staticmethod
def
fake_8bit_storage
(
w
,
exponent_bits
):
...
...
@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
def
fake_8bit_storage_quantile
(
w
,
args
):
code
=
bnb
.
functional
.
estimate_quantiles
(
w
.
data
,
offset
=
args
.
offset
)
#C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#out = out.half()
#
C = bnb.functional.quantize_no_absmax(code, w)
#
out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#
print(out)
#
out = out.half()
code
/=
torch
.
max
(
torch
.
abs
(
code
))
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
...
...
@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
def
fake_8bit_storage_with_max
(
w
,
topk
=
8
):
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
'
(h b) -> h b
'
,
b
=
256
)
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
"
(h b) -> h b
"
,
b
=
256
)
max_val
,
idx
=
torch
.
sort
(
torch
.
abs
(
blocked_w
),
dim
=
1
,
descending
=
True
)
idx
=
idx
[:,
:
topk
]
max_val
=
max_val
[:,
:
topk
]
...
...
@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w
.
copy_
(
unblocked_w
)
return
unblocked_w
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
=
None
,
args
=
None
):
if
args
.
use_8bit_training
!=
'
off
'
:
if
args
.
use_8bit_training
!=
"
off
"
:
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
#if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
#
if torch.rand(1) < 0.01:
#
output32 = torch.matmul(x, weight.t())
#
err = torch.abs(output-output32).float()
#
relerr = err/(torch.abs(output32).float()+1e-8)
#
print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else
:
#output = torch.matmul(x, weight.t())
output
=
torch
.
einsum
(
'
bsi,oi->bso
'
,
x
,
weight
)
#
output = torch.matmul(x, weight.t())
output
=
torch
.
einsum
(
"
bsi,oi->bso
"
,
x
,
weight
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
)
ctx
.
args
=
args
...
...
@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args
=
ctx
.
args
stochastic
=
False
grad_input
=
grad_weight
=
grad_bias
=
None
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
'forward+wgrad'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
if
args
.
use_8bit_training
==
"forward+wgrad"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
#
grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
'full'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
elif
args
.
use_8bit_training
==
"full"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_weight
=
torch
.
einsum
(
'
bsi,bso->oi
'
,
x
,
grad_output
)
grad_weight
=
torch
.
einsum
(
"
bsi,bso->oi
"
,
x
,
grad_output
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
class
Linear8bit
(
nn
.
Module
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
args
=
None
):
super
(
Linear8bit
,
self
).
__init__
()
...
...
@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_features
))
else
:
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
if
self
.
bias
is
not
None
:
...
...
@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
args
)
def
test_linear8bit
():
l0
=
torch
.
nn
.
Linear
(
32
,
64
).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l2
=
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l0
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l0
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
...
...
@@ -292,8 +314,8 @@ def test_linear8bit():
l3
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
"
cuda
"
).
half
()
b2
=
b1
.
clone
()
b3
=
b1
.
clone
()
b0
=
b1
.
clone
()
...
...
@@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
assert
err1
*
0.8
<
err2
assert
err2
*
0.8
<
err3
assert
err3
*
0.8
<
err1
assert
err1
*
0.8
<
err2
assert
err2
*
0.8
<
err3
assert
err3
*
0.8
<
err1
l0
.
weight
.
grad
=
None
l1
.
weight
.
grad
=
None
...
...
@@ -341,23 +367,28 @@ def test_linear8bit():
threshold
=
[
0.0
,
3.0
]
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
names
=
[
"threshold_{0}"
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_inference
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
assert
l1
.
weight
.
device
.
type
==
'
cuda
'
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
assert
l1
.
weight
.
device
.
type
==
"
cuda
"
assert
l1
.
weight
.
dtype
==
torch
.
float16
l1
.
eval
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
if
i
==
1
:
assert
l1
.
state
.
CxB
is
not
None
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
...
...
@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps
=
10
for
i
in
range
(
10
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
o2
=
l2
(
b1
)
loss1
=
o1
.
mean
()
...
...
@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
...
...
@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold
=
[
0.0
,
2.0
]
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
names
=
[
"threshold_{0}"
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
...
...
@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
'cuda'
)
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
"cuda"
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"
cuda
"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"
cuda
"
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
to
(
torch
.
float16
).
to
(
'cuda'
)
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
to
(
torch
.
float16
)
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"
cuda
"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"
cuda
"
tests/test_optim.py
View file @
bfa0e332
import
ctypes
import
os
import
time
import
shutil
import
time
import
uuid
from
itertools
import
product
from
os.path
import
join
import
pytest
import
ctypes
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
os.path
import
join
from
itertools
import
product
#import apex
# import apex
k
=
20
def
get_temp_dir
():
path
=
'
/tmp/autoswap/{0}
'
.
format
(
str
(
uuid
.
uuid4
()))
path
=
"
/tmp/autoswap/{0}
"
.
format
(
str
(
uuid
.
uuid4
()))
os
.
makedirs
(
path
,
exist_ok
=
True
)
return
path
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
"momentum_pytorch"
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
,
)
# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
"momentum"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"lars"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
),
)
# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"adam8bit"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"momentum8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"rmsprop8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers
[
"lars8bit"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
)
str2optimizers
[
"rmsprop8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
)
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
]
str2statenames
[
"lamb8bit"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
]
str2statenames
[
"adam8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lamb"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
torch_optimizer
=
str2optimizers
[
optim_name
][
0
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
...
...
@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else
:
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
...
...
@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer
.
step
()
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'
opt.pt
'
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"
opt.pt
"
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
)
if
gtype
==
torch
.
float16
:
# the adam buffers should also be close because they are 32-bit
...
...
@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1
.
data
=
p1
.
data
.
half
().
float
()
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(),
p2
)
if
optim_name
in
[
'lars'
,
'lamb'
]:
assert
bnb_optimizer
.
state
[
p2
][
'unorm_vec'
]
>
0.0
if
optim_name
in
[
"lars"
,
"lamb"
]:
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype"
,
values
,
ids
=
names
)
def
test_global_config
(
dim1
,
dim2
,
gtype
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
p2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
p3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
mask
=
torch
.
rand_like
(
p2
)
<
0.1
beta1
=
0.9
beta2
=
0.999
...
...
@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'
optim_bits
'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"
optim_bits
"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
p1
=
p1
.
cuda
()
...
...
@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
p1
.
grad
=
g1
p2
.
grad
=
g2
p3
.
grad
=
g3
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
"state1"
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
"state2"
].
dtype
==
torch
.
uint8
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
optimizer_names
=
[
"adam8bit"
,
"momentum8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"lamb8bit"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
blocksize
=
2048
...
...
@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
=
[]
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
...
...
@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states
=
[]
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
#print(bnb_optimizer.state[p2][max_val], name1)
if
'blockwise'
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
# print(bnb_optimizer.state[p2][max_val], name1)
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
,
)
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
...
...
@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
.
append
(
relerr
.
mean
().
item
())
if
i
%
10
==
0
and
i
>
0
:
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
s1cpy
=
s
.
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'
opt.pt
'
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"
opt.pt
"
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
if
'blockwise'
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
,
)
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
...
...
@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
#print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors))
# print(sum(errors)/len(errors))
# print(sum(relerrors)/len(relerrors))
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
]
optim_bits
=
[
32
,
8
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
def
test_adam_percentile_clipping
(
dim1
,
dim2
,
gtype
,
optim_bits
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
beta1
=
0.9
beta2
=
0.999
lr
=
0.001
...
...
@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1
=
p1
.
cuda
()
p2
=
p1
.
clone
()
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
adam2
=
bnb
.
optim
.
Adam
(
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
step
=
0
for
i
in
range
(
50
):
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g2
=
g1
.
clone
()
p2
.
grad
=
g2
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
p1
.
grad
=
g1
adam1
.
step
()
...
...
@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if
optim_bits
==
32
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
5e-5
,
rtol
=
1e-4
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
5e-5
,
rtol
=
1e-4
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
5e-5
,
rtol
=
1e-4
,
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
5e-5
,
rtol
=
1e-4
,
)
elif
optim_bits
==
8
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
2
,
rtol
=
1e-3
)
adam1
.
state
[
p1
][
'state1'
].
copy_
(
adam2
.
state
[
p2
][
'state1'
])
adam1
.
state
[
p1
][
'state2'
].
copy_
(
adam2
.
state
[
p2
][
'state2'
])
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
)
adam1
.
state
[
p1
][
"state1"
].
copy_
(
adam2
.
state
[
p2
][
"state1"
])
adam1
.
state
[
p1
][
"state2"
].
copy_
(
adam2
.
state
[
p2
][
"state2"
])
if
i
%
10
==
0
and
i
>
0
:
path
=
get_temp_dir
()
torch
.
save
(
adam2
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
adam2
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
adam2
adam2
=
None
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
adam2
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'opt.pt'
)))
adam2
=
bnb
.
optim
.
Adam
(
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
,
)
adam2
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
dim1
=
[
4096
]
dim2
=
[
4096
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
'adam8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
"adam8bit_blockwise"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
for
i
in
range
(
k
):
if
i
==
k
//
5
:
if
i
==
k
//
5
:
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
...
...
@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer
.
step
()
torch
.
cuda
.
synchronize
()
s
=
time
.
time
()
-
t0
print
(
''
)
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
#assert s < 3.9
s
=
time
.
time
()
-
t0
print
(
""
)
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
# assert s < 3.9
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment