Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bfa0e332
Commit
bfa0e332
authored
Aug 01, 2022
by
Titus von Koeller
Browse files
ran black and isort for coherent code formatting
parent
597a8521
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1417 additions
and
1007 deletions
+1417
-1007
tests/test_autograd.py
tests/test_autograd.py
+137
-63
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+39
-38
tests/test_functional.py
tests/test_functional.py
+801
-652
tests/test_modules.py
tests/test_modules.py
+175
-122
tests/test_optim.py
tests/test_optim.py
+265
-132
No files found.
tests/test_autograd.py
View file @
bfa0e332
import
p
ytes
t
from
itertools
import
p
roduc
t
import
pytest
import
torch
import
torch
import
bitsandbytes
as
bnb
from
itertools
import
product
import
bitsandbytes
as
bnb
n
=
1
n
=
1
k
=
25
k
=
25
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
funcs
=
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)]
funcs
=
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)]
str_funcs
=
[
'
bmm
'
,
'
matmul
'
]
str_funcs
=
[
"
bmm
"
,
"
matmul
"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'
FF
'
,
'
TF
'
,
'
TT
'
,
'
FT
'
]
req_grad_str
=
[
"
FF
"
,
"
TF
"
,
"
TT
"
,
"
FT
"
]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
'
FF
'
,
'
FT
'
,
'
TT
'
,
'
TF
'
]
str_transpose
=
[
"
FF
"
,
"
FT
"
,
"
TT
"
,
"
TF
"
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
str_values
=
list
(
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
str_values
]
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
...
@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
...
@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
)
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
...
@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
if
any
(
req_grad
):
...
@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
if
funcs
[
0
]
in
[
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
]
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
]
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
transpose
[
1
]:
if
transpose
[
1
]:
...
@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
n
=
1
n
=
1
k
=
3
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
#dim1 = (17,)
#
dim1 = (17,)
#dim2 = (7,)
#
dim2 = (7,)
#dim3 = (37,)
#
dim3 = (37,)
#dim4 = (23,)
#
dim4 = (23,)
decomp
=
[
0.0
,
6.0
]
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
str_funcs
=
[
'
matmul
'
]
str_funcs
=
[
"
matmul
"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'
FF
'
,
'
TF
'
,
'
TT
'
,
'
FT
'
]
req_grad_str
=
[
"
FF
"
,
"
TF
"
,
"
TT
"
,
"
FT
"
]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
'
NT
'
,
'
NN
'
]
str_transpose
=
[
"
NT
"
,
"
NN
"
]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
))
values
=
list
(
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
))
product
(
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'
.
format
(
*
vals
)
for
vals
in
str_values
]
dim1
,
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
)
dim2
,
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
,
)
)
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
,
)
)
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
,
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
'
cuda
'
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"
cuda
"
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
# normal multiply
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
if
decomp
==
6.0
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
B
=
torch
.
randn
(
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
=
B
.
clone
()
...
@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
...
@@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state
.
threshold
=
decomp
state
.
threshold
=
decomp
state
.
has_fp16_weights
=
has_fp16_weights
state
.
has_fp16_weights
=
has_fp16_weights
if
not
has_fp16_weights
:
if
not
has_fp16_weights
:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
B2
=
B2
.
t
().
contiguous
()
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B2
)
B2
=
B2
.
t
().
contiguous
()
(
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
,
)
=
bnb
.
functional
.
double_quant
(
B2
)
B2
=
state
.
CB
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
...
@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
...
@@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
#print(f'abs error {err:.4f}')
#
print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
has_fp16_weights
:
if
has_fp16_weights
:
if
any
(
req_grad
):
if
any
(
req_grad
):
...
@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
...
@@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_cuda_setup_evaluator.py
View file @
bfa0e332
import
pytest
import
os
import
os
from
typing
import
List
,
NamedTuple
import
pytest
from
typing
import
List
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
evaluate_cuda_setup
,
get_cuda_runtime_lib_path
,
tokenize_paths
)
from
bitsandbytes.cuda_setup
import
(
CUDA_RUNTIME_LIB
,
get_cuda_runtime_lib_path
,
evaluate_cuda_setup
,
tokenize_paths
,
)
class
InputAndExpectedOutput
(
NamedTuple
):
input
:
str
output
:
str
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
tuple
[
str
,
str
]]
=
[
HAPPY_PATH__LD_LIB_TEST_PATHS
:
List
[
InputAndExpectedOutput
]
=
[
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
":some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir:dir/with/
{
CUDA_RUNTIME_LIB
}
:"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"some/other/dir::dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:some/other/dir"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
),
(
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
:other/dir/libcuda.so"
,
f
"dir/with/
{
CUDA_RUNTIME_LIB
}
"
,
),
]
]
@
pytest
.
mark
.
parametrize
(
@
pytest
.
fixture
(
params
=
HAPPY_PATH__LD_LIB_TEST_PATHS
)
"test_input, expected"
,
def
happy_path_path_string
(
tmpdir
,
request
):
HAPPY_PATH__LD_LIB_TEST_PATHS
for
path
in
tokenize_paths
(
request
.
param
):
)
test_dir
.
mkdir
()
if
CUDA_RUNTIME_LIB
in
path
:
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
@
pytest
.
mark
.
parametrize
(
"test_input, expected"
,
HAPPY_PATH__LD_LIB_TEST_PATHS
)
def
test_get_cuda_runtime_lib_path__happy_path
(
def
test_get_cuda_runtime_lib_path__happy_path
(
tmp_path
,
test_input
:
str
,
expected
:
str
tmp_path
,
test_input
:
str
,
expected
:
str
):
):
for
path
in
tokenize_paths
(
test_input
):
for
path
in
tokenize_paths
(
test_input
):
assert
False
==
tmp_path
/
test_input
path
.
mkdir
()
test_dir
.
mkdir
()
(
path
/
CUDA_RUNTIME_LIB
).
touch
()
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
assert
get_cuda_runtime_lib_path
(
test_input
)
==
expected
assert
get_cuda_runtime_lib_path
(
test_input
)
==
expected
...
@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
...
@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
(
test_input
/
CUDA_RUNTIME_LIB
).
touch
()
with
pytest
.
raises
(
FileNotFoundError
)
as
err_info
:
with
pytest
.
raises
(
FileNotFoundError
)
as
err_info
:
get_cuda_runtime_lib_path
(
test_input
)
get_cuda_runtime_lib_path
(
test_input
)
assert
all
(
assert
all
(
match
in
err_info
for
match
in
{
"duplicate"
,
CUDA_RUNTIME_LIB
})
match
in
err_info
for
match
in
{
"duplicate"
,
CUDA_RUNTIME_LIB
}
)
def
test_get_cuda_runtime_lib_path__non_existent_dir
(
capsys
,
tmp_path
):
def
test_get_cuda_runtime_lib_path__non_existent_dir
(
capsys
,
tmp_path
):
existent_dir
=
tmp_path
/
'
a/b
'
existent_dir
=
tmp_path
/
"
a/b
"
existent_dir
.
mkdir
()
existent_dir
.
mkdir
()
non_existent_dir
=
tmp_path
/
'
c/d
'
# non-existent dir
non_existent_dir
=
tmp_path
/
"
c/d
"
# non-existent dir
test_input
=
":"
.
join
([
str
(
existent_dir
),
str
(
non_existent_dir
)])
test_input
=
":"
.
join
([
str
(
existent_dir
),
str
(
non_existent_dir
)])
get_cuda_runtime_lib_path
(
test_input
)
get_cuda_runtime_lib_path
(
test_input
)
std_err
=
capsys
.
readouterr
().
err
std_err
=
capsys
.
readouterr
().
err
assert
all
(
assert
all
(
match
in
std_err
for
match
in
{
"WARNING"
,
"non-existent"
})
match
in
std_err
for
match
in
{
"WARNING"
,
"non-existent"
}
)
def
test_full_system
():
def
test_full_system
():
## this only tests the cuda version and not compute capability
## this only tests the cuda version and not compute capability
ld_path
=
os
.
environ
[
'
LD_LIBRARY_PATH
'
]
ld_path
=
os
.
environ
[
"
LD_LIBRARY_PATH
"
]
paths
=
ld_path
.
split
(
':'
)
paths
=
ld_path
.
split
(
":"
)
version
=
''
version
=
""
for
p
in
paths
:
for
p
in
paths
:
if
'
cuda
'
in
p
:
if
"
cuda
"
in
p
:
idx
=
p
.
rfind
(
'
cuda-
'
)
idx
=
p
.
rfind
(
"
cuda-
"
)
version
=
p
[
idx
+
5
:
idx
+
5
+
4
].
replace
(
'/'
,
''
)
version
=
p
[
idx
+
5
:
idx
+
5
+
4
].
replace
(
"/"
,
""
)
version
=
float
(
version
)
version
=
float
(
version
)
break
break
binary_name
=
evaluate_cuda_setup
()
binary_name
=
evaluate_cuda_setup
()
binary_name
=
binary_name
.
replace
(
'libbitsandbytes_cuda'
,
''
)
binary_name
=
binary_name
.
replace
(
"libbitsandbytes_cuda"
,
""
)
assert
binary_name
.
startswith
(
str
(
version
).
replace
(
'.'
,
''
))
assert
binary_name
.
startswith
(
str
(
version
).
replace
(
"."
,
""
))
tests/test_functional.py
View file @
bfa0e332
import
pytest
import
math
import
math
import
random
import
random
import
time
import
time
import
torch
import
bitsandbytes
as
bnb
import
einops
from
itertools
import
product
from
itertools
import
product
import
einops
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes
import
functional
as
F
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
k
=
20
k
=
20
def
assert_all_approx_close
(
a
,
b
,
rtol
,
atol
,
count
):
def
assert_all_approx_close
(
a
,
b
,
rtol
,
atol
,
count
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
if
sumval
>
count
:
print
(
f
'
Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
print
(
f
"
Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
FFN
(
torch
.
nn
.
Module
):
class
FFN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_features
,
hidden_size
,
bias
=
True
):
def
__init__
(
self
,
input_features
,
hidden_size
,
bias
=
True
):
super
(
FFN
,
self
).
__init__
()
super
(
FFN
,
self
).
__init__
()
...
@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
...
@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
x
return
x
class
Timer
(
object
):
class
Timer
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
starts
=
{}
self
.
starts
=
{}
self
.
ends
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
self
.
agg
=
{}
def
tick
(
self
,
name
=
'
default
'
):
def
tick
(
self
,
name
=
"
default
"
):
if
name
not
in
self
.
starts
:
if
name
not
in
self
.
starts
:
self
.
starts
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
starts
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
ends
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
ends
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
@@ -49,66 +54,70 @@ class Timer(object):
...
@@ -49,66 +54,70 @@ class Timer(object):
else
:
else
:
ms
=
self
.
tock
(
name
,
evict
=
True
,
print_ms
=
False
)
ms
=
self
.
tock
(
name
,
evict
=
True
,
print_ms
=
False
)
def
tock
(
self
,
name
=
'
default
'
,
evict
=
True
,
print_ms
=
True
):
def
tock
(
self
,
name
=
"
default
"
,
evict
=
True
,
print_ms
=
True
):
if
name
in
self
.
ends
:
if
name
in
self
.
ends
:
self
.
ends
[
name
].
record
()
self
.
ends
[
name
].
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
ms
=
self
.
starts
[
name
].
elapsed_time
(
self
.
ends
[
name
])
ms
=
self
.
starts
[
name
].
elapsed_time
(
self
.
ends
[
name
])
if
name
not
in
self
.
agg
:
self
.
agg
[
name
]
=
0.0
if
name
not
in
self
.
agg
:
self
.
agg
[
name
]
=
0.0
self
.
agg
[
name
]
+=
ms
self
.
agg
[
name
]
+=
ms
if
evict
:
if
evict
:
self
.
starts
.
pop
(
name
)
self
.
starts
.
pop
(
name
)
self
.
ends
.
pop
(
name
)
self
.
ends
.
pop
(
name
)
if
print_ms
and
name
in
self
.
agg
:
if
print_ms
and
name
in
self
.
agg
:
print
(
'
{0} took: {1:.5f}s
'
.
format
(
name
,
self
.
agg
[
name
]
/
1000.0
))
print
(
"
{0} took: {1:.5f}s
"
.
format
(
name
,
self
.
agg
[
name
]
/
1000.0
))
return
self
.
agg
[
name
]
return
self
.
agg
[
name
]
def
reset
(
self
):
def
reset
(
self
):
self
.
starts
=
{}
self
.
starts
=
{}
self
.
ends
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
self
.
agg
=
{}
print
(
'Resetting benchmark data'
)
print
(
"Resetting benchmark data"
)
def
setup
():
def
setup
():
pass
pass
def
teardown
():
def
teardown
():
pass
pass
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
def
test_estimate_quantiles
(
dtype
):
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
code
=
F
.
estimate_quantiles
(
A
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
torch
.
testing
.
assert_allclose
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
torch
.
testing
.
assert_allclose
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
code
=
F
.
estimate_quantiles
(
A
)
quantiles
=
torch
.
quantile
(
A
.
float
(),
percs
)
quantiles
=
torch
.
quantile
(
A
.
float
(),
percs
)
diff
=
torch
.
abs
(
code
-
quantiles
)
diff
=
torch
.
abs
(
code
-
quantiles
)
assert
(
diff
>
5e-02
).
sum
().
item
()
==
0
assert
(
diff
>
5e-02
).
sum
().
item
()
==
0
def
test_quantile_quantization
():
def
test_quantile_quantization
():
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
code
=
F
.
estimate_quantiles
(
A1
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0075
assert
diff
<
0.0075
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
code
=
F
.
estimate_quantiles
(
A1
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
assert
diff
<
0.001
assert
diff
<
0.001
...
@@ -117,22 +126,22 @@ def test_dynamic_quantization():
...
@@ -117,22 +126,22 @@ def test_dynamic_quantization():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize
(
A1
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
assert
diff
.
mean
().
item
()
<
0.0135
#print(sum(diffs)/len(diffs))
#
print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
#
print(sum(reldiffs)/len(reldiffs))
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize
(
A1
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
assert
diff
<
0.004
assert
diff
<
0.004
...
@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization():
...
@@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
assert
diffs
[
-
1
]
<
0.011
#print(sum(diffs)/len(diffs))
#
print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
#
print(sum(reldiffs)/len(reldiffs))
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"
cuda
"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0033
assert
diff
<
0.0033
diffs
.
append
(
diff
)
diffs
.
append
(
diff
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
#print(sum(diffs)/len(diffs))
# print(sum(diffs)/len(diffs))
def
test_dynamic_blockwise_stochastic_quantization
():
def
test_dynamic_blockwise_stochastic_quantization
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
rand
=
torch
.
rand
(
1024
).
cuda
()
rand
=
torch
.
rand
(
1024
).
cuda
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'
cuda
'
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"
cuda
"
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
# a maximunm distance of quantized values of 1
# a maximunm distance of quantized values of 1
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
"float"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
def
test_percentile_clipping
(
gtype
):
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
'
cuda
'
)
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
"
cuda
"
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
'
cuda
'
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
"
cuda
"
)
n
=
4
n
=
4
step
=
0
step
=
0
percentile
=
5
percentile
=
5
for
i
in
range
(
k
):
for
i
in
range
(
k
):
step
+=
1
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
'cuda'
)
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
"cuda"
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
gnorm2
=
torch
.
norm
(
g
.
float
())
gnorm2
=
torch
.
norm
(
g
.
float
())
if
step
==
1
:
if
step
==
1
:
...
@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype):
...
@@ -208,74 +221,89 @@ def test_percentile_clipping(gtype):
def
quant
(
x
):
def
quant
(
x
):
max1
=
torch
.
abs
(
x
).
max
()
max1
=
torch
.
abs
(
x
).
max
()
x
=
torch
.
round
(
x
/
max1
*
127
)
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
dequant
(
c
,
maxC
):
def
dequant
(
c
,
maxC
):
return
c
.
float
()
*
(
maxC
/
127
)
return
c
.
float
()
*
(
maxC
/
127
)
def
mm_dequant
(
maxA
,
maxB
,
C
):
def
mm_dequant
(
maxA
,
maxB
,
C
):
return
C
.
float
()
*
(
maxA
/
127
)
*
(
maxB
/
127
)
return
C
.
float
()
*
(
maxA
/
127
)
*
(
maxB
/
127
)
def
quant_multi
(
x
,
dim
):
def
quant_multi
(
x
,
dim
):
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_multi_chunk
(
x
,
dim
,
chunk_size
=
32
):
def
quant_multi_chunk
(
x
,
dim
,
chunk_size
=
32
):
if
dim
==
1
:
if
dim
==
1
:
x_chunked
=
einops
.
rearrange
(
x
,
'
(c a) b -> c a b
'
,
c
=
chunk_size
)
x_chunked
=
einops
.
rearrange
(
x
,
"
(c a) b -> c a b
"
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
+
1
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
+
1
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
1
,
1
,
x
.
shape
[
1
]))
max1
=
torch
.
tile
(
max1
,
(
1
,
1
,
x
.
shape
[
1
]))
max1
=
max1
.
view
(
x
.
shape
)
max1
=
max1
.
view
(
x
.
shape
)
elif
dim
==
0
:
elif
dim
==
0
:
x_chunked
=
einops
.
rearrange
(
x
,
'
a (b c) -> a b c
'
,
c
=
chunk_size
)
x_chunked
=
einops
.
rearrange
(
x
,
"
a (b c) -> a b c
"
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
x
.
shape
[
0
],
1
,
1
))
max1
=
torch
.
tile
(
max1
,
(
x
.
shape
[
0
],
1
,
1
))
max1
=
max1
.
view
(
x
.
shape
)
max1
=
max1
.
view
(
x
.
shape
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_minmax
(
A
):
def
quant_minmax
(
A
):
minA
=
A
.
min
()
minA
=
A
.
min
()
maxA
=
A
.
max
()
maxA
=
A
.
max
()
def
mean
(
xx
):
def
mean
(
xx
):
return
sum
(
xx
)
/
float
(
len
(
xx
))
return
sum
(
xx
)
/
float
(
len
(
xx
))
#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1
=
[
1024
*
2
]
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
dim2
=
[
1024
*
16
]
methods
=
[(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)]
methods
=
[
(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)
]
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
#
methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names
=
[
'
linear
'
,
'
vectorwise
'
]
method_names
=
[
"
linear
"
,
"
vectorwise
"
]
batched
=
[
False
,
True
]
batched
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
'dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'
.
format
(
*
vals
)
for
vals
in
values_names
]
names
=
[
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}"
.
format
(
*
vals
)
for
vals
in
values_names
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
dim1
=
dim1
-
(
dim1
%
32
)
dim1
=
dim1
-
(
dim1
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
errors
=
[]
relerrors
=
[]
relerrors
=
[]
print
(
''
)
print
(
""
)
for
i
in
range
(
5
):
for
i
in
range
(
5
):
if
batched
:
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"
cuda
"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim2
//
32
,
dim1
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim2
//
32
,
dim1
),
device
=
"
cuda
"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
2
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
2
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
1
)
else
:
else
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"
cuda
"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_allclose
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
torch
.
testing
.
assert_allclose
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
if
batched
:
if
batched
:
out2
=
torch
.
bmm
(
A
,
B
)
out2
=
torch
.
bmm
(
A
,
B
)
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
...
@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C
=
F
.
igemm
(
Ac
,
Bc
)
C
=
F
.
igemm
(
Ac
,
Bc
)
out
=
quant_methods
[
4
](
maxA
,
maxB
,
C
)
out
=
quant_methods
[
4
](
maxA
,
maxB
,
C
)
std
=
out2
.
std
()
std
=
out2
.
std
()
out
/=
std
out
/=
std
out2
/=
std
out2
/=
std
err
=
torch
.
abs
(
out
-
out2
)
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
print
(
mean
(
errors
))
print
(
mean
(
errors
))
print
(
mean
(
relerrors
))
print
(
mean
(
relerrors
))
def
test_stable_embedding
():
def
test_stable_embedding
():
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
.
reset_parameters
()
layer
.
reset_parameters
()
n
=
2
n
=
2
hidden_dim
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)]
values
=
list
(
product
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
))
values
=
list
(
product
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
))
names
=
[
'hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
shapeA
=
(
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
shapeB
=
(
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
out
=
F
.
igemm
(
A
,
B
)
...
@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
shapeB
=
(
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
(
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
))
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"cuda"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"cuda"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
out
=
F
.
igemm
(
A
,
B
)
...
@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n
=
3
n
=
3
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
'seq_dim{0}_hidden_dim{1}_batch_dim{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"seq_dim{0}_hidden_dim{1}_batch_dim{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
def
test_dim3_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
):
def
test_dim3_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
):
seq_dim
=
seq_dim
-
(
seq_dim
%
32
)
seq_dim
=
seq_dim
-
(
seq_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
for
i
in
range
(
25
):
for
i
in
range
(
25
):
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
'cuda'
).
to
(
torch
.
int8
)
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"cuda"
out2
=
torch
.
einsum
(
'bsi, bso->io'
,
A
.
float
(),
B
.
float
())
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
"cuda"
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
"bsi, bso->io"
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
n
=
2
n
=
2
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
transpose
=
[
False
,
True
]
transpose
=
[
False
,
True
]
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
))
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
))
names
=
[
'seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
min_max
(
x
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
2
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
2
,
keepdim
=
True
)
scale
=
(
maxA
-
minA
)
/
2.0
scale
=
(
maxA
-
minA
)
/
2.0
return
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
),
minA
,
scale
return
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
),
minA
,
scale
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
16
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
16
)
...
@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
...
@@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2
=
[]
errs2
=
[]
relerrs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'
cuda
'
)
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
"
cuda
"
)
if
transpose
:
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
"
cuda
"
)
else
:
else
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
hidden_dim
,
256
),
device
=
'
cuda
'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
hidden_dim
,
256
),
device
=
"
cuda
"
)
Ac
,
minA
,
scale
=
min_max
(
A
)
Ac
,
minA
,
scale
=
min_max
(
A
)
if
transpose
:
if
transpose
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
(
1
if
transpose
else
0
))
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
(
1
if
transpose
else
0
))
out
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out2
=
torch
.
matmul
(
A
,
B
.
t
())
out2
=
torch
.
matmul
(
A
,
B
.
t
())
offset
=
B
.
t
().
sum
(
0
)
*
(
minA
+
scale
)
offset
=
B
.
t
().
sum
(
0
)
*
(
minA
+
scale
)
out
=
out
.
float
()
out
=
out
.
float
()
out
=
(
out
*
maxB
.
t
()
*
scale
/
(
127
*
127
))
+
offset
out
=
(
out
*
maxB
.
t
()
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out3
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out3
=
mm_dequant
(
maxA
,
maxB
.
t
(),
out3
)
out3
=
mm_dequant
(
maxA
,
maxB
.
t
(),
out3
)
else
:
else
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
0
)
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
0
)
offset
=
B
.
sum
(
0
)
*
(
minA
+
scale
)
offset
=
B
.
sum
(
0
)
*
(
minA
+
scale
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
out2
=
torch
.
matmul
(
A
,
B
)
out
=
out
.
float
()
out
=
out
.
float
()
out
=
(
out
*
maxB
*
scale
/
(
127
*
127
))
+
offset
out
=
(
out
*
maxB
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
)
out3
=
F
.
igemm
(
Ac
,
Bc
)
...
@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
...
@@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out
/=
std
out
/=
std
out3
/=
std
out3
/=
std
err
=
torch
.
abs
(
out
-
out2
)
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
(
torch
.
abs
(
out2
)
+
1e-7
)
relerr
=
err
/
(
torch
.
abs
(
out2
)
+
1e-7
)
err2
=
torch
.
abs
(
out3
-
out2
)
err2
=
torch
.
abs
(
out3
-
out2
)
relerr2
=
err2
/
(
torch
.
abs
(
out2
)
+
1e-7
)
relerr2
=
err2
/
(
torch
.
abs
(
out2
)
+
1e-7
)
errs
.
append
(
err
.
mean
().
item
())
errs
.
append
(
err
.
mean
().
item
())
relerrs
.
append
(
relerr
.
mean
().
item
())
relerrs
.
append
(
relerr
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#print(mean(errs))
#
print(mean(errs))
#print(mean(relerrs))
#
print(mean(relerrs))
#print(mean(errs2))
#
print(mean(errs2))
#print(mean(relerrs2))
#
print(mean(relerrs2))
assert
mean
(
errs
)
<
0.015
assert
mean
(
errs
)
<
0.015
assert
mean
(
relerrs
)
<
0.3
assert
mean
(
relerrs
)
<
0.3
n
=
2
n
=
2
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, transpose"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, transpose"
,
values
,
ids
=
names
)
def
test_ibmm
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
):
def
test_ibmm
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim2
=
dim2
-
(
dim2
%
16
)
...
@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
dim1
,
dim3
,
dim2
)
if
transpose
[
0
]
else
(
dim1
,
dim2
,
dim3
)
shapeA
=
(
dim1
,
dim3
,
dim2
)
if
transpose
[
0
]
else
(
dim1
,
dim2
,
dim3
)
shapeB
=
(
dim1
,
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim1
,
dim3
,
dim4
)
shapeB
=
(
dim1
,
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim1
,
dim3
,
dim4
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
float
(),
B
.
float
())
out2
=
torch
.
bmm
(
A
.
float
(),
B
.
float
())
...
@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
n
=
1
n
=
1
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3"
,
values
,
ids
=
names
)
def
test_vector_quant
(
dim1
,
dim2
,
dim3
):
def
test_vector_quant
(
dim1
,
dim2
,
dim3
):
dim2
=
dim2
-
(
dim2
%
16
)
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
size
=
(
dim2
,
dim3
),
device
=
'
cuda
'
)
A
=
torch
.
randn
(
size
=
(
dim2
,
dim3
),
device
=
"
cuda
"
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
torch
.
testing
.
assert_allclose
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
)
n
=
2
n
=
2
dim1
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
#dim1, dim2 = (256,), (256,)
#
dim1, dim2 = (256,), (256,)
dtype
=
[
torch
.
int8
,
torch
.
int32
]
dtype
=
[
torch
.
int8
,
torch
.
int32
]
a_order
=
[
'
row
'
]
a_order
=
[
"
row
"
]
out_order
=
[
'
col
'
,
'
row
'
,
'
col32
'
]
out_order
=
[
"
col
"
,
"
row
"
,
"
col32
"
]
transpose
=
[
False
]
transpose
=
[
False
]
dims
=
[
2
,
3
]
dims
=
[
2
,
3
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
if
dims
==
3
and
out_order
!=
'col32'
:
return
if
dims
==
3
and
out_order
!=
"col32"
:
if
dtype
==
torch
.
int32
and
out_order
!=
'col32'
:
return
return
if
dtype
==
torch
.
int32
and
out_order
!=
"col32"
:
return
func
=
F
.
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
)
func
=
F
.
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
)
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
to
(
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
if
orderOut
==
'
row
'
:
if
orderOut
==
"
row
"
:
torch
.
testing
.
assert_allclose
(
A
.
flatten
(),
out
.
flatten
())
torch
.
testing
.
assert_allclose
(
A
.
flatten
(),
out
.
flatten
())
elif
orderOut
==
'
col
'
:
elif
orderOut
==
"
col
"
:
torch
.
testing
.
assert_allclose
(
A
.
t
().
flatten
(),
out
.
flatten
())
torch
.
testing
.
assert_allclose
(
A
.
t
().
flatten
(),
out
.
flatten
())
elif
orderOut
==
'
col32
'
:
elif
orderOut
==
"
col32
"
:
if
dims
==
2
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
elif
dims
==
3
:
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
elif
orderOut
==
'
col_turing
'
:
elif
orderOut
==
"
col_turing
"
:
# 32 col 8 row tiles
# 32 col 8 row tiles
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
))
)
assert
out
.
numel
()
==
n
assert
out
.
numel
()
==
n
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
for
row
in
range
(
A
.
shape
[
0
]):
for
row
in
range
(
A
.
shape
[
0
]):
for
col
in
range
(
A
.
shape
[
1
]):
for
col
in
range
(
A
.
shape
[
1
]):
i
=
row
*
A
.
shape
[
1
]
i
=
row
*
A
.
shape
[
1
]
j
=
col
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
row2
=
(
row
%
8
)
*
32
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
if
orderOut
==
"col32"
:
#assert A.flatten()[i+j] == out.flatten()[row2+col2]
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
#torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
#torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
'col32'
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
'row'
,
state
=
S
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
n
=
1
n
=
1
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2]
#
dim1 = [2]
#dim2 = [2]
#
dim2 = [2]
#dim3 = [2]
#
dim3 = [2]
#dim4 = [2]
#
dim4 = [2]
dims
=
(
2
,
3
)
dims
=
(
2
,
3
)
ldb
=
[
0
]
ldb
=
[
0
]
#ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims, ldb"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims, ldb"
,
values
,
ids
=
names
)
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
"cuda"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'
col32
'
)
A2
,
SA
=
F
.
transform
(
A
,
"
col32
"
)
B2
,
SB
=
F
.
transform
(
B
,
'
col_turing
'
)
B2
,
SB
=
F
.
transform
(
B
,
"
col_turing
"
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
# transpose
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
'
cuda
'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"
cuda
"
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
'
col_turing
'
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"
col_turing
"
,
transpose
=
True
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
dim1
=
[
32
]
dim1
=
[
32
]
dim2
=
[
32
]
dim2
=
[
32
]
dim3
=
[
32
]
dim3
=
[
32
]
dim4
=
[
32
]
dim4
=
[
32
]
dims
=
(
2
,)
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims"
,
values
,
ids
=
names
)
def
test_igemmlt_half
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
):
def
test_igemmlt_half
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
):
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
'
cuda
'
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
"
cuda
"
).
half
()
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
half
()
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
bnb
.
matmul
(
A
,
B
.
t
())
C2
=
bnb
.
matmul
(
A
,
B
.
t
())
...
@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
output
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsAt
,
statsBt
)
output
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsAt
,
statsBt
)
#print('')
#
print('')
#print(output.flatten()[:10])
#
print(output.flatten()[:10])
#print(C1.flatten()[:10])
#
print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#
print(C2.flatten()[:10])
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
#torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
# transpose
#B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
#C1 = torch.matmul(A.float(), B.float())
# C1 = torch.matmul(A.float(), B.float())
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_allclose(C1, C3.float())
#B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
#C2, SC = F.igemmlt(A2, B2t, SA, SBt)
#C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
batch_size
=
2
batch_size
=
2
seqdim
=
512
seqdim
=
512
#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values
=
[(
batch_size
,
seqdim
,
4
*
1024
,
3
*
4
*
1024
),(
batch_size
,
seqdim
,
5120
,
3
*
5120
),(
batch_size
,
seqdim
,
12
*
1024
,
4
*
12
*
1024
)]
values
=
[
(
batch_size
,
seqdim
,
4
*
1024
,
3
*
4
*
1024
),
(
batch_size
,
seqdim
,
5120
,
3
*
5120
),
(
batch_size
,
seqdim
,
12
*
1024
,
4
*
12
*
1024
),
]
# values = list(product(batch, seq, model, hidden))
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
#values = list(product(batch, seq, model, hidden))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_8bit_training
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_8bit_training
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
grad
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
grad
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
w1
=
torch
.
randint
(
-
128
,
127
,
size
=
(
hidden
,
model
),
device
=
'
cuda
'
).
half
()
w1
=
torch
.
randint
(
-
128
,
127
,
size
=
(
hidden
,
model
),
device
=
"
cuda
"
).
half
()
w2
=
torch
.
randint
(
-
128
,
127
,
size
=
(
model
,
hidden
),
device
=
'
cuda
'
).
half
()
w2
=
torch
.
randint
(
-
128
,
127
,
size
=
(
model
,
hidden
),
device
=
"
cuda
"
).
half
()
print
(
''
)
print
(
""
)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
## warmup
## warmup
#for i in range(100):
#
for i in range(100):
# torch.matmul(A, w1.t())
# torch.matmul(A, w1.t())
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
dtype
=
torch
.
int8
dtype
=
torch
.
int8
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
...
@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
#out2 = torch.matmul(out1, w2.t())# fc2
#
out2 = torch.matmul(out1, w2.t())# fc2
#d1 = torch.matmul(grad, w2) # delta1
#
d1 = torch.matmul(grad, w2) # delta1
#d2 = torch.matmul(d1, w1) # delta2
#
d2 = torch.matmul(d1, w1) # delta2
#grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#
grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
#
grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t16
=
time
.
time
()
-
t0
t16
=
time
.
time
()
-
t0
print
(
t16
)
print
(
t16
)
#torch.cuda.empty_cache()
#
torch.cuda.empty_cache()
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#
CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#
CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#
CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#
CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#C32A, SA = F.transform2(CA, 'col32')
#
C32A, SA = F.transform2(CA, 'col32')
## fc1
## fc1
#out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
#
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
## fc2
#Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#
Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#C32out1, Sout1 = F.transform2(Cout1, 'col32')
#
C32out1, Sout1 = F.transform2(Cout1, 'col32')
#out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
#
out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
## delta1
#Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#
Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#C32grad, Sgrad = F.transform2(Cgrad, 'col32')
#
C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
## delta2
#Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#
Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#C32d1, Sd1 = F.transform2(Cd1, 'col32')
#
C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
## grad1
#C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#
C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
#
CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
## grad2
#C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#
C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
#
CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#
Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#
Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#
CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#
CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#
CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#
CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#for i in range(k):
#
for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
...
@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden):
...
@@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#t8 = time.time() - t0
# t8 = time.time() - t0
#print(t8)
# print(t8)
n
=
2
n
=
2
dim1
=
torch
.
randint
(
64
,
256
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
64
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
64
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
64
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2*1024]
#
dim1 = [2*1024]
#dim4 = [2*1024]
#
dim4 = [2*1024]
#dim1 = [4]
#
dim1 = [4]
#dim4 = [4]
#
dim4 = [4]
dims
=
(
2
,)
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
formatB
=
[
'col_turing'
,
'col_ampere'
]
formatB
=
[
"col_turing"
,
"col_ampere"
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
'dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
formatB
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
formatB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
n
=
C1
.
numel
()
n
=
C1
.
numel
()
p
=
0.06
p
=
0.06
assert
count
/
n
<
p
,
f
'error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
'
assert
(
count
/
n
<
p
),
f
"error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
"
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
())
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
())
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
#print(C2)
# print(C2)
n
=
2
n
=
2
dim1
=
[
1
*
1024
]
dim1
=
[
1
*
1024
]
dim2
=
[
1
*
1024
]
dim2
=
[
1
*
1024
]
#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#
dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims
=
(
2
,)
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dims
))
values
=
list
(
product
(
dim1
,
dim2
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dims_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dims_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dims"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dims"
,
values
,
ids
=
names
)
def
test_colrow_absmax
(
dim1
,
dim2
,
dims
):
def
test_colrow_absmax
(
dim1
,
dim2
,
dims
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
threshold
=
3.0
threshold
=
3.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
A_truncated
=
A
.
clone
()
A_truncated
=
A
.
clone
()
A_truncated
[
torch
.
abs
(
A_truncated
)
>=
3.0
]
=
0.0
A_truncated
[
torch
.
abs
(
A_truncated
)
>=
3.0
]
=
0.0
if
dims
==
2
:
if
dims
==
2
:
...
@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims):
else
:
else
:
assert
False
assert
False
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
'(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size'
,
row_tiles
=
16
,
block_size
=
64
*
4
)
)
nnz_rows1_counts
=
(
torch
.
abs
(
A_blocked
)
>=
threshold
).
sum
(
3
).
flatten
()
nnz_block_ptr1
=
torch
.
zeros
(
nnz_rows1_counts
.
shape
[
0
]
+
1
,
dtype
=
nnz_rows1_counts
.
dtype
,
device
=
nnz_rows1_counts
.
device
)
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
"(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size"
,
row_tiles
=
16
,
block_size
=
64
*
4
,
)
nnz_rows1_counts
=
(
torch
.
abs
(
A_blocked
)
>=
threshold
).
sum
(
3
).
flatten
()
nnz_block_ptr1
=
torch
.
zeros
(
nnz_rows1_counts
.
shape
[
0
]
+
1
,
dtype
=
nnz_rows1_counts
.
dtype
,
device
=
nnz_rows1_counts
.
device
,
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
torch
.
testing
.
assert_allclose
(
col_stats1_trunc
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
col_stats1_trunc
,
col_stats2
)
...
@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims):
assert
nnz_block_ptr2
is
None
assert
nnz_block_ptr2
is
None
n
=
2
n
=
2
#dim1 = [8*1024]
# dim1 = [8*1024]
#dim2 = [4*1024]
# dim2 = [4*1024]
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_double_quant
(
dim1
,
dim2
):
def
test_double_quant
(
dim1
,
dim2
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
out_col1
,
Scol
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
out_col1
,
Scol
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
out_row1
,
Srow
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
out_row1
,
Srow
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
...
@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2):
...
@@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2):
torch
.
testing
.
assert_allclose
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
n
=
CAt
.
numel
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
# allow for 1:500 error due to rounding differences
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
'
)
print
(
f
"Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
"
)
assert
False
assert
False
if
num_not_close_rows
>
(
min_error
*
n
):
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
'
)
print
(
f
"Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
"
)
assert
False
assert
False
torch
.
testing
.
assert_allclose
(
Srow
.
flatten
(),
statsA
)
torch
.
testing
.
assert_allclose
(
Srow
.
flatten
(),
statsA
)
...
@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2):
...
@@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2):
n
=
4
n
=
4
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
[
6
]
dim1
=
[
6
]
dim4
=
[
4
]
dim4
=
[
4
]
inner
=
[
8
]
inner
=
[
8
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
...
@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
...
@@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch
.
testing
.
assert_allclose
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_allclose
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_allclose
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_allclose
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
'
col_turing
'
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
"
col_turing
"
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
'
col_turing
'
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
"
col_turing
"
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"
row
"
,
state
=
SC
)
out3
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
out3
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out3
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out3
).
mean
().
item
()
assert
err2
<=
err1
*
1.01
assert
err2
<=
err1
*
1.01
n
=
6
n
=
6
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
...
@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
...
@@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1
,
relerr2
=
[],
[]
relerr1
,
relerr2
=
[],
[]
scale
=
1
scale
=
1
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
'
row
'
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
"
row
"
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
if
maxval
==
127
:
scale
=
1.5
scale
=
1.5
else
:
else
:
scale
=
maxval
/
120
scale
=
maxval
/
120
out3
=
C3
*
maxA
*
absmaxB
*
c
/
(
127
*
127
)
out3
=
C3
*
maxA
*
absmaxB
*
c
/
(
127
*
127
)
C4
=
torch
.
matmul
(
C1a
.
float
(),
CB
.
float
().
t
())
C4
=
torch
.
matmul
(
C1a
.
float
(),
CB
.
float
().
t
())
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
CA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
,
quant_type
=
'
vector
'
)
CA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
,
quant_type
=
"
vector
"
)
CB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'
linear
'
)
CB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"
linear
"
)
C
=
torch
.
matmul
(
CA
.
float
(),
CB
.
t
().
float
())
C
=
torch
.
matmul
(
CA
.
float
(),
CB
.
t
().
float
())
out4
=
C
*
SA
*
SB
/
(
127
*
127
)
out4
=
C
*
SA
*
SB
/
(
127
*
127
)
#out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#
out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#print('='*80)
#
print('='*80)
#print(out1)
#
print(out1)
#print(out2)
#
print(out2)
#print(out3)
#
print(out3)
#print(out1)
#
print(out1)
#print(out2)
#
print(out2)
#print(out3)
#
print(out3)
err1
.
append
(
torch
.
abs
(
out1
-
out2
).
mean
().
item
())
err1
.
append
(
torch
.
abs
(
out1
-
out2
).
mean
().
item
())
err2
.
append
(
torch
.
abs
(
out1
-
out3
).
mean
().
item
())
err2
.
append
(
torch
.
abs
(
out1
-
out3
).
mean
().
item
())
err3
.
append
(
torch
.
abs
(
out1
-
out4
).
mean
().
item
())
err3
.
append
(
torch
.
abs
(
out1
-
out4
).
mean
().
item
())
#assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
#
assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print
(
''
)
print
(
""
)
print
(
sum
(
err1
)
/
len
(
err1
))
print
(
sum
(
err1
)
/
len
(
err1
))
print
(
sum
(
err2
)
/
len
(
err2
))
print
(
sum
(
err2
)
/
len
(
err2
))
print
(
sum
(
err3
)
/
len
(
err3
))
print
(
sum
(
err3
)
/
len
(
err3
))
dim1
=
[
1024
,
2048
]
dim1
=
[
1024
,
2048
]
inner
=
[
12288
*
4
,
4096
*
4
]
inner
=
[
12288
*
4
,
4096
*
4
]
dim4
=
[
12288
,
4096
]
dim4
=
[
12288
,
4096
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim4_{1}_inner_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_row_scale_bench
(
dim1
,
dim4
,
inner
):
def
test_row_scale_bench
(
dim1
,
dim4
,
inner
):
err1
,
err2
,
err3
=
[],
[],
[]
err1
,
err2
,
err3
=
[],
[],
[]
relerr1
,
relerr2
=
[],
[]
relerr1
,
relerr2
=
[],
[]
scale
=
1
scale
=
1
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
# warmpup
# warmpup
for
i
in
range
(
k
):
for
i
in
range
(
k
):
...
@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C1
=
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
16
'
,
time
.
time
()
-
t0
)
print
(
"
16
"
,
time
.
time
()
-
t0
)
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'
col32
'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"
col32
"
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
c
=
10.0
*
inner
*
scale
row_scale
=
maxA
/
c
row_scale
=
maxA
/
c
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'row-wise'
,
time
.
time
()
-
t0
)
print
(
"row-wise"
,
time
.
time
()
-
t0
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
...
@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner):
...
@@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'vector-wise'
,
time
.
time
()
-
t0
)
print
(
"vector-wise"
,
time
.
time
()
-
t0
)
n
=
2
n
=
2
dim1
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [8*1024]
#
dim1 = [8*1024]
#dim2 = [4*1024]
#
dim2 = [4*1024]
dim3
=
[
0
]
dim3
=
[
0
]
dtype
=
[
torch
.
int8
]
dtype
=
[
torch
.
int8
]
a_order
=
[
'
row
'
]
a_order
=
[
"
row
"
]
out_order
=
[
'
col32
'
,
'
col_turing
'
,
'
col_ampere
'
]
out_order
=
[
"
col32
"
,
"
col_turing
"
,
"
col_ampere
"
]
transpose
=
[
False
,
True
]
transpose
=
[
False
,
True
]
dims
=
[
2
]
dims
=
[
2
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
dims
==
2
:
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
elif
dims
==
3
:
elif
dims
==
3
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
"
cuda
"
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
if
transpose
:
...
@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
...
@@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert
S1
[
0
][
0
]
==
S2
[
0
][
0
]
assert
S1
[
0
][
0
]
==
S2
[
0
][
0
]
assert
S1
[
0
][
1
]
==
S2
[
0
][
1
]
assert
S1
[
0
][
1
]
==
S2
[
0
][
1
]
#print(out1)
#
print(out1)
#print(out2)
#
print(out2)
torch
.
testing
.
assert_allclose
(
out1
,
out2
)
torch
.
testing
.
assert_allclose
(
out1
,
out2
)
n
=
2
n
=
2
#dim1 = torch.randint(2,1024, size=(n,)).tolist()
#
dim1 = torch.randint(2,1024, size=(n,)).tolist()
#dim2 = torch.randint(2,1024, size=(n,)).tolist()
#
dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1
=
[
1
]
dim1
=
[
1
]
dim2
=
[
33
]
dim2
=
[
33
]
dtype
=
[
torch
.
int8
]
dtype
=
[
torch
.
int8
]
#a_order = ['col_turing', 'col_ampere']
# a_order = ['col_turing', 'col_ampere']
a_order
=
[
'col_turing'
]
a_order
=
[
"col_turing"
]
out_order
=
[
'row'
]
out_order
=
[
"row"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
a_order
,
out_order
))
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
a_order
,
out_order
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
for
i
in
range
(
1
):
for
i
in
range
(
1
):
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'
cuda
'
).
to
(
dtype
)
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
"
cuda
"
).
to
(
dtype
)
out2
,
S2
=
F
.
transform
(
A
,
to_order
=
orderA
)
out2
,
S2
=
F
.
transform
(
A
,
to_order
=
orderA
)
A2
,
S3
=
F
.
transform
(
out2
,
from_order
=
orderA
,
to_order
=
'
row
'
,
state
=
S2
)
A2
,
S3
=
F
.
transform
(
out2
,
from_order
=
orderA
,
to_order
=
"
row
"
,
state
=
S2
)
assert
A2
.
shape
[
0
]
==
A
.
shape
[
0
]
assert
A2
.
shape
[
0
]
==
A
.
shape
[
0
]
assert
A2
.
shape
[
1
]
==
A
.
shape
[
1
]
assert
A2
.
shape
[
1
]
==
A
.
shape
[
1
]
print
(
""
)
print
(
''
)
print
(
A
)
print
(
A
)
print
(
out2
)
print
(
out2
)
print
(
A2
)
print
(
A2
)
# torch.testing.assert_allclose(A, A2)
#torch.testing.assert_allclose(A, A2)
def
test_overflow
():
def
test_overflow
():
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
print
(
formatB
)
print
(
formatB
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
a
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
a
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
b
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
b
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
Ca
,
Sa
=
F
.
nvidia_transform
(
a
,
'
col32
'
)
Ca
,
Sa
=
F
.
nvidia_transform
(
a
,
"
col32
"
)
Cb
,
Sb
=
F
.
nvidia_transform
(
b
,
formatB
)
Cb
,
Sb
=
F
.
nvidia_transform
(
b
,
formatB
)
c
=
F
.
igemmlt
(
Ca
,
Cb
,
Sa
,
Sb
,
dtype
=
torch
.
int8
)
c
=
F
.
igemmlt
(
Ca
,
Cb
,
Sa
,
Sb
,
dtype
=
torch
.
int8
)
...
@@ -1198,46 +1315,51 @@ def test_overflow():
...
@@ -1198,46 +1315,51 @@ def test_overflow():
n
=
2
n
=
2
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [4]
# dim1 = [4]
#dim2 = [5]
# dim2 = [5]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_coo_double_quant
(
dim1
,
dim2
):
def
test_coo_double_quant
(
dim1
,
dim2
):
threshold
=
3.00
threshold
=
3.00
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
idx
=
(
torch
.
abs
(
A
)
>=
threshold
)
idx
=
torch
.
abs
(
A
)
>=
threshold
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
=
torch
.
zeros_like
(
A
)
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
n
=
2
n
=
2
dim1
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [7]
#
dim1 = [7]
#dim2 = [11]
#
dim2 = [11]
transposed_B
=
[
False
,
True
]
transposed_B
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
transposed_B
))
values
=
list
(
product
(
dim1
,
dim2
,
transposed_B
))
names
=
[
'dim1_{0}_dim2_{1}_transposed_B_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_transposed_B_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, transposed_B"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, transposed_B"
,
values
,
ids
=
names
)
def
test_spmm_coo
(
dim1
,
dim2
,
transposed_B
):
def
test_spmm_coo
(
dim1
,
dim2
,
transposed_B
):
threshold
=
1.5
threshold
=
1.5
dim3
=
torch
.
randint
(
32
,
128
,
size
=
(
1
,)).
item
()
dim3
=
torch
.
randint
(
32
,
128
,
size
=
(
1
,)).
item
()
#dim3 = 17
#
dim3 = 17
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
if
transposed_B
:
if
transposed_B
:
...
@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
...
@@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A2
=
A
*
idx
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
if
transposed_B
:
if
transposed_B
:
out2
=
F
.
spmm_coo
(
cooA
,
B
.
t
())
out2
=
F
.
spmm_coo
(
cooA
,
B
.
t
())
...
@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
...
@@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
30
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
30
)
def
test_spmm_bench
():
def
test_spmm_bench
():
batch
=
2
batch
=
2
model
=
1024
*
1
model
=
1024
*
1
hidden
=
model
*
4
hidden
=
model
*
4
seq
=
1024
seq
=
1024
dim1
=
batch
*
seq
dim1
=
batch
*
seq
dim2
=
model
dim2
=
model
dim3
=
hidden
dim3
=
hidden
threshold
=
4
threshold
=
4
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
randn
(
dim2
,
dim3
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim3
,
device
=
"
cuda
"
).
half
()
for
i
in
range
(
10
):
for
i
in
range
(
10
):
C1
=
bnb
.
matmul
(
A
,
B
)
C1
=
bnb
.
matmul
(
A
,
B
)
...
@@ -1282,14 +1405,16 @@ def test_spmm_bench():
...
@@ -1282,14 +1405,16 @@ def test_spmm_bench():
for
i
in
range
(
k
):
for
i
in
range
(
k
):
C1
=
bnb
.
matmul
(
A
,
B
)
C1
=
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t8
=
time
.
time
()
-
t0
t8
=
time
.
time
()
-
t0
idx
=
torch
.
abs
(
A
)
>=
threshold
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
print
(
nnz
/
idx
.
numel
())
print
(
nnz
/
idx
.
numel
())
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
for
i
in
range
(
10
):
for
i
in
range
(
10
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
...
@@ -1299,20 +1424,22 @@ def test_spmm_bench():
...
@@ -1299,20 +1424,22 @@ def test_spmm_bench():
for
i
in
range
(
k
):
for
i
in
range
(
k
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
tsp
=
time
.
time
()
-
t0
tsp
=
time
.
time
()
-
t0
print
(
tsp
,
t8
)
print
(
tsp
,
t8
)
print
(
tsp
/
t8
)
print
(
tsp
/
t8
)
n
=
2
n
=
2
dim1
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_integrated_sparse_decomp
(
dim1
,
dim2
):
def
test_integrated_sparse_decomp
(
dim1
,
dim2
):
threshold
=
3.0
threshold
=
3.0
formatB
=
'
col_turing
'
formatB
=
"
col_turing
"
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
w1
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
w1
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
...
@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
...
@@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1
,
Sw1
=
F
.
transform
(
Cw1
,
formatB
)
CTw1
,
Sw1
=
F
.
transform
(
Cw1
,
formatB
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out3
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
out3
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
...
@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
...
@@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4
=
F
.
spmm_coo
(
coo_tensor
,
w1
.
t
())
out4
=
F
.
spmm_coo
(
coo_tensor
,
w1
.
t
())
out5
=
out3
+
out4
out5
=
out3
+
out4
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out5
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out5
).
mean
().
item
()
assert
err2
<
err1
assert
err2
<
err1
...
@@ -1350,91 +1477,95 @@ def test_matmuls():
...
@@ -1350,91 +1477,95 @@ def test_matmuls():
c2
=
bnb
.
matmul
(
a
,
b
)
c2
=
bnb
.
matmul
(
a
,
b
)
c3
=
bnb
.
matmul
(
a
,
b
)
c3
=
bnb
.
matmul
(
a
,
b
)
err1
=
torch
.
abs
(
c1
-
c2
).
mean
().
item
()
err1
=
torch
.
abs
(
c1
-
c2
).
mean
().
item
()
err2
=
torch
.
abs
(
c1
-
c3
).
mean
().
item
()
err2
=
torch
.
abs
(
c1
-
c3
).
mean
().
item
()
assert
err1
<
0.2
assert
err1
<
0.2
assert
err2
<
0.2
assert
err2
<
0.2
n
=
2
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#
dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
dim1
=
[
1
*
2048
]
dim2
=
[
12288
]
dim2
=
[
12288
]
#dim1 = [32]
#
dim1 = [32]
#dim2 = [32]
#
dim2 = [32]
#dtype = [torch.float16, torch.int8]
#
dtype = [torch.float16, torch.int8]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
]
out_function
=
[
'zeros'
,
'ones'
]
out_function
=
[
"zeros"
,
"ones"
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
def
test_spmm_coo_very_sparse
(
dim1
,
dim2
,
dtype
,
out_func
):
def
test_spmm_coo_very_sparse
(
dim1
,
dim2
,
dtype
,
out_func
):
out_func
=
getattr
(
torch
,
out_func
)
out_func
=
getattr
(
torch
,
out_func
)
threshold
=
3.3
threshold
=
3.3
#threshold = 2.8
#
threshold = 2.8
#threshold = 0.0
#
threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
else
:
else
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B
,
SB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
B
,
SB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
#B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
#
B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print
(
''
)
print
(
""
)
idx
=
torch
.
abs
(
A
)
>=
threshold
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A2
=
A
*
idx
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out1
+=
out
.
clone
()
out1
+=
out
.
clone
()
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
B
,
out
=
out
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
B
,
out
=
out
)
#print(B)
#
print(B)
#print(out1)
#
print(out1)
#print(out2)
#
print(out2)
p
=
200
/
(
2048
*
12288
*
4
)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
count
=
math
.
ceil
(
p
*
n
)
std
=
out1
.
std
()
std
=
out1
.
std
()
out1
/=
std
out1
/=
std
out2
/=
std
out2
/=
std
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
#assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
#
assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
#torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#
torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#
Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#t0 = time.time()
#
t0 = time.time()
#print(A2.shape, B.shape)
#
print(A2.shape, B.shape)
#for i in range(100):
#
for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
# #out1 = torch.matmul(A, Bt.t())
#torch.cuda.synchronize()
# torch.cuda.synchronize()
#print(time.time() - t0)
# print(time.time() - t0)
def
test_layout
():
def
test_layout
():
a1
=
torch
.
rand
(
16
,
64
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
a1
=
torch
.
rand
(
16
,
64
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
a1
=
torch
.
arange
(
16
*
64
,
device
=
'
cuda
'
).
reshape
(
16
,
64
).
byte
()
a1
=
torch
.
arange
(
16
*
64
,
device
=
"
cuda
"
).
reshape
(
16
,
64
).
byte
()
a2
,
s2
=
F
.
transform
(
a1
,
'
col_turing
'
)
a2
,
s2
=
F
.
transform
(
a1
,
"
col_turing
"
)
print
(
a2
.
shape
)
print
(
a2
.
shape
)
print
(
a1
.
flatten
()[
8
*
64
:
8
*
64
+
32
])
print
(
a1
.
flatten
()[
8
*
64
:
8
*
64
+
32
])
for
i
in
range
(
4
):
for
i
in
range
(
4
):
print
(
a2
.
flatten
()[
i
*
8
*
32
:
i
*
8
*
32
+
32
],
0
)
print
(
a2
.
flatten
()[
i
*
8
*
32
:
i
*
8
*
32
+
32
],
0
)
def
test_coo2csr
():
def
test_coo2csr
():
...
@@ -1444,14 +1575,16 @@ def test_coo2csr():
...
@@ -1444,14 +1575,16 @@ def test_coo2csr():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A2
=
A
*
idx
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
csrA
=
F
.
coo2csr
(
cooA
)
csrA
=
F
.
coo2csr
(
cooA
)
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
1
))
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
1
))
idx
=
(
A2
!=
0
)
idx
=
A2
!=
0
torch
.
testing
.
assert_allclose
(
A2
[
idx
],
csrA
.
values
)
torch
.
testing
.
assert_allclose
(
A2
[
idx
],
csrA
.
values
)
...
@@ -1462,41 +1595,43 @@ def test_coo2csc():
...
@@ -1462,41 +1595,43 @@ def test_coo2csc():
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A2
=
A
*
idx
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cscA
=
F
.
coo2csc
(
cooA
)
cscA
=
F
.
coo2csc
(
cooA
)
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
0
))
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
0
))
# torch uses row-major -> use transpose to transfer to col-major
# torch uses row-major -> use transpose to transfer to col-major
idx
=
(
A2
.
t
()
!=
0
)
idx
=
A2
.
t
()
!=
0
torch
.
testing
.
assert_allclose
(
A2
.
t
()[
idx
],
cscA
.
values
)
torch
.
testing
.
assert_allclose
(
A2
.
t
()[
idx
],
cscA
.
values
)
n
=
2
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#
dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
#
dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
dim1
=
[
1
*
2048
]
#dim2 = [12288]
#
dim2 = [12288]
dim2
=
[
2048
]
dim2
=
[
2048
]
#dim1 = [2]
#
dim1 = [2]
#dim2 = [2]
#
dim2 = [2]
dtype
=
[
torch
.
int8
]
dtype
=
[
torch
.
int8
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
))
values
=
list
(
product
(
dim1
,
dim2
,
dtype
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_dtype_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype"
,
values
,
ids
=
names
)
def
test_spmm_coo_dequant
(
dim1
,
dim2
,
dtype
):
def
test_spmm_coo_dequant
(
dim1
,
dim2
,
dtype
):
threshold
=
6.0
threshold
=
6.0
#threshold = 2.8
#
threshold = 2.8
#threshold = 0.0
#
threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
empty
(
dim2
,
dim2
*
4
,
device
=
'
cuda
'
,
dtype
=
torch
.
float16
)
B
=
torch
.
empty
(
dim2
,
dim2
*
4
,
device
=
"
cuda
"
,
dtype
=
torch
.
float16
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
Bt
=
B
.
t
().
contiguous
()
Bt
=
B
.
t
().
contiguous
()
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
rowidx
=
torch
.
randint
(
0
,
A
.
shape
[
-
1
],
size
=
(
15
,))
rowidx
=
torch
.
randint
(
0
,
A
.
shape
[
-
1
],
size
=
(
15
,))
...
@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz
=
(
idx
==
1
).
sum
().
item
()
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
cooA
=
F
.
COOSparseTensor
(
A2
=
A
*
idx
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out3
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
.
half
())
out3
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
.
half
())
out3
=
out3
*
statsBt
.
half
()
/
127
out3
=
out3
*
statsBt
.
half
()
/
127
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
offset
=
counts
.
cumsum
(
0
).
int
()
offset
=
counts
.
cumsum
(
0
).
int
()
...
@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch
.
testing
.
assert_allclose
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
torch
.
testing
.
assert_allclose
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
p
=
200
/
(
2048
*
12288
*
4
)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
count
=
math
.
ceil
(
p
*
n
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
# torch.cuda.synchronize()
# t0 = time.time()
#torch.cuda.synchronize()
# for i in range(100):
#t0 = time.time()
#for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
# out2 = F.spmm_coo_very_sparse(cooA, B)
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
#print('fp16', time.time() - t0)
#
print('fp16', time.time() - t0)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
cusparse fp16
'
,
time
.
time
()
-
t0
)
print
(
"
cusparse fp16
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
int8
'
,
time
.
time
()
-
t0
)
print
(
"
int8
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
int8+dequant
'
,
time
.
time
()
-
t0
)
print
(
"
int8+dequant
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out2
=
torch
.
matmul
(
A
,
B
)
out2
=
torch
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
matmul
'
,
time
.
time
()
-
t0
)
print
(
"
matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
out1
=
bnb
.
matmul
(
A
,
Bt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out
=
out1
+
out2
out
=
out1
+
out2
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
sparse+ matmul
'
,
time
.
time
()
-
t0
)
print
(
"
sparse+ matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1
=
bnb
.
matmul
(
A
,
Bt
)
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
matmul
(
A
[:,
rowidx
],
Bt
.
t
()[
rowidx
],
out
=
out1
)
torch
.
matmul
(
A
[:,
rowidx
],
Bt
.
t
()[
rowidx
],
out
=
out1
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'
partial matmul
'
,
time
.
time
()
-
t0
)
print
(
"
partial matmul
"
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'partial matmul'
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
1
batch_size
=
1
seqdim
=
2048
seqdim
=
2048
values
=
[]
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
# values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{0}_seq_{1}_model_{2}_hidden_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'
cuda
'
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"
cuda
"
).
half
()
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
'
cuda
'
)
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"
cuda
"
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
...
@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
)
linearMixedBit
.
eval
()
linearMixedBit
.
eval
()
# warmup
# warmup
for
i
in
range
(
100
):
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
''
)
print
(
""
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'pytorch: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"pytorch: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
bnb
.
matmul
(
A
,
B
)
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'bnb lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
C32A
,
SA
=
F
.
transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
transform
(
CA
,
"
col32
"
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
for
i
in
range
(
100
):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'igemmlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"igemmlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
...
@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'
row
'
,
state
=
Sout32
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"
row
"
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'
linear
'
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"
linear
"
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
'
linear
'
)
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
"
linear
"
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'
col32
'
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"
col32
"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'
row
'
,
state
=
Sout32
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"
row
"
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linear8bit
(
A
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
for
i
in
range
(
100
):
linear8bit
(
A
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
100
):
for
i
in
range
(
100
):
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
print
(
f
"bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
def
test_zeropoint
():
def
test_zeropoint
():
def
min_max
(
x
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
1
,
keepdim
=
True
)
maxA
=
torch
.
amax
(
x
,
dim
=
1
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
1
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
1
,
keepdim
=
True
)
midpoint
=
(
maxA
-
minA
)
/
2.0
midpoint
=
(
maxA
-
minA
)
/
2.0
dyna
=
252
/
(
maxA
-
minA
)
dyna
=
252
/
(
maxA
-
minA
)
#dyna *= 0.98
#
dyna *= 0.98
x
=
dyna
*
x
x
=
dyna
*
x
x
=
x
-
torch
.
round
((
dyna
*
(
minA
+
midpoint
)))
x
=
x
-
torch
.
round
((
dyna
*
(
minA
+
midpoint
)))
return
x
.
to
(
torch
.
int8
),
minA
,
midpoint
,
dyna
return
x
.
to
(
torch
.
int8
),
minA
,
midpoint
,
dyna
batch
=
2
batch
=
2
seq
=
2
seq
=
2
model
=
4
model
=
4
hidden
=
2
*
model
hidden
=
2
*
model
#batch = 4
#
batch = 4
#seq = 2048
#
seq = 2048
#model = 1024
#
model = 1024
#hidden = 8*model
#
hidden = 8*model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'
cuda
'
).
half
()
-
0.4
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
"
cuda
"
).
half
()
-
0.4
B
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
model
,
hidden
,
device
=
'
cuda
'
).
half
())
B
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
model
,
hidden
,
device
=
"
cuda
"
).
half
())
#A[0] = 0
#
A[0] = 0
#B[:, 0] = 0
#
B[:, 0] = 0
#A = A*(A>0)
#
A = A*(A>0)
#A[0, 0] = 0
#
A[0, 0] = 0
#A[0, 0] = 6.0
#
A[0, 0] = 6.0
Ac
,
minA
,
midpoint
,
dyna
=
min_max
(
A
)
Ac
,
minA
,
midpoint
,
dyna
=
min_max
(
A
)
#print(Ac[0, 0], 'zero')
#
print(Ac[0, 0], 'zero')
#print(Ac, Ac.min(), Ac.max())
#
print(Ac, Ac.min(), Ac.max())
Bc
,
maxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'
linear
'
)
Bc
,
maxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
"
linear
"
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
out2
=
torch
.
matmul
(
A
,
B
)
offset
=
B
.
sum
(
0
)
*
torch
.
round
(
dyna
*
(
minA
+
midpoint
))
/
dyna
offset
=
B
.
sum
(
0
)
*
torch
.
round
(
dyna
*
(
minA
+
midpoint
))
/
dyna
out
=
out
.
float
()
out
=
out
.
float
()
#print(out.shape, maxB.shape, scale.shape, offset.shape)
# print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1
=
maxB
/
127
norm1
=
maxB
/
127
C4
=
(
out
/
dyna
)
*
norm1
+
offset
C4
=
(
out
/
dyna
)
*
norm1
+
offset
B1
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B1
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B2
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B2
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B3
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B3
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B4
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B4
=
torch
.
nn
.
Parameter
(
B
.
clone
())
C1
=
torch
.
matmul
(
A
,
B1
)
C1
=
torch
.
matmul
(
A
,
B1
)
C2
=
bnb
.
matmul_cublas
(
A
,
B2
,
None
,
'
linear
'
)
C2
=
bnb
.
matmul_cublas
(
A
,
B2
,
None
,
"
linear
"
)
C3
=
bnb
.
matmul_cublas
(
A
,
B3
,
None
,
'
zeropoint
'
)
C3
=
bnb
.
matmul_cublas
(
A
,
B3
,
None
,
"
zeropoint
"
)
C4
=
bnb
.
matmul_cublas
(
A
,
B4
,
None
,
'
vector-zeropoint
'
)
C4
=
bnb
.
matmul_cublas
(
A
,
B4
,
None
,
"
vector-zeropoint
"
)
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
print
(
err1
,
err2
,
err3
)
#assert err1 > err2
#
assert err1 > err2
loss1
=
C1
.
mean
()
loss1
=
C1
.
mean
()
loss2
=
C2
.
mean
()
loss2
=
C2
.
mean
()
...
@@ -1765,40 +1917,38 @@ def test_zeropoint():
...
@@ -1765,40 +1917,38 @@ def test_zeropoint():
print
(
B2
.
grad
)
print
(
B2
.
grad
)
print
(
B3
.
grad
)
print
(
B3
.
grad
)
print
(
B4
.
grad
)
print
(
B4
.
grad
)
err1
=
torch
.
abs
(
B1
.
grad
-
B2
.
grad
).
mean
().
item
()
err1
=
torch
.
abs
(
B1
.
grad
-
B2
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
B1
.
grad
-
B3
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
B1
.
grad
-
B3
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
B1
.
grad
-
B4
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
B1
.
grad
-
B4
.
grad
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
print
(
err1
,
err2
,
err3
)
def
test_zp
():
def
test_zp
():
def
quant_zp
(
x
):
def
quant_zp
(
x
):
dtype
=
x
.
dtype
dtype
=
x
.
dtype
x
=
x
.
float
()
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
if
dyna
==
0
:
qx
=
254.
/
dyna
dyna
=
1
qx
=
254.0
/
dyna
minx
=
x
.
min
()
minx
=
x
.
min
()
#zpx = torch.round(minx* qx)
#
zpx = torch.round(minx* qx)
#zpx = 127 - torch.round(x.max()* qx)
#
zpx = 127 - torch.round(x.max()* qx)
zpx
=
torch
.
round
(
x
.
min
()
*
qx
)
-
127
zpx
=
torch
.
round
(
x
.
min
()
*
qx
)
-
127
x
=
(
qx
*
x
)
+
zpx
x
=
(
qx
*
x
)
+
zpx
return
x
,
qx
,
zpx
return
x
,
qx
,
zpx
batch
=
2
batch
=
2
seq
=
512
seq
=
512
model
=
1024
model
=
1024
hidden
=
4
*
model
hidden
=
4
*
model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'cuda'
).
half
()
*
0.1
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
"cuda"
).
half
()
*
0.1
B
=
torch
.
randn
(
model
,
hidden
,
device
=
'cuda'
).
half
()
*
0.1
B
=
torch
.
randn
(
model
,
hidden
,
device
=
"cuda"
).
half
()
*
0.1
C0
=
torch
.
matmul
(
A
,
B
)
C0
=
torch
.
matmul
(
A
,
B
)
# A, SA = F.vectorwise_quant(A, quant_type='linear')
#A, SA = F.vectorwise_quant(A, quant_type='linear')
# B, SB = F.vectorwise_quant(B, quant_type='linear')
#B, SB = F.vectorwise_quant(B, quant_type='linear')
A
=
A
.
float
()
A
=
A
.
float
()
B
=
B
.
float
()
B
=
B
.
float
()
...
@@ -1806,69 +1956,68 @@ def test_zp():
...
@@ -1806,69 +1956,68 @@ def test_zp():
C3
=
bnb
.
matmul
(
A
.
half
(),
B
.
t
().
contiguous
().
half
())
C3
=
bnb
.
matmul
(
A
.
half
(),
B
.
t
().
contiguous
().
half
())
zp
=
1
zp
=
1
#C2 = torch.matmul(A-zp, B)
#
C2 = torch.matmul(A-zp, B)
#C2 += B.sum(0).view(1, -1)*zp
#
C2 += B.sum(0).view(1, -1)*zp
C2
=
torch
.
matmul
(
A
,
B
-
zp
)
C2
=
torch
.
matmul
(
A
,
B
-
zp
)
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
ca
,
cqa
,
cza
=
quant_zp
(
A
)
ca
,
cqa
,
cza
=
quant_zp
(
A
)
print
(
ca
.
min
(),
ca
.
max
())
print
(
ca
.
min
(),
ca
.
max
())
print
((
ca
-
cza
).
min
(),
(
ca
-
cza
).
max
())
print
((
ca
-
cza
).
min
(),
(
ca
-
cza
).
max
())
zp
=
1
zp
=
1
scale
=
2.0
scale
=
2.0
C5
=
torch
.
matmul
((
A
*
scale
)
-
zp
,
B
)
C5
=
torch
.
matmul
((
A
*
scale
)
-
zp
,
B
)
C5
+=
B
.
sum
(
0
)
*
zp
C5
+=
B
.
sum
(
0
)
*
zp
C5
/=
scale
C5
/=
scale
CA
,
qa
,
zpa
=
quant_zp
(
A
)
CA
,
qa
,
zpa
=
quant_zp
(
A
)
C4
=
torch
.
matmul
(
CA
,
B
)
C4
=
torch
.
matmul
(
CA
,
B
)
C4
-=
B
.
sum
(
0
)
*
zpa
C4
-=
B
.
sum
(
0
)
*
zpa
C4
/=
qa
C4
/=
qa
zpb
=
1
zpb
=
1
zpa
=
1
zpa
=
1
qa
=
2
qa
=
2
qb
=
2
qb
=
2
C6
=
torch
.
matmul
((
A
*
qa
)
+
zpa
,
(
B
*
qb
)
+
zpb
)
C6
=
torch
.
matmul
((
A
*
qa
)
+
zpa
,
(
B
*
qb
)
+
zpb
)
C6
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C6
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C6
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C6
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C6
/=
qa
*
qb
C6
/=
qa
*
qb
CA
,
qa
,
zpa
=
quant_zp
(
A
)
CA
,
qa
,
zpa
=
quant_zp
(
A
)
CB
,
qb
,
zpb
=
quant_zp
(
B
)
CB
,
qb
,
zpb
=
quant_zp
(
B
)
C7
=
torch
.
matmul
(
CA
,
CB
)
C7
=
torch
.
matmul
(
CA
,
CB
)
C7
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C7
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C7
/=
qa
*
qb
C7
/=
qa
*
qb
print
(
''
)
print
(
""
)
#print(C0.flatten()[:10])
#
print(C0.flatten()[:10])
print
(
C1
.
flatten
()[:
10
])
print
(
C1
.
flatten
()[:
10
])
print
(
C2
.
flatten
()[:
10
])
print
(
C2
.
flatten
()[:
10
])
print
(
C3
.
flatten
()[:
10
])
print
(
C3
.
flatten
()[:
10
])
print
(
C5
.
flatten
()[:
10
])
print
(
C5
.
flatten
()[:
10
])
print
(
C6
.
flatten
()[:
10
])
print
(
C6
.
flatten
()[:
10
])
print
(
C7
.
flatten
()[:
10
])
print
(
C7
.
flatten
()[:
10
])
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err4
=
torch
.
abs
(
C1
-
C5
).
mean
().
item
()
err4
=
torch
.
abs
(
C1
-
C5
).
mean
().
item
()
err5
=
torch
.
abs
(
C1
-
C6
).
mean
().
item
()
err5
=
torch
.
abs
(
C1
-
C6
).
mean
().
item
()
err6
=
torch
.
abs
(
C1
-
C7
).
mean
().
item
()
err6
=
torch
.
abs
(
C1
-
C7
).
mean
().
item
()
print
(
err1
,
err2
,
err3
,
err4
,
err5
,
err6
)
print
(
err1
,
err2
,
err3
,
err4
,
err5
,
err6
)
def
test_extract_outliers
():
def
test_extract_outliers
():
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
4096
,
4096
*
4
)
shapeA
=
(
4096
,
4096
*
4
)
idx
=
torch
.
unique
(
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
()).
cuda
()
idx
=
torch
.
unique
(
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
()).
cuda
()
#idx = torch.Tensor([0]).int().cuda()
#
idx = torch.Tensor([0]).int().cuda()
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'
cuda
'
).
to
(
torch
.
int8
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
"
cuda
"
).
to
(
torch
.
int8
)
outliers1
=
A
[:,
idx
.
long
()]
outliers1
=
A
[:,
idx
.
long
()]
CA
,
SA
=
F
.
transform
(
A
,
'
col_turing
'
)
CA
,
SA
=
F
.
transform
(
A
,
"
col_turing
"
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
...
@@ -1877,7 +2026,7 @@ def test_extract_outliers():
...
@@ -1877,7 +2026,7 @@ def test_extract_outliers():
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
CA
,
SA
=
F
.
transform
(
A
,
'
col_ampere
'
)
CA
,
SA
=
F
.
transform
(
A
,
"
col_ampere
"
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
...
...
tests/test_modules.py
View file @
bfa0e332
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
torch
from
itertools
import
product
from
torch
import
nn
from
torch
import
nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
class
MockArgs
(
object
):
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
def
__init__
(
self
,
initial_data
):
for
key
in
initial_data
:
for
key
in
initial_data
:
setattr
(
self
,
key
,
initial_data
[
key
])
setattr
(
self
,
key
,
initial_data
[
key
])
class
MLP8bit
(
torch
.
nn
.
Module
):
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
...
@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
...
@@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def
get_args
():
def
get_args
():
args
=
MockArgs
([])
args
=
MockArgs
([])
args
.
quant_type
=
'
vector
'
args
.
quant_type
=
"
vector
"
args
.
use_8bit_training
=
'
full
'
args
.
use_8bit_training
=
"
full
"
args
.
clip_freq
=
9999
args
.
clip_freq
=
9999
return
args
return
args
def
assert_all_approx_close
(
a
,
b
,
atol
=
1e-8
,
rtol
=
1e-5
,
count
=
10
):
def
assert_all_approx_close
(
a
,
b
,
atol
=
1e-8
,
rtol
=
1e-5
,
count
=
10
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
if
sumval
>
count
:
print
(
f
'
Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
print
(
f
"
Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
#std = torch.abs(x).mean()*norm
#
std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
std
=
torch
.
std
(
x
)
max1
=
std
*
trim_value
max1
=
std
*
trim_value
x
=
x
/
max1
*
127
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
x
=
round_func
(
x
)
x
[
x
>
127
]
=
127
x
[
x
>
127
]
=
127
x
[
x
<
-
127
]
=
-
127
x
[
x
<
-
127
]
=
-
127
x
=
x
/
127
*
max1
x
=
x
/
127
*
max1
return
x
return
x
def
quant
(
x
,
quant_type
,
dim
=
1
):
def
quant
(
x
,
quant_type
,
dim
=
1
):
if
quant_type
==
'
linear
'
:
if
quant_type
==
"
linear
"
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
return
xq
,
max1
elif
quant_type
==
'
vector
'
:
elif
quant_type
==
"
vector
"
:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
return
xq
,
max1
elif
quant_type
==
'
min-max
'
:
elif
quant_type
==
"
min-max
"
:
maxA
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
maxA
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
minA
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
minA
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
scale
=
(
maxA
-
minA
)
/
2.0
scale
=
(
maxA
-
minA
)
/
2.0
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
return
xq
,
(
minA
.
float
(),
scale
.
float
())
return
xq
,
(
minA
.
float
(),
scale
.
float
())
else
:
return
None
else
:
return
None
def
dequant
(
xq
,
S1
,
S2
,
dtype
,
quant_type
):
def
dequant
(
xq
,
S1
,
S2
,
dtype
,
quant_type
):
if
quant_type
==
'
linear
'
:
if
quant_type
==
"
linear
"
:
norm
=
S1
*
S2
/
(
127
*
127
)
norm
=
S1
*
S2
/
(
127
*
127
)
# double cast needed to prevent overflows
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'
vector
'
:
elif
quant_type
==
"
vector
"
:
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
S1
=
S1
.
squeeze
(
0
)
#print(x.shape, S1.shape, S2.shape)
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
# print(x.shape, S1.shape, S2.shape)
if
len
(
S1
.
shape
)
==
2
:
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
.
t
()
/
127
x
*=
S1
.
t
()
/
127
else
:
else
:
x
*=
S1
/
127
x
*=
S1
/
127
x
*=
S2
/
127
x
*=
S2
/
127
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
else
:
return
None
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
):
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
x
*=
SB
.
t
()
/
127
else
:
else
:
x
*=
SB
/
127
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
x
+=
offset
return
x
.
to
(
dtype
)
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
abs
(
x
).
max
()
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
x
=
round_func
(
x
)
/
127
*
max1
#x = torch.round(x)/128*max1
#
x = torch.round(x)/128*max1
return
x
return
x
@
staticmethod
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
(
x
*
127
)
/
max1
x
=
round_func
(
x
)
/
127
*
max1
x
=
round_func
(
x
)
/
127
*
max1
return
x
return
x
@
staticmethod
@
staticmethod
def
round_stoachastic
(
x
):
def
round_stoachastic
(
x
):
sign
=
torch
.
sign
(
x
)
sign
=
torch
.
sign
(
x
)
absx
=
torch
.
abs
(
x
)
absx
=
torch
.
abs
(
x
)
decimal
=
absx
-
torch
.
floor
(
absx
)
decimal
=
absx
-
torch
.
floor
(
absx
)
rdm
=
torch
.
rand_like
(
decimal
)
rdm
=
torch
.
rand_like
(
decimal
)
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
@
staticmethod
@
staticmethod
def
fake_8bit_storage
(
w
,
exponent_bits
):
def
fake_8bit_storage
(
w
,
exponent_bits
):
...
@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
...
@@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
fake_8bit_storage_quantile
(
w
,
args
):
def
fake_8bit_storage_quantile
(
w
,
args
):
code
=
bnb
.
functional
.
estimate_quantiles
(
w
.
data
,
offset
=
args
.
offset
)
code
=
bnb
.
functional
.
estimate_quantiles
(
w
.
data
,
offset
=
args
.
offset
)
#C = bnb.functional.quantize_no_absmax(code, w)
#
C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#
out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#
print(out)
#out = out.half()
#
out = out.half()
code
/=
torch
.
max
(
torch
.
abs
(
code
))
code
/=
torch
.
max
(
torch
.
abs
(
code
))
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
...
@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
...
@@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
fake_8bit_storage_with_max
(
w
,
topk
=
8
):
def
fake_8bit_storage_with_max
(
w
,
topk
=
8
):
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
'
(h b) -> h b
'
,
b
=
256
)
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
"
(h b) -> h b
"
,
b
=
256
)
max_val
,
idx
=
torch
.
sort
(
torch
.
abs
(
blocked_w
),
dim
=
1
,
descending
=
True
)
max_val
,
idx
=
torch
.
sort
(
torch
.
abs
(
blocked_w
),
dim
=
1
,
descending
=
True
)
idx
=
idx
[:,
:
topk
]
idx
=
idx
[:,
:
topk
]
max_val
=
max_val
[:,
:
topk
]
max_val
=
max_val
[:,
:
topk
]
...
@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
...
@@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w
.
copy_
(
unblocked_w
)
w
.
copy_
(
unblocked_w
)
return
unblocked_w
return
unblocked_w
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
=
None
,
args
=
None
):
def
forward
(
ctx
,
x
,
weight
,
bias
=
None
,
args
=
None
):
if
args
.
use_8bit_training
!=
'
off
'
:
if
args
.
use_8bit_training
!=
"
off
"
:
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
#if torch.rand(1) < 0.01:
#
if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#
output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#
err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#
relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
#
print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else
:
else
:
#output = torch.matmul(x, weight.t())
#
output = torch.matmul(x, weight.t())
output
=
torch
.
einsum
(
'
bsi,oi->bso
'
,
x
,
weight
)
output
=
torch
.
einsum
(
"
bsi,oi->bso
"
,
x
,
weight
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
)
ctx
.
args
=
args
ctx
.
args
=
args
...
@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
...
@@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args
=
ctx
.
args
args
=
ctx
.
args
stochastic
=
False
stochastic
=
False
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
# weight and x are already 8bit
# weight and x are already 8bit
# -> transform grad_output to 8-bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
'forward+wgrad'
:
if
args
.
use_8bit_training
==
"forward+wgrad"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
#
grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
'full'
:
elif
args
.
use_8bit_training
==
"full"
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
]
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
grad_weight
=
torch
.
einsum
(
'
bsi,bso->oi
'
,
x
,
grad_output
)
grad_weight
=
torch
.
einsum
(
"
bsi,bso->oi
"
,
x
,
grad_output
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
return
grad_input
,
grad_weight
,
grad_bias
,
None
class
Linear8bit
(
nn
.
Module
):
class
Linear8bit
(
nn
.
Module
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
args
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
args
=
None
):
super
(
Linear8bit
,
self
).
__init__
()
super
(
Linear8bit
,
self
).
__init__
()
...
@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
...
@@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if
bias
:
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_features
))
else
:
else
:
self
.
register_parameter
(
'
bias
'
,
None
)
self
.
register_parameter
(
"
bias
"
,
None
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
...
@@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
args
)
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
args
)
def
test_linear8bit
():
def
test_linear8bit
():
l0
=
torch
.
nn
.
Linear
(
32
,
64
).
cuda
().
half
()
l0
=
torch
.
nn
.
Linear
(
32
,
64
).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l2
=
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l2
=
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l0
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l0
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l0
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l0
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
...
@@ -292,8 +314,8 @@ def test_linear8bit():
...
@@ -292,8 +314,8 @@ def test_linear8bit():
l3
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l3
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
'
cuda
'
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
"
cuda
"
).
half
()
b2
=
b1
.
clone
()
b2
=
b1
.
clone
()
b3
=
b1
.
clone
()
b3
=
b1
.
clone
()
b0
=
b1
.
clone
()
b0
=
b1
.
clone
()
...
@@ -318,16 +340,20 @@ def test_linear8bit():
...
@@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
assert
err1
*
0.8
<
err2
assert
err1
*
0.8
<
err2
assert
err2
*
0.8
<
err3
assert
err2
*
0.8
<
err3
assert
err3
*
0.8
<
err1
assert
err3
*
0.8
<
err1
l0
.
weight
.
grad
=
None
l0
.
weight
.
grad
=
None
l1
.
weight
.
grad
=
None
l1
.
weight
.
grad
=
None
...
@@ -341,23 +367,28 @@ def test_linear8bit():
...
@@ -341,23 +367,28 @@ def test_linear8bit():
threshold
=
[
0.0
,
3.0
]
threshold
=
[
0.0
,
3.0
]
values
=
threshold
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
names
=
[
"threshold_{0}"
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_inference
(
threshold
):
def
test_linear8bitlt_inference
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
assert
l1
.
weight
.
device
.
type
==
'
cuda
'
assert
l1
.
weight
.
device
.
type
==
"
cuda
"
assert
l1
.
weight
.
dtype
==
torch
.
float16
assert
l1
.
weight
.
dtype
==
torch
.
float16
l1
.
eval
()
l1
.
eval
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
if
i
==
1
:
if
i
==
1
:
assert
l1
.
state
.
CxB
is
not
None
assert
l1
.
state
.
CxB
is
not
None
def
test_linear8bitlt_accumulated_gradient
():
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l1
=
torch
.
nn
.
Sequential
(
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
)
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
...
@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps
=
10
acc_steps
=
10
for
i
in
range
(
10
):
for
i
in
range
(
10
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
o2
=
l2
(
b1
)
o2
=
l2
(
b1
)
loss1
=
o1
.
mean
()
loss1
=
o1
.
mean
()
...
@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1
.
zero_grad
(
True
)
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
step
()
opt2
.
zero_grad
(
True
)
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
...
@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold
=
[
0.0
,
2.0
]
threshold
=
[
0.0
,
2.0
]
values
=
threshold
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
names
=
[
"threshold_{0}"
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
l1
.
eval
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
...
@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
"cuda"
)
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
'cuda'
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"
cuda
"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"
cuda
"
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
to
(
torch
.
float16
).
to
(
'cuda'
)
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
to
(
torch
.
float16
)
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'
cuda
'
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"
cuda
"
).
half
()
o1
=
mlp
(
b1
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"
cuda
"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'
cuda
'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"
cuda
"
tests/test_optim.py
View file @
bfa0e332
import
ctypes
import
os
import
os
import
time
import
shutil
import
shutil
import
time
import
uuid
import
uuid
from
itertools
import
product
from
os.path
import
join
import
pytest
import
pytest
import
ctypes
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
from
os.path
import
join
# import apex
from
itertools
import
product
#import apex
k
=
20
k
=
20
def
get_temp_dir
():
def
get_temp_dir
():
path
=
'
/tmp/autoswap/{0}
'
.
format
(
str
(
uuid
.
uuid4
()))
path
=
"
/tmp/autoswap/{0}
"
.
format
(
str
(
uuid
.
uuid4
()))
os
.
makedirs
(
path
,
exist_ok
=
True
)
os
.
makedirs
(
path
,
exist_ok
=
True
)
return
path
return
path
def
rm_path
(
path
):
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
"momentum_pytorch"
]
=
(
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
None
,
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
,
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
"momentum"
]
=
(
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
)
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
"lars"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
),
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
)
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"adam8bit"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"momentum8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
str2optimizers
[
"rmsprop8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers
[
"lars8bit"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
)
str2optimizers
[
"rmsprop8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
)
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
"adam8bit"
]
=
[
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
"lamb8bit"
]
=
[
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
]
str2statenames
[
"adam8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"lamb"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
p1
=
p1
.
float
()
torch_optimizer
=
str2optimizers
[
optim_name
][
0
]([
p1
])
torch_optimizer
=
str2optimizers
[
optim_name
][
0
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
...
@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
k
):
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
...
@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer
.
step
()
torch_optimizer
.
step
()
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
bnb_optimizer
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'
opt.pt
'
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"
opt.pt
"
)))
rm_path
(
path
)
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
,
)
if
gtype
==
torch
.
float16
:
if
gtype
==
torch
.
float16
:
# the adam buffers should also be close because they are 32-bit
# the adam buffers should also be close because they are 32-bit
...
@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1
.
data
=
p1
.
data
.
half
().
float
()
p1
.
data
=
p1
.
data
.
half
().
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(),
p2
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(),
p2
)
if
optim_name
in
[
'lars'
,
'lamb'
]:
if
optim_name
in
[
"lars"
,
"lamb"
]:
assert
bnb_optimizer
.
state
[
p2
][
'unorm_vec'
]
>
0.0
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype"
,
values
,
ids
=
names
)
def
test_global_config
(
dim1
,
dim2
,
gtype
):
def
test_global_config
(
dim1
,
dim2
,
gtype
):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
return
p2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
p3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
p3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
mask
=
torch
.
rand_like
(
p2
)
<
0.1
mask
=
torch
.
rand_like
(
p2
)
<
0.1
beta1
=
0.9
beta1
=
0.9
beta2
=
0.999
beta2
=
0.999
...
@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'
optim_bits
'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
"
optim_bits
"
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
p1
=
p1
.
cuda
()
p1
=
p1
.
cuda
()
...
@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
0.001
p1
.
grad
=
g1
p1
.
grad
=
g1
p2
.
grad
=
g2
p2
.
grad
=
g2
p3
.
grad
=
g3
p3
.
grad
=
g3
adam2
.
step
()
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
"state1"
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
"state2"
].
dtype
==
torch
.
uint8
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
optimizer_names
=
[
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
"adam8bit"
,
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
"momentum8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"lamb8bit"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
p1
=
p1
.
float
()
blocksize
=
2048
blocksize
=
2048
...
@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
=
[]
relerrors
=
[]
for
i
in
range
(
50
):
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
...
@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states
=
[]
dequant_states
=
[]
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
#print(bnb_optimizer.state[p2][max_val], name1)
# print(bnb_optimizer.state[p2][max_val], name1)
if
'blockwise'
in
optim_name
:
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
,
)
else
:
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
s1
=
F
.
dequantize
(
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
dequant_states
.
append
(
s1
.
clone
())
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
relerr
=
err
/
torch
.
abs
(
p1
)
assert
err
.
mean
()
<
0.0001
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
assert
relerr
.
mean
()
<
0.001
...
@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
if
i
%
10
==
0
and
i
>
0
:
if
i
%
10
==
0
and
i
>
0
:
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
s1cpy
=
s
.
clone
()
s1cpy
=
s
.
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
path
=
get_temp_dir
()
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
bnb_optimizer
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'
opt.pt
'
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"
opt.pt
"
)))
rm_path
(
path
)
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
if
'blockwise'
in
optim_name
:
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
,
)
else
:
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
...
@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
#print(sum(errors)/len(errors))
# print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors))
# print(sum(relerrors)/len(relerrors))
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
]
gtype
=
[
torch
.
float32
]
optim_bits
=
[
32
,
8
]
optim_bits
=
[
32
,
8
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
def
test_adam_percentile_clipping
(
dim1
,
dim2
,
gtype
,
optim_bits
):
def
test_adam_percentile_clipping
(
dim1
,
dim2
,
gtype
,
optim_bits
):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cpu"
,
dtype
=
gtype
)
*
0.1
beta1
=
0.9
beta1
=
0.9
beta2
=
0.999
beta2
=
0.999
lr
=
0.001
lr
=
0.001
...
@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1
=
p1
.
cuda
()
p1
=
p1
.
cuda
()
p2
=
p1
.
clone
()
p2
=
p1
.
clone
()
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
adam2
=
bnb
.
optim
.
Adam
(
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
step
=
0
step
=
0
for
i
in
range
(
50
):
for
i
in
range
(
50
):
step
+=
1
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g2
=
g1
.
clone
()
g2
=
g1
.
clone
()
p2
.
grad
=
g2
p2
.
grad
=
g2
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
p1
.
grad
=
g1
p1
.
grad
=
g1
adam1
.
step
()
adam1
.
step
()
...
@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if
optim_bits
==
32
:
if
optim_bits
==
32
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
5e-5
,
rtol
=
1e-4
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
5e-5
,
rtol
=
1e-4
)
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
5e-5
,
rtol
=
1e-4
,
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
5e-5
,
rtol
=
1e-4
,
)
elif
optim_bits
==
8
:
elif
optim_bits
==
8
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
2
,
rtol
=
1e-3
)
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
rtol
=
1e-3
adam1
.
state
[
p1
][
'state1'
].
copy_
(
adam2
.
state
[
p2
][
'state1'
])
)
adam1
.
state
[
p1
][
'state2'
].
copy_
(
adam2
.
state
[
p2
][
'state2'
])
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
rtol
=
1e-3
)
adam1
.
state
[
p1
][
"state1"
].
copy_
(
adam2
.
state
[
p2
][
"state1"
])
adam1
.
state
[
p1
][
"state2"
].
copy_
(
adam2
.
state
[
p2
][
"state2"
])
if
i
%
10
==
0
and
i
>
0
:
if
i
%
10
==
0
and
i
>
0
:
path
=
get_temp_dir
()
path
=
get_temp_dir
()
torch
.
save
(
adam2
.
state_dict
(),
join
(
path
,
'
opt.pt
'
))
torch
.
save
(
adam2
.
state_dict
(),
join
(
path
,
"
opt.pt
"
))
del
adam2
del
adam2
adam2
=
None
adam2
=
None
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
adam2
=
bnb
.
optim
.
Adam
(
adam2
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'opt.pt'
)))
[
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
,
)
adam2
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
dim1
=
[
4096
]
dim1
=
[
4096
]
dim2
=
[
4096
]
dim2
=
[
4096
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
'adam8bit_blockwise'
]
optimizer_names
=
[
"adam8bit_blockwise"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'
cuda
'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"
cuda
"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
p1
.
grad
=
g
for
i
in
range
(
k
):
for
i
in
range
(
k
):
if
i
==
k
//
5
:
if
i
==
k
//
5
:
# 100 iterations for burn-in
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer
.
step
()
bnb_optimizer
.
step
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
s
=
time
.
time
()
-
t0
s
=
time
.
time
()
-
t0
print
(
''
)
print
(
""
)
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
print
(
optim_name
,
gtype
,
s
/
params
)
#assert s < 3.9
# assert s < 3.9
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment