Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
a24aae30
Commit
a24aae30
authored
Jul 06, 2023
by
Jeongseok Kang
Browse files
Merge branch 'main' into fix/libcuda-to-torch
parents
2b4cc256
4395d68c
Changes
48
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1016 additions
and
335 deletions
+1016
-335
deploy.sh
deploy.sh
+0
-11
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+27
-0
setup.py
setup.py
+2
-2
tests/test_autograd.py
tests/test_autograd.py
+205
-10
tests/test_functional.py
tests/test_functional.py
+486
-165
tests/test_modules.py
tests/test_modules.py
+129
-41
tests/test_optim.py
tests/test_optim.py
+108
-106
tests/test_triton.py
tests/test_triton.py
+59
-0
No files found.
deploy.sh
View file @
a24aae30
...
@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
...
@@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
fi
fi
make clean
export
CUDA_HOME
=
$BASE_PATH
/cuda-10.2
make cuda10x_nomatmul
CUDA_VERSION
=
102
if
[
!
-f
"./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so"
]
;
then
# Control will enter here if $DIRECTORY doesn't exist.
echo
"Compilation unsuccessul!"
1>&2
exit
64
fi
make clean
make clean
export
CUDA_HOME
=
$BASE_PATH
/cuda-11.0
export
CUDA_HOME
=
$BASE_PATH
/cuda-11.0
make cuda110_nomatmul
CUDA_VERSION
=
110
make cuda110_nomatmul
CUDA_VERSION
=
110
...
...
examples/int8_inference_huggingface.py
0 → 100644
View file @
a24aae30
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
MAX_NEW_TOKENS
=
128
model_name
=
'decapoda-research/llama-7b-hf'
text
=
'Hamburg is in which country?
\n
'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
free_in_GB
=
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
max_memory
=
f
'
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB'
n_gpus
=
torch
.
cuda
.
device_count
()
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
'auto'
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
setup.py
View file @
a24aae30
...
@@ -18,10 +18,10 @@ def read(fname):
...
@@ -18,10 +18,10 @@ def read(fname):
setup
(
setup
(
name
=
f
"bitsandbytes"
,
name
=
f
"bitsandbytes"
,
version
=
f
"0.3
8.0
"
,
version
=
f
"0.3
9.1
"
,
author
=
"Tim Dettmers"
,
author
=
"Tim Dettmers"
,
author_email
=
"dettmers@cs.washington.edu"
,
author_email
=
"dettmers@cs.washington.edu"
,
description
=
"
8
-bit optimizers and matrix multiplication routines."
,
description
=
"
k
-bit optimizers and matrix multiplication routines."
,
license
=
"MIT"
,
license
=
"MIT"
,
keywords
=
"gpu optimizers optimization 8-bit quantization compression"
,
keywords
=
"gpu optimizers optimization 8-bit quantization compression"
,
url
=
"https://github.com/TimDettmers/bitsandbytes"
,
url
=
"https://github.com/TimDettmers/bitsandbytes"
,
...
...
tests/test_autograd.py
View file @
a24aae30
...
@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
...
@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
)
...
@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
decomp
=
[
0.0
,
6.0
]
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)
,
(
torch
.
matmul
,
bnb
.
research
.
switchback_bnb
)
]
str_funcs
=
[
"matmul
"
]
str_funcs
=
[
"matmul
lt"
,
'switchback_bnb'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
req_grad_str
=
[]
...
@@ -407,7 +407,7 @@ def test_matmullt(
...
@@ -407,7 +407,7 @@ def test_matmullt(
bias
.
grad
=
None
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
...
@@ -423,9 +423,204 @@ def test_matmullt(
...
@@ -423,9 +423,204 @@ def test_matmullt(
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_4bit
)]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
compress_statistics
=
[
False
,
True
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
quant_type
=
[
'fp4'
,
'nf4'
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
,
quant_type
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type"
,
values
,
ids
=
names
)
def
test_matmul_4bit
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
,
quant_type
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
B
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.115
#assert err < 0.20
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
torch
.
testing
.
assert_close
(
gradBias1
,
gradBias2
)
funcs
=
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)]
str_funcs
=
[
"matmul_fp8_mixed"
,
'matmul_fp8_global'
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
A
.
device
)
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
A
.
device
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
(),
fw_code
,
bw_code
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
,
fw_code
,
bw_code
)
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.115
#assert err < 0.20
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_close
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
assert
grad_err
.
item
()
<
0.003
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_functional.py
View file @
a24aae30
...
@@ -18,12 +18,15 @@ torch.set_printoptions(
...
@@ -18,12 +18,15 @@ torch.set_printoptions(
k
=
20
k
=
20
def
assert_all_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
count
=
0
):
def
assert_all_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
count
=
0
,
throw
=
True
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
if
sumval
>
count
:
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
if
throw
:
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
,
atol
)
return
sumval
class
FFN
(
torch
.
nn
.
Module
):
class
FFN
(
torch
.
nn
.
Module
):
...
@@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype):
...
@@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype):
code
=
F
.
estimate_quantiles
(
A
)
code
=
F
.
estimate_quantiles
(
A
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
torch
.
testing
.
assert_
all
close
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A
=
A
.
to
(
dtype
)
A
=
A
.
to
(
dtype
)
...
@@ -122,7 +125,7 @@ def test_quantile_quantization():
...
@@ -122,7 +125,7 @@ def test_quantile_quantization():
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_
all
close
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
assert
diff
<
0.001
assert
diff
<
0.001
...
@@ -146,63 +149,49 @@ def test_dynamic_quantization():
...
@@ -146,63 +149,49 @@ def test_dynamic_quantization():
C
,
S
=
F
.
quantize
(
A1
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_
all
close
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
assert
diff
<
0.004
assert
diff
<
0.004
def
test_dynamic_blockwise_quantization
():
@
pytest
.
mark
.
parametrize
(
"nested"
,
[
False
,
True
],
ids
=
[
"False"
,
"True"
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
def
test_dynamic_blockwise_quantization
(
nested
,
blocksize
):
#print('')
#print('')
for
blocksize
in
[
4096
,
2048
,
1024
,
512
]:
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.011
assert
relerr
<
0.018
#print('randn', blocksize, sum(diffs)/len(diffs))
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
,
blocksize
=
blocksize
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
relerr
<
0.015
#print('rand', blocksize, sum(diffs)/len(diffs))
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def
test_dynamic_blockwise_stochastic_quantization
():
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
rand
=
torch
.
rand
(
1024
).
cuda
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
"cuda"
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
# a maximunm distance of quantized values of 1
diff
=
torch
.
abs
(
A1
-
A2
)
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
diffs
.
append
(
diff
.
mean
().
item
())
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
reldiffs
.
append
(
reldiff
.
mean
().
item
())
torch
.
testing
.
assert_allclose
(
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
)
assert
abserr
<
0.011
assert
relerr
<
0.018
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
"cuda"
)
C
,
S
=
F
.
quantize_blockwise
(
A1
,
blocksize
=
blocksize
,
nested
=
nested
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
abserr
=
sum
(
diffs
)
/
len
(
diffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
relerr
<
0.015
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -231,9 +220,9 @@ def test_percentile_clipping(gtype):
...
@@ -231,9 +220,9 @@ def test_percentile_clipping(gtype):
vals
,
idx
=
torch
.
sort
(
gnorm_vec1
)
vals
,
idx
=
torch
.
sort
(
gnorm_vec1
)
clip1
=
vals
[
percentile
]
clip1
=
vals
[
percentile
]
torch
.
testing
.
assert_
all
close
(
gnorm_vec1
,
torch
.
sqrt
(
gnorm_vec2
))
torch
.
testing
.
assert_close
(
gnorm_vec1
,
torch
.
sqrt
(
gnorm_vec2
))
torch
.
testing
.
assert_
all
close
(
clip1
,
clip2
)
torch
.
testing
.
assert_close
(
clip1
,
clip2
)
torch
.
testing
.
assert_
all
close
(
gnorm1
,
gnorm2
)
torch
.
testing
.
assert_close
(
gnorm1
,
gnorm2
)
def
quant
(
x
):
def
quant
(
x
):
...
@@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2
=
dim2
-
(
dim2
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
errors
=
[]
relerrors
=
[]
relerrors
=
[]
print
(
""
)
#
print("")
for
i
in
range
(
5
):
for
i
in
range
(
5
):
if
batched
:
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
"cuda"
)
...
@@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
"cuda"
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
)
if
batched
:
if
batched
:
...
@@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
...
@@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr
=
err
/
torch
.
abs
(
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
print
(
mean
(
errors
))
#
print(mean(errors))
print
(
mean
(
relerrors
))
#
print(mean(relerrors))
def
test_stable_embedding
():
def
test_stable_embedding
():
...
@@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2
=
torch
.
matmul
(
A
.
t
().
float
(),
B
.
t
().
float
())
out2
=
torch
.
matmul
(
A
.
t
().
float
(),
B
.
t
().
float
())
out
=
F
.
igemm
(
A
.
t
(),
B
.
t
())
out
=
F
.
igemm
(
A
.
t
(),
B
.
t
())
torch
.
testing
.
assert_
all
close
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
for
i
in
range
(
k
):
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
...
@@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
...
@@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
out
=
F
.
igemm
(
A
,
B
.
t
())
out
=
F
.
igemm
(
A
,
B
.
t
())
torch
.
testing
.
assert_
all
close
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
n
=
3
n
=
3
...
@@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
...
@@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
)
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_
all
close
(
out
.
float
(),
out2
)
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
)
n
=
2
n
=
2
...
@@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
...
@@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
()
)
)
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_
all
close
(
out
.
float
(),
out2
.
float
())
torch
.
testing
.
assert_close
(
out
.
float
(),
out2
.
float
())
n
=
1
n
=
1
...
@@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
if
orderOut
==
"row"
:
if
orderOut
==
"row"
:
torch
.
testing
.
assert_
all
close
(
A
.
flatten
(),
out
.
flatten
())
torch
.
testing
.
assert_close
(
A
.
flatten
(),
out
.
flatten
())
elif
orderOut
==
"col"
:
elif
orderOut
==
"col"
:
torch
.
testing
.
assert_
all
close
(
A
.
t
().
flatten
(),
out
.
flatten
())
torch
.
testing
.
assert_close
(
A
.
t
().
flatten
(),
out
.
flatten
())
elif
orderOut
==
"col32"
:
elif
orderOut
==
"col32"
:
if
dims
==
2
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
...
@@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
...
@@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
# torch.testing.assert_
all
close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
# torch.testing.assert_
all
close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
"col32"
:
if
orderOut
==
"col32"
:
out2
,
S
=
F
.
nvidia_transform
(
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
out
,
from_order
=
orderOut
,
to_order
=
"row"
,
state
=
S
)
)
torch
.
testing
.
assert_
all
close
(
A
,
out2
)
torch
.
testing
.
assert_close
(
A
,
out2
)
n
=
1
n
=
1
...
@@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2
,
SB
=
F
.
transform
(
B
,
"col_turing"
)
B2
,
SB
=
F
.
transform
(
B
,
"col_turing"
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
torch
.
testing
.
assert_
all
close
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
# transpose
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
"cuda"
).
to
(
...
@@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
...
@@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
B2t
,
SBt
=
F
.
transform
(
B
,
"col_turing"
,
transpose
=
True
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
torch
.
testing
.
assert_
all
close
(
C1
,
C3
.
float
())
torch
.
testing
.
assert_close
(
C1
,
C3
.
float
())
dim1
=
[
32
]
dim1
=
[
32
]
...
@@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# print(C1.flatten()[:10])
# print(C1.flatten()[:10])
# print(C2.flatten()[:10])
# print(C2.flatten()[:10])
# torch.testing.assert_
all
close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
# transpose
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
...
@@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
...
@@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
# C3, S = F.transform(C2, 'row', state=SC)
# C3, S = F.transform(C2, 'row', state=SC)
# torch.testing.assert_
all
close(C1, C3.float())
# torch.testing.assert_close(C1, C3.float())
batch_size
=
2
batch_size
=
2
...
@@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
...
@@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
(),
bias
=
bias
)
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
(),
bias
=
bias
)
#torch.testing.assert_
all
close(C5, C4, atol=0.015, rtol=0.1)
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
n
=
C5
.
numel
()
n
=
C5
.
numel
()
assert_all_approx_close
(
C1
,
C4
,
atol
=
0.015
,
rtol
=
0.1
,
count
=
int
(
0.01
*
n
))
assert_all_approx_close
(
C1
,
C4
,
atol
=
0.015
,
rtol
=
0.1
,
count
=
int
(
0.01
*
n
))
...
@@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
...
@@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
)
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
torch
.
testing
.
assert_
all
close
(
col_stats1_trunc
,
col_stats2
)
torch
.
testing
.
assert_close
(
col_stats1_trunc
,
col_stats2
)
torch
.
testing
.
assert_
all
close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_
all
close
(
nnz_block_ptr1
,
nnz_block_ptr2
)
torch
.
testing
.
assert_close
(
nnz_block_ptr1
.
int
()
,
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
A
,
threshold
=
0.0
)
)
torch
.
testing
.
assert_
all
close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_close
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_
all
close
(
row_stats1
,
row_stats2
)
torch
.
testing
.
assert_close
(
row_stats1
,
row_stats2
)
assert
nnz_block_ptr2
is
None
assert
nnz_block_ptr2
is
None
...
@@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2):
...
@@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2):
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
# max difference is 1 due to rounding differences
# max difference is 1 due to rounding differences
torch
.
testing
.
assert_
all
close
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_
all
close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
n
=
CAt
.
numel
()
num_not_close_rows
=
(
num_not_close_rows
=
(
...
@@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2):
...
@@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2):
)
)
assert
False
assert
False
torch
.
testing
.
assert_
all
close
(
Srow
.
flatten
(),
statsA
)
torch
.
testing
.
assert_close
(
Srow
.
flatten
()
.
float
()
,
statsA
)
torch
.
testing
.
assert_
all
close
(
Scol
.
flatten
(),
statsAt
)
torch
.
testing
.
assert_close
(
Scol
.
flatten
()
.
float
()
,
statsAt
)
n
=
4
n
=
4
...
@@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
...
@@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
torch
.
testing
.
assert_
all
close
(
maxA
.
flatten
(),
stats1a
)
torch
.
testing
.
assert_close
(
maxA
.
flatten
()
.
float
()
,
stats1a
)
torch
.
testing
.
assert_
all
close
(
maxB
.
flatten
(),
stats2a
)
torch
.
testing
.
assert_close
(
maxB
.
flatten
()
.
float
()
,
stats2a
)
torch
.
testing
.
assert_
all
close
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_close
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_
all
close
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_close
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"col32"
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
"col32"
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
"col_turing"
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
"col_turing"
)
...
@@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
...
@@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
# print(out1)
# print(out1)
# print(out2)
# print(out2)
torch
.
testing
.
assert_
all
close
(
out1
,
out2
)
torch
.
testing
.
assert_close
(
out1
,
out2
)
n
=
2
n
=
2
...
@@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
...
@@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
A2
[
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()
]
=
coo_tensor
.
values
]
=
coo_tensor
.
values
torch
.
testing
.
assert_
all
close
(
A1
,
A2
)
torch
.
testing
.
assert_close
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
)
...
@@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
...
@@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
# torch.testing.assert_
all
close(out1, out2.half(), rtol=0.05, atol=0.001)
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
# torch.cuda.synchronize()
# torch.cuda.synchronize()
...
@@ -1644,9 +1633,9 @@ def test_coo2csr():
...
@@ -1644,9 +1633,9 @@ def test_coo2csr():
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
torch
.
testing
.
assert_
all
close
(
counts
,
(
A2
!=
0
).
sum
(
1
))
torch
.
testing
.
assert_close
(
counts
.
long
()
,
(
A2
!=
0
).
sum
(
1
))
idx
=
A2
!=
0
idx
=
A2
!=
0
torch
.
testing
.
assert_
all
close
(
A2
[
idx
],
csrA
.
values
)
torch
.
testing
.
assert_close
(
A2
[
idx
],
csrA
.
values
)
def
test_coo2csc
():
def
test_coo2csc
():
...
@@ -1664,10 +1653,10 @@ def test_coo2csc():
...
@@ -1664,10 +1653,10 @@ def test_coo2csc():
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
torch
.
testing
.
assert_
all
close
(
counts
,
(
A2
!=
0
).
sum
(
0
))
torch
.
testing
.
assert_close
(
counts
.
long
()
,
(
A2
!=
0
).
sum
(
0
))
# torch uses row-major -> use transpose to transfer to col-major
# torch uses row-major -> use transpose to transfer to col-major
idx
=
A2
.
t
()
!=
0
idx
=
A2
.
t
()
!=
0
torch
.
testing
.
assert_
all
close
(
A2
.
t
()[
idx
],
cscA
.
values
)
torch
.
testing
.
assert_close
(
A2
.
t
()[
idx
],
cscA
.
values
)
n
=
2
n
=
2
...
@@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
print
(
torch
.
median
(
max_count
.
float
()))
print
(
torch
.
median
(
max_count
.
float
()))
torch
.
testing
.
assert_
all
close
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
torch
.
testing
.
assert_close
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
p
=
200
/
(
2048
*
12288
*
4
)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
n
=
out1
.
numel
()
...
@@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size
=
1
batch_size
=
1
seqdim
=
1
seqdim
=
1
values
=
[]
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
values
.
append
((
batch_size
,
seqdim
,
4096
,
4
*
4096
))
# values.append((batch_size, seqdim, 5140, 4*5140))
values
.
append
((
batch_size
,
seqdim
,
5120
,
4
*
5120
))
values
.
append
((
batch_size
,
seqdim
,
6656
,
4
*
6656
))
values
.
append
((
batch_size
,
seqdim
,
8192
,
4
*
8192
))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
12
8
iters
=
8
0
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
(
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
,
threshold
=
6.0
).
cuda
().
half
())
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
#linearMixedBit.eval()
)
linearMixedBit
.
eval
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
# warmup
# warmup
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
f
"CB -> CxB conversion (each iteration): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
bnb
.
matmul
(
A
,
B
,
threshold
=
6.0
)
bnb
.
matmul
_4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"
CB -> CxB conversion + threshold
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"
bnb fp4
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
bnb
.
matmul_4bit
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"
no overhead matmul-lt
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"
bnb fp4 + compressed stats
: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# bnb.matmul(A, B, threshold=6.0)
#torch.cuda.synchronize()
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#C32A, SA = F.transform(CA, "col32")
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#CxB, SB = F.transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#torch.cuda.synchronize()
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#BA, statsB = F.vectorwise_quant(B, dim=1)
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"linear"
)
#
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
#
A2 = A.view(-1, A.shape[-1]).contiguous()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
"linear"
)
#
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"col32"
)
#
C32A, SA = F.nvidia_transform(CA, "col32")
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
#
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
#
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
#
out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit
(
A
)
linear8bit
(
A
)
...
@@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
linear8bit
(
A
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"bnb linear8bitlt (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"bnb linear8bitlt with threshold (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#linear8bit_train(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
#linear8bit_train_thresh(A)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# linear8bit_train(A)
#torch.cuda.synchronize()
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def
test_zeropoint
():
def
test_zeropoint
():
def
quant_zp
(
x
):
def
quant_zp
(
x
):
...
@@ -2009,7 +2034,7 @@ def test_extract_outliers():
...
@@ -2009,7 +2034,7 @@ def test_extract_outliers():
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
torch
.
testing
.
assert_
all
close
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
CA
,
SA
=
F
.
transform
(
A
,
"col_ampere"
)
CA
,
SA
=
F
.
transform
(
A
,
"col_ampere"
)
...
@@ -2018,7 +2043,7 @@ def test_extract_outliers():
...
@@ -2018,7 +2043,7 @@ def test_extract_outliers():
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
torch
.
testing
.
assert_
all
close
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_close
(
outliers1
,
outliers2
)
...
@@ -2050,7 +2075,6 @@ def test_fp8_quant():
...
@@ -2050,7 +2075,6 @@ def test_fp8_quant():
p_bits
=
7
-
e_bits
p_bits
=
7
-
e_bits
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
print
(
e_bits
,
p_bits
)
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -2149,7 +2173,7 @@ def test_few_bit_quant():
...
@@ -2149,7 +2173,7 @@ def test_few_bit_quant():
#assert err2.mean() <= err1
#assert err2.mean() <= err1
else
:
else
:
torch
.
testing
.
assert_
all
close
(
q1
,
q2
)
torch
.
testing
.
assert_close
(
q1
,
q2
)
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
#assert False
#assert False
...
@@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation():
...
@@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation():
def
test_bench_dequantization
():
def
test_bench_dequantization
():
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
print
(
qa
.
max
())
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
#print(max_theoretical_mu)
#print(max_theoretical_mu)
...
@@ -2189,7 +2215,302 @@ def test_bench_dequantization():
...
@@ -2189,7 +2215,302 @@ def test_bench_dequantization():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
F
.
de
quantize_blockwise
(
qa
,
SA
,
blocksize
=
2048
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
#print((time.time()-t0)/1e6)
def
test_fp4_quant
():
vals
=
list
(
product
([
0
,
1
],
repeat
=
4
))
code
=
{}
for
bits
in
vals
:
result
=
0
bias
=
3
sign
,
e1
,
e2
,
p1
=
bits
idx
=
sign
*
8
+
e1
*
4
+
e2
*
2
+
p1
*
1
sign
=
-
1.0
if
sign
else
1.0
exp
=
e1
*
2
+
e2
*
1
if
exp
==
0
:
# sub-normal
if
p1
==
0
:
result
=
0
else
:
result
=
sign
*
0.0625
else
:
# normal
exp
=
2
**
(
-
exp
+
bias
+
1
)
frac
=
1.5
if
p1
else
1.0
result
=
sign
*
exp
*
frac
code
[
idx
]
=
result
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
idx
=
err
>
1.0
err
=
err
.
mean
()
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_4bit_compressed_stats
(
quant_type
):
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs2
=
[]
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
q3
,
SA3
=
F
.
quantize_4bit
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
,
quant_type
=
quant_type
)
A2
=
F
.
dequantize_4bit
(
q2
,
SA2
,
quant_type
=
quant_type
)
A3
=
F
.
dequantize_4bit
(
q3
,
SA3
,
quant_type
=
quant_type
)
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
err
=
(
A1
-
A3
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
errs2
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize, quant_type)
#print(sum(errs2)/len(errs2), blocksize, quant_type)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_4bit
(
a
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
input_size
=
a
.
numel
()
/
2
output_size
=
a
.
numel
()
*
2
num_bytes
=
input_size
+
output_size
GB
=
num_bytes
/
1e9
max_theoretical_s
=
GB
/
768
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
F
.
dequantize_4bit
(
qa
,
SA
,
blocksize
=
blocksize
,
quant_type
=
quant_type
)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
def
test_normal_map_tree
():
code
=
F
.
create_normal_map
()
values
=
code
[:
8
].
tolist
()
+
code
[
-
8
:].
tolist
()
num_pivots
=
1
print
(
values
)
while
num_pivots
<
16
:
idx
=
list
(
range
(
16
//
num_pivots
//
2
,
16
,
16
//
num_pivots
))
print
(
idx
)
num_pivots
*=
2
pivots
=
[]
for
i
in
idx
:
pivots
.
append
((
values
[
i
-
1
]
+
values
[
i
])
/
2
)
print
(
pivots
)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
debug
=
True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
for
dim
in
[
4096
]:
#for dim in [128+1]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
100
):
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
B
.
t
())
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
not
debug
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_gemm_4bit
(
dtype
):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for
dim
in
[
4096
]:
errs
=
[]
relerrs
=
[]
max_err
=
0
max_relerr
=
0
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
4
*
dim
,
dim
+
0
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
qB
,
state
=
F
.
quantize_nf4
(
B
)
F
.
dequantize_nf4
(
qB
,
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
print
(
C1
.
shape
,
C2
.
shape
)
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err
=
torch
.
abs
(
C1
-
C2
)
mag
=
torch
.
abs
(
C1
)
+
1e-8
relerr
=
err
/
mag
max_err
=
max
(
err
.
max
(),
max_err
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
if
err
/
torch
.
abs
(
C1
).
mean
()
>
5e-5
or
err
>
3.2e-5
:
print
(
''
)
print
(
i
,
err
,
relerr
)
print
(
A
.
flatten
()[
-
6
:])
print
(
B
.
flatten
()[
-
6
:])
out
=
A
.
flatten
()[
-
6
:]
*
B
.
flatten
()[
-
6
:]
print
(
out
)
print
(
out
[:
-
1
].
sum
())
print
(
'='
*
80
)
print
(
C1
.
flatten
()[
-
6
:])
print
(
C2
.
flatten
()[
-
6
:])
#assert False, 'ERROR'
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
#print(c/math.sqrt(dim))
print
(
''
)
print
(
dim
,
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
))
print
(
dim
,
(
max_err
.
item
(),
max_relerr
.
item
()))
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
n
=
32
*
10
A
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
B
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
uint8
)
B2
=
F
.
get_paged
(
n
,
n
,
dtype
=
torch
.
float32
)
assert
A
.
is_paged
assert
B
.
is_paged
assert
A
.
page_deviceid
==
0
assert
B
.
page_deviceid
==
0
F
.
fill
(
A
,
17.0
)
F
.
fill
(
B
,
17
)
F
.
fill
(
B2
,
2
)
assert
(
A
==
17
).
sum
().
item
()
==
n
*
n
assert
(
B
==
17
).
sum
().
item
()
==
n
*
n
C
=
A
*
B
.
float
()
assert
(
C
==
289
).
sum
().
item
()
==
n
*
n
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
F
.
_mul
(
A
,
B2
)
assert
(
A
==
17
*
(
2
**
3
)).
sum
().
item
()
==
n
*
n
# F.prefetch_tensor(A)
# F.prefetch_tensor(B)
# F.fill(B2, 17.0)
# F._mul(A, B2)
# F.prefetch_tensor(A, to_cpu=True)
# F.prefetch_tensor(B, to_cpu=True)
# F.prefetch_tensor(B2, to_cpu=True)
# torch.cuda.synchronize()
# assert (A==17).sum().item() == n*n
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
tests/test_modules.py
View file @
a24aae30
...
@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
...
@@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
sumval
=
(
idx
==
0
).
sum
().
item
()
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
if
sumval
>
count
:
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
print
(
f
"Too many values not close: assert
{
sumval
}
<
{
count
}
"
)
torch
.
testing
.
assert_
all
close
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
class
LinearFunction
(
torch
.
autograd
.
Function
):
...
@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
...
@@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
def
test_linear8bitlt_accumulated_gradient
():
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l2
=
torch
.
nn
.
Sequential
(
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)]
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam32bit
(
l1
.
parameters
(),
lr
=
0.001
)
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
opt2
=
bnb
.
optim
.
Adam32bit
(
l2
.
parameters
(),
lr
=
0.001
)
l2
[
1
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
1
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam8bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam8bit
(
l2
.
parameters
(),
lr
=
0.001
)
acc_steps
=
10
acc_steps
=
10
...
@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
...
@@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
else
:
else
:
torch
.
testing
.
assert_allclose
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
threshold
=
[
0.0
,
2.0
]
values
=
threshold
names
=
[
f
"threshold_
{
vals
}
"
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
]
)
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
cuda
().
half
())
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
half
()
)
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
==
torch
.
int8
l1
.
eval
()
l1
.
eval
()
...
@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
).
half
().
to
(
"cuda"
))
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
to
(
"cuda"
)
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
...
@@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
scale
=
grad_ref
.
abs
().
mean
()
scale
=
grad_ref
.
abs
().
mean
()
torch
.
testing
.
assert_
all
close
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
torch
.
testing
.
assert_close
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
def
test_linear8bitlt_fp32_bias
():
@
pytest
.
mark
.
parametrize
(
"module"
,
[
lambda
nin
,
nout
,
bias
=
True
:
bnb
.
nn
.
Linear8bitLt
(
nin
,
nout
,
bias
=
bias
,
has_fp16_weights
=
False
),
bnb
.
nn
.
LinearFP4
],
ids
=
[
'Int8Lt'
,
'FP4'
])
def
test_linear_kbit_fp32_bias
(
module
):
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
l1
=
module
(
32
,
64
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
.
dtype
==
torch
.
float32
assert
l1
.
bias
.
dtype
==
torch
.
float32
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
...
@@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
assert
l1
.
bias
.
dtype
==
torch
.
float16
assert
l1
.
bias
.
dtype
==
torch
.
float16
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
,
bias
=
False
).
cuda
()
l1
=
module
(
32
,
64
,
bias
=
False
).
cuda
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
assert
l1
.
weight
.
dtype
in
[
torch
.
int8
,
torch
.
u
int8
]
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
modules
=
[]
modules
.
append
(
bnb
.
nn
.
Linear8bitLt
)
modules
.
append
(
bnb
.
nn
.
Linear4bit
)
modules
.
append
(
bnb
.
nn
.
LinearFP4
)
modules
.
append
(
bnb
.
nn
.
LinearNF4
)
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
))
modules
.
append
(
lambda
d1
,
d2
:
bnb
.
nn
.
LinearNF4
(
d1
,
d2
,
compress_statistics
=
True
))
names
=
[
'Int8Lt'
,
'4bit'
,
'FP4'
,
'NF4'
,
'FP4+C'
,
'NF4+C'
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"module"
,
modules
,
ids
=
names
)
def
test_kbit_backprop
(
module
):
b
=
17
dim1
=
37
dim2
=
83
ref
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
torch
.
nn
.
Linear
(
dim2
,
10
)])
ref
[
1
].
weight
.
requires_grad
=
False
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
0
].
weight
)
torch
.
nn
.
init
.
kaiming_normal_
(
ref
[
1
].
weight
)
kbit
=
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
dim1
,
dim2
),
module
(
dim2
,
10
)])
kbit
[
0
].
weight
.
detach
().
copy_
(
ref
[
0
].
weight
)
kbit
[
1
].
weight
.
detach
().
copy_
(
ref
[
1
].
weight
)
kbit
[
0
].
bias
.
detach
().
copy_
(
ref
[
0
].
bias
)
kbit
[
1
].
bias
.
detach
().
copy_
(
ref
[
1
].
bias
)
ref
=
ref
.
half
().
cuda
()
kbit
=
kbit
.
half
().
cuda
()
errs1
=
[]
errs2
=
[]
relerrs1
=
[]
relerrs2
=
[]
for
i
in
range
(
100
):
batch
=
torch
.
randn
(
b
,
dim1
).
half
().
cuda
()
out1
=
ref
(
batch
)
out2
=
kbit
(
batch
)
out1
.
mean
().
backward
()
out2
.
mean
().
backward
()
grad1
=
ref
[
0
].
weight
.
grad
grad2
=
kbit
[
0
].
weight
.
grad
bgrad1
=
ref
[
0
].
bias
.
grad
bgrad2
=
kbit
[
0
].
bias
.
grad
err1
=
(
out1
-
out2
).
abs
().
float
()
err2
=
(
grad1
-
grad2
).
abs
().
float
()
relerr1
=
(
err1
/
(
out1
.
abs
().
float
()
+
1e-9
))
relerr2
=
(
err2
/
(
grad1
.
abs
().
float
()
+
1e-9
))
errs1
.
append
(
err1
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
if
isinstance
(
module
,
bnb
.
nn
.
Linear8bitLt
):
torch
.
testing
.
assert_close
(
grad1
,
grad2
,
atol
=
0.008
,
rtol
=
0.05
)
torch
.
testing
.
assert_close
(
bgrad1
,
bgrad2
,
atol
=
0.008
,
rtol
=
0.05
)
else
:
torch
.
testing
.
assert_close
(
grad1
,
grad2
,
atol
=
0.015
,
rtol
=
0.05
)
torch
.
testing
.
assert_close
(
bgrad1
,
bgrad2
,
atol
=
0.02
,
rtol
=
0.05
)
ref
.
zero_grad
()
kbit
.
zero_grad
()
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
print
(
'rel grad'
,
sum
(
relerrs2
)
/
len
(
relerrs2
))
def
test_fp8linear
():
b
=
10
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8Mixed
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8b
.
weight
.
data
.
copy_
(
fp32b
.
weight
.
data
)
fp8b
.
bias
.
data
.
copy_
(
fp32b
.
bias
.
data
)
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
err
=
(
a
-
b
).
abs
().
mean
()
a
.
mean
().
backward
()
b
.
mean
().
backward
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
assert
err
<
0.05
assert
graderr
<
0.00002
assert
bgraderr
<
0.00002
tests/test_optim.py
View file @
a24aae30
...
@@ -19,11 +19,11 @@ import bitsandbytes.functional as F
...
@@ -19,11 +19,11 @@ import bitsandbytes.functional as F
k
=
20
k
=
20
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
def
assert_most_approx_close
(
a
,
b
,
rtol
=
1e-3
,
atol
=
1e-3
,
max_error_count
=
0
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
idx
=
torch
.
isclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
error_count
=
(
idx
==
0
).
sum
().
item
()
error_count
=
(
idx
==
0
).
sum
().
item
()
if
error_count
>
max_error_count
:
if
error_count
>
max_error_count
:
print
(
f
"Too many values not close: assert
{
error_count
}
<
{
max_error_count
}
"
)
print
(
f
"Too many values not close: assert
{
error_count
}
<
{
max_error_count
}
"
)
torch
.
testing
.
assert_
all
close
(
a
,
b
,
rtol
,
atol
)
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
def
get_temp_dir
():
def
get_temp_dir
():
...
@@ -35,11 +35,8 @@ def get_temp_dir():
...
@@ -35,11 +35,8 @@ def get_temp_dir():
def
rm_path
(
path
):
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"lion_pytorch"
]
=
(
None
,
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"momentum_pytorch"
]
=
(
str2optimizers
[
"momentum_pytorch"
]
=
(
None
,
None
,
...
@@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = (
...
@@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = (
bnb
.
optim
.
Adam
,
bnb
.
optim
.
Adam
,
)
)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
"paged_adamw"
]
=
(
torch
.
optim
.
AdamW
,
bnb
.
optim
.
PagedAdamW
)
str2optimizers
[
"paged_adam"
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
PagedAdam
)
str2optimizers
[
"lion"
]
=
(
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"lion"
]
=
(
Lion
,
bnb
.
optim
.
Lion
)
str2optimizers
[
"paged_lion"
]
=
(
Lion
,
bnb
.
optim
.
PagedLion
)
str2optimizers
[
"momentum"
]
=
(
str2optimizers
[
"momentum"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
)
str2optimizers
[
"lars"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"rmsprop"
]
=
(
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
)
str2optimizers
[
"adam8bit"
]
=
(
str2optimizers
[
"adam8bit"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
torch
.
optim
.
Adam
,
str2optimizers
[
"lion8bit"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
False
))
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"lion8bit"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
False
),
)
str2optimizers
[
"momentum8bit"
]
=
(
str2optimizers
[
"momentum8bit"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
@@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = (
...
@@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = (
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
)
str2optimizers
[
"lars8bit"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
torch
.
optim
.
Adam
,
str2optimizers
[
"paged_adamw8bit_blockwise"
]
=
(
torch
.
optim
.
AdamW
,
lambda
pxx
:
bnb
.
optim
.
PagedAdamW8bit
(
pxx
,
block_wise
=
True
))
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
),
str2optimizers
[
"paged_adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
PagedAdam8bit
(
pxx
,
block_wise
=
True
))
)
str2optimizers
[
"lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
"lion8bit_blockwise"
]
=
(
str2optimizers
[
"paged_lion8bit_blockwise"
]
=
(
Lion
,
lambda
pxx
:
bnb
.
optim
.
PagedLion8bit
(
pxx
,
block_wise
=
True
))
Lion
,
lambda
pxx
:
bnb
.
optim
.
Lion8bit
(
pxx
,
block_wise
=
True
),
)
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
str2optimizers
[
"momentum8bit_blockwise"
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
),
...
@@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = (
...
@@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"paged_adamw"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"paged_adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"paged_lion"
]
=
[(
"exp_avg"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[
str2statenames
[
"adam8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
str2statenames
[
"lamb8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
)]
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
str2statenames
[
"adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
]
str2statenames
[
"paged_adam8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
str2statenames
[
"lion8bit"
]
=
[
str2statenames
[
"paged_adamw8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)
str2statenames
[
"momentum8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
]
str2statenames
[
"lion8bit"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"lamb8bit"
]
=
[
str2statenames
[
"momentum8bit_blockwise"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"max1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"max2"
),
]
str2statenames
[
"adam8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
),
(
"exp_avg_sq"
,
"state2"
,
"qmap2"
,
"absmax2"
),
]
str2statenames
[
"lion8bit_blockwise"
]
=
[
(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"momentum8bit"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)
]
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)]
(
"square
_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
str2statenames
[
"lion8bit_blockwise"
]
=
[(
"exp
_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
]
str2statenames
[
"paged_lion8bit_blockwise"
]
=
[(
"exp_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
,
"
lion
"
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
'paged_adamw'
,
'paged_adam'
,
'lion'
,
'paged_
lion
'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
in
[
'momentum'
,
'rmsprop'
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
@@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-5
atol
,
rtol
=
1e-6
,
1e-5
elif
gtype
==
torch
.
bfloat16
:
atol
,
rtol
=
1e-3
,
1e-2
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
...
@@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
for
name1
,
name2
in
str2statenames
[
optim_name
]:
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
torch_optimizer
.
state
[
p1
][
name1
],
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
bnb_optimizer
.
state
[
p2
][
name2
]
.
cuda
()
,
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
)
)
...
@@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
atol
=
atol
,
rtol
=
rtol
,
atol
=
atol
,
rtol
=
rtol
,
max_error_count
=
10
)
max_error_count
=
10
)
if
gtype
=
=
torch
.
float
16
:
if
gtype
!
=
torch
.
float
32
:
# the adam buffers should also be close because they are 32-bit
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
# --> copy the state to keep weights close
p1
.
data
=
p1
.
data
.
half
(
).
float
()
p1
.
data
=
p1
.
data
.
to
(
p2
.
dtype
).
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_
all
close
(
p1
.
half
(
),
p2
)
torch
.
testing
.
assert_close
(
p1
.
to
(
p2
.
dtype
),
p2
)
if
optim_name
in
[
"lars"
,
"lamb"
]:
if
optim_name
in
[
"lars"
,
"lamb"
]:
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
...
@@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype):
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
optimizer_names
=
[
"adam8bit"
,
"adam8bit"
,
"lion8bit"
,
"lion8bit"
,
...
@@ -276,7 +242,6 @@ optimizer_names = [
...
@@ -276,7 +242,6 @@ optimizer_names = [
"rmsprop8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"adam8bit_blockwise"
,
"lion8bit_blockwise"
,
"lion8bit_blockwise"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
]
...
@@ -288,6 +253,7 @@ names = [
...
@@ -288,6 +253,7 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
[
'adam8bit_blockwise'
,
'lion8bit_blockwise'
]:
pytest
.
skip
()
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
@@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
3e-3
,
1e-3
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
elif
gtype
==
torch
.
bfloat16
:
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-4
,
1e-2
else
:
else
:
atol
,
rtol
=
3e-3
,
1e-3
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
...
@@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
errors
=
[]
errors
=
[]
relerrors
=
[]
relerrors
=
[]
for
i
in
range
(
5
0
):
for
i
in
range
(
10
0
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
...
@@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
)
)
==
0
==
0
)
)
assert
num_not_close
.
sum
().
item
()
<
20
#
assert num_not_close.sum().item() < 20
dequant_states
.
append
(
s1
.
clone
())
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
relerr
=
err
/
(
torch
.
abs
(
p1
)
+
1e-9
)
assert
err
.
mean
()
<
0.0001
if
g
.
dtype
==
torch
.
bfloat16
:
assert
relerr
.
mean
()
<
0.001
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.0016
else
:
assert
err
.
mean
()
<
0.00012
assert
relerr
.
mean
()
<
0.0012
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
...
@@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
rm_path
(
path
)
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_close
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
]
torch
.
testing
.
assert_close
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
)
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
]
)
if
"blockwise"
in
optim_name
:
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
s1
=
F
.
dequantize_blockwise
(
...
@@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
)
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
torch
.
testing
.
assert_close
(
s1cpy
,
s1
)
num_not_close
=
(
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
,
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
# since Lion can have pretty noisy updates where things lie at the boundary
# since Lion can have pretty noisy updates where things lie at the boundary
# allow up to 5 errors for Lion
# allow up to 5 errors for Lion
...
@@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
# together so we can test against the Adam error
# together so we can test against the Adam error
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
torch
.
testing
.
assert_close
(
p1
.
to
(
gtype
),
p2
)
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
str2statenames
[
optim_name
],
dequant_states
):
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
# print(sum(errors)/len(errors))
# print(sum(errors)/len(errors))
...
@@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
...
@@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if
optim_bits
==
32
:
if
optim_bits
==
32
:
torch
.
testing
.
assert_
all
close
(
p1
,
p2
)
torch
.
testing
.
assert_close
(
p1
,
p2
)
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
adam1
.
state
[
p1
][
"state1"
],
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
5e-5
,
atol
=
5e-5
,
rtol
=
1e-4
,
rtol
=
1e-4
,
)
)
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
adam1
.
state
[
p1
][
"state2"
],
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
5e-5
,
atol
=
5e-5
,
rtol
=
1e-4
,
rtol
=
1e-4
,
)
)
elif
optim_bits
==
8
:
elif
optim_bits
==
8
:
torch
.
testing
.
assert_
all
close
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
adam1
.
state
[
p1
][
"state1"
],
adam1
.
state
[
p1
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
adam2
.
state
[
p2
][
"state1"
],
atol
=
2
,
atol
=
2
,
rtol
=
1e-3
,
rtol
=
1e-3
,
)
)
torch
.
testing
.
assert_
all
close
(
torch
.
testing
.
assert_close
(
adam1
.
state
[
p1
][
"state2"
],
adam1
.
state
[
p1
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
adam2
.
state
[
p2
][
"state2"
],
atol
=
2
,
atol
=
2
,
...
@@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16]
...
@@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
"adam8bit_blockwise"
]
optimizer_names
=
[
"adam8bit_blockwise"
,
'paged_adam8bit_blockwise'
,
'paged_adamw8bit_blockwise'
,
'paged_lion8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
...
@@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
print
(
optim_name
,
gtype
,
s
/
params
)
# assert s < 3.9
# assert s < 3.9
dim1
=
[
2
*
1024
]
gtype
=
[
torch
.
float16
]
#mode = ['torch', 'bnb']
mode
=
[
'bnb'
]
optimizer_names
=
[
'paged_adamw'
]
#optimizer_names = ['paged_adamw8bit_blockwise']
values
=
list
(
product
(
dim1
,
gtype
,
optimizer_names
,
mode
))
names
=
[
'dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, gtype, optim_name, mode"
,
values
,
ids
=
names
)
def
test_stream_optimizer_bench
(
dim1
,
gtype
,
optim_name
,
mode
):
layers1
=
torch
.
nn
.
Sequential
(
*
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
dim1
,
dim1
)
for
i
in
range
(
10
)]))
layers1
=
layers1
.
to
(
gtype
)
layers1
=
layers1
.
cuda
()
large_tensor
=
None
if
mode
==
'torch'
:
optim
=
str2optimizers
[
optim_name
][
0
](
layers1
.
parameters
())
else
:
optim
=
str2optimizers
[
optim_name
][
1
](
layers1
.
parameters
())
# 12 GB
large_tensor
=
torch
.
empty
((
int
(
4.5e9
),),
device
=
'cuda'
)
torch
.
cuda
.
synchronize
()
time
.
sleep
(
5
)
num_batches
=
5
batches
=
torch
.
randn
(
num_batches
,
128
,
dim1
,
device
=
'cuda'
).
to
(
gtype
)
lbls
=
torch
.
randint
(
0
,
10
,
size
=
(
num_batches
,
128
)).
cuda
()
for
i
in
range
(
num_batches
):
print
(
i
)
b
=
batches
[
i
]
if
i
==
2
:
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
out1
=
layers1
(
b
)
loss1
=
torch
.
nn
.
functional
.
cross_entropy
(
out1
,
lbls
[
i
]).
mean
()
loss1
.
backward
()
optim
.
step
()
torch
.
cuda
.
synchronize
()
print
(
mode
,
time
.
time
()
-
t0
)
tests/test_triton.py
0 → 100644
View file @
a24aae30
import
pytest
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
from
bitsandbytes.nn
import
Linear8bitLt
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector_wise_quantization"
,
[
False
,
True
])
def
test_switchback
(
vector_wise_quantization
):
for
dim
in
[
83
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
baseline
.
weight
.
data
.
copy_
(
standard
.
weight
)
baseline
.
bias
.
data
.
copy_
(
standard
.
bias
)
x1
=
torch
.
randn
(
batch
,
dim
).
cuda
().
half
().
requires_grad_
(
True
)
x2
=
x1
.
clone
().
detach
().
requires_grad_
(
True
)
x3
=
x1
.
clone
().
detach
().
requires_grad_
(
True
)
out_standard
=
standard
(
x1
)
(
2
**
10
*
out_standard
.
abs
().
mean
()).
backward
()
print
(
x2
.
dtype
)
out_sb
=
switchback
(
x2
)
(
2
**
10
*
out_sb
.
abs
().
mean
()).
backward
()
out_baseline
=
baseline
(
x3
)
(
2
**
10
*
out_baseline
.
abs
().
mean
()).
backward
()
err_sb
=
(
out_standard
-
out_sb
).
abs
().
mean
()
err_baseline
=
(
out_standard
-
out_baseline
).
abs
().
mean
()
print
(
'OUT'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
bias
.
grad
-
switchback
.
bias
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
bias
.
grad
-
baseline
.
bias
.
grad
).
abs
().
mean
()
print
(
'GW2'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
weight
.
grad
-
switchback
.
weight
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
weight
.
grad
-
baseline
.
weight
.
grad
).
abs
().
mean
()
print
(
'GW1'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
x1
.
grad
-
x2
.
grad
).
abs
().
mean
()
err_baseline
=
(
x1
.
grad
-
x3
.
grad
).
abs
().
mean
()
print
(
'GX1'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment