Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
74399248
Commit
74399248
authored
Oct 05, 2021
by
Tim Dettmers
Browse files
Initial commit
parents
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
575 additions
and
0 deletions
+575
-0
tests/test_functional.py
tests/test_functional.py
+213
-0
tests/test_optim.py
tests/test_optim.py
+362
-0
No files found.
tests/test_functional.py
0 → 100644
View file @
74399248
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
itertools
import
product
from
bitsandbytes
import
functional
as
F
def
setup
():
pass
def
teardown
():
pass
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
def
test_estimate_quantiles
(
dtype
):
A
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
percs
=
torch
.
linspace
(
1
/
512
,
511
/
512
,
256
,
device
=
A
.
device
)
torch
.
testing
.
assert_allclose
(
percs
,
code
,
atol
=
1e-3
,
rtol
=
1e-2
)
A
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
A
=
A
.
to
(
dtype
)
code
=
F
.
estimate_quantiles
(
A
)
quantiles
=
torch
.
quantile
(
A
.
float
(),
percs
)
diff
=
torch
.
abs
(
code
-
quantiles
)
assert
(
diff
>
5e-02
).
sum
().
item
()
==
0
def
test_quantile_quantization
():
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0075
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
code
=
F
.
estimate_quantiles
(
A1
)
C
=
F
.
quantize_no_absmax
(
A1
,
code
)
A2
=
F
.
dequantize_no_absmax
(
C
,
code
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
5e-3
,
rtol
=
0
)
assert
diff
<
0.001
def
test_dynamic_quantization
():
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
C
,
S
=
F
.
quantize
(
A1
)
A2
=
F
.
dequantize
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
assert
diff
<
0.004
def
test_dynamic_blockwise_quantization
():
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
print
(
sum
(
diffs
)
/
len
(
diffs
))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
diffs
=
[]
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0033
diffs
.
append
(
diff
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
#print(sum(diffs)/len(diffs))
def
test_dynamic_blockwise_stochastic_quantization
():
diffs
=
[]
reldiffs
=
[]
rand
=
torch
.
rand
(
1024
).
cuda
()
for
i
in
range
(
100
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
,
rand
=
rand
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
)
# a maximunm distance of quantized values of 1
torch
.
testing
.
assert_allclose
(
C1
,
C2
,
atol
=
1
,
rtol
=
0
)
fraction_smaller
=
(
C1
<
C2
).
float
().
sum
()
/
C1
.
numel
()
fraction_larger
=
(
C1
>
C2
).
float
().
sum
()
/
C1
.
numel
()
torch
.
testing
.
assert_allclose
(
fraction_larger
,
fraction_smaller
,
atol
=
0.01
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"gtype"
,
[
torch
.
float32
,
torch
.
float16
],
ids
=
[
'float'
,
'half'
])
def
test_percentile_clipping
(
gtype
):
gnorm_vec1
=
torch
.
zeros
(
100
,
device
=
'cuda'
)
gnorm_vec2
=
torch
.
zeros
(
100
,
device
=
'cuda'
)
n
=
4
step
=
0
percentile
=
5
for
i
in
range
(
1000
):
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
'cuda'
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
assert
gnorm_scale
==
1.0
if
gnorm1
<
clip2
else
clip2
/
gnorm1
gnorm2
=
torch
.
norm
(
g
.
float
())
if
step
==
1
:
gnorm_vec1
[:]
=
gnorm2
else
:
gnorm_vec1
[
step
%
100
]
=
gnorm2
vals
,
idx
=
torch
.
sort
(
gnorm_vec1
)
clip1
=
vals
[
percentile
]
torch
.
testing
.
assert_allclose
(
gnorm_vec1
,
torch
.
sqrt
(
gnorm_vec2
))
torch
.
testing
.
assert_allclose
(
clip1
,
clip2
)
torch
.
testing
.
assert_allclose
(
gnorm1
,
gnorm2
)
def
test_stable_embedding
():
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
.
reset_parameters
()
def
test_dynamic_blockwise_quantization_cpu
():
#A1 = torch.randn(1024, 1024, device='cpu')
#code = F.create_dynamic_map()
#for i in range(1000):
# C, S = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, S)
for
i
in
range
(
10
):
# equivalence with GPU blockwise quantization
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cpu'
)
C1
,
S1
=
F
.
quantize_blockwise
(
A1
)
C2
,
S2
=
F
.
quantize_blockwise
(
A1
.
cuda
())
torch
.
testing
.
assert_allclose
(
S1
[
0
],
S2
[
0
].
cpu
())
# there seems to be some issues with precision in CUDA vs CPU
# not all elements are usually close, with couple off elements in a million
idx
=
torch
.
isclose
(
C1
,
C2
.
cpu
())
assert
(
idx
==
0
).
sum
().
item
()
<
15
diffs
=
[]
reldiffs
=
[]
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cpu'
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
)
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
diffs
=
[]
for
i
in
range
(
10
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cpu'
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
assert
diff
<
0.0033
diffs
.
append
(
diff
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
#print(sum(diffs)/len(diffs))
def
test_histogram
():
dim1
,
dim2
=
32
,
32
source
=
torch
.
rand
(
dim1
,
dim2
,
device
=
'cuda'
)
idx1
=
torch
.
randint
(
0
,
255
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
int
()
idx2
=
torch
.
randint
(
0
,
255
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
int
()
histogram1
=
torch
.
zeros
((
256
,
256
)).
cuda
()
histogram2
=
torch
.
zeros
((
256
,
256
)).
cuda
()
F
.
histogram_scatter_add_2d
(
histogram2
,
idx1
,
idx2
,
source
)
for
i
in
range
(
dim1
):
for
j
in
range
(
dim2
):
histogram1
[
idx1
[
i
,
j
].
item
(),
idx2
[
i
,
j
].
item
()]
+=
source
[
i
,
j
]
torch
.
testing
.
assert_allclose
(
histogram1
,
histogram2
)
torch
.
testing
.
assert_allclose
(
histogram1
.
sum
(),
source
.
sum
())
tests/test_optim.py
0 → 100644
View file @
74399248
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
time
import
shutil
import
uuid
import
pytest
import
ctypes
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
os.path
import
join
from
itertools
import
product
import
apex
def
get_temp_dir
():
path
=
'/tmp/autoswap/{0}'
.
format
(
str
(
uuid
.
uuid4
()))
os
.
makedirs
(
path
,
exist_ok
=
True
)
return
path
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
str2optimizers
=
{}
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam_apex'
]
=
(
None
,
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lamb_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.00
,
use_nvlamb
=
True
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lars_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
parallel
.
LARC
.
LARC
(
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
)),
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'fused_adam'
]
=
(
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lamb'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB
)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lamb8bit'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB8bit
)
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
torch_optimizer
=
str2optimizers
[
optim_name
][
0
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-5
else
:
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
bnb_optimizer
.
step
()
torch_optimizer
.
step
()
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
if
i
%
10
==
0
and
i
>
0
:
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'opt.pt'
))
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'opt.pt'
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
for
name1
,
name2
in
str2statenames
[
optim_name
]:
torch
.
testing
.
assert_allclose
(
torch_optimizer
.
state
[
p1
][
name1
],
bnb_optimizer
.
state
[
p2
][
name2
],
atol
=
atol
,
rtol
=
rtol
)
if
gtype
==
torch
.
float16
:
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
p1
.
data
=
p1
.
data
.
half
().
float
()
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(),
p2
)
if
optim_name
in
[
'lars'
,
'lamb'
]:
assert
bnb_optimizer
.
state
[
p2
][
'unorm_vec'
]
>
0.0
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype"
,
values
,
ids
=
names
)
def
test_global_config
(
dim1
,
dim2
,
gtype
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
p3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
mask
=
torch
.
rand_like
(
p2
)
<
0.1
beta1
=
0.9
beta2
=
0.999
lr
=
0.001
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'optim_bits'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
p1
=
p1
.
cuda
()
p2
=
p2
.
cuda
()
p3
=
p3
.
cuda
()
adam2
=
bnb
.
optim
.
Adam
([
p1
,
p2
,
p3
],
lr
,
(
beta1
,
beta2
),
eps
)
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-5
else
:
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g3
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
p1
.
grad
=
g1
p2
.
grad
=
g2
p3
.
grad
=
g3
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
p2
=
p1
.
clone
()
p1
=
p1
.
float
()
blocksize
=
2048
torch_optimizer
=
str2optimizers
[
optim_name
][
0
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
else
:
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
errors
=
[]
relerrors
=
[]
for
i
in
range
(
50
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
bnb_optimizer
.
step
()
torch_optimizer
.
step
()
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
dequant_states
=
[]
for
name1
,
name2
,
qmap
,
max_val
in
str2statenames
[
optim_name
]:
#print(bnb_optimizer.state[p2][max_val], name1)
if
'blockwise'
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
assert
num_not_close
.
sum
().
item
()
<
20
dequant_states
.
append
(
s1
.
clone
())
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
if
i
%
10
==
0
and
i
>
0
:
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
s1cpy
=
s
.
clone
()
raws1cpy
=
bnb_optimizer
.
state
[
p2
][
name2
].
clone
()
qmap1
=
bnb_optimizer
.
state
[
p2
][
qmap
].
clone
()
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'opt.pt'
))
del
bnb_optimizer
bnb_optimizer
=
None
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'opt.pt'
)))
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
if
'blockwise'
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
],
blocksize
=
blocksize
)
else
:
s1
=
F
.
dequantize
(
code
=
bnb_optimizer
.
state
[
p2
][
qmap
],
absmax
=
bnb_optimizer
.
state
[
p2
][
max_val
],
A
=
bnb_optimizer
.
state
[
p2
][
name2
])
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
num_not_close
=
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
#print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors))
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
]
optim_bits
=
[
32
,
8
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optim_bits
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_bits"
,
values
,
ids
=
names
)
def
test_adam_percentile_clipping
(
dim1
,
dim2
,
gtype
,
optim_bits
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cpu'
,
dtype
=
gtype
)
*
0.1
beta1
=
0.9
beta2
=
0.999
lr
=
0.001
eps
=
1e-8
p1
=
p1
.
cuda
()
p2
=
p1
.
clone
()
adam1
=
bnb
.
optim
.
Adam
([
p1
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
)
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
gnorm_vec
=
torch
.
zeros
(
100
).
cuda
()
step
=
0
for
i
in
range
(
50
):
step
+=
1
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
(
0.01
*
i
)
g2
=
g1
.
clone
()
p2
.
grad
=
g2
current_gnorm
,
clip_val
,
gnorm_scale
=
F
.
percentile_clipping
(
g1
,
gnorm_vec
,
step
,
5
)
g1
=
(
g1
.
float
()
*
gnorm_scale
).
to
(
gtype
)
p1
.
grad
=
g1
adam1
.
step
()
adam2
.
step
()
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if
optim_bits
==
32
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
5e-5
,
rtol
=
1e-4
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
5e-5
,
rtol
=
1e-4
)
elif
optim_bits
==
8
:
torch
.
testing
.
assert_allclose
(
p1
,
p2
,
atol
=
1e-4
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state1'
],
adam2
.
state
[
p2
][
'state1'
],
atol
=
2
,
rtol
=
1e-3
)
torch
.
testing
.
assert_allclose
(
adam1
.
state
[
p1
][
'state2'
],
adam2
.
state
[
p2
][
'state2'
],
atol
=
2
,
rtol
=
1e-3
)
adam1
.
state
[
p1
][
'state1'
].
copy_
(
adam2
.
state
[
p2
][
'state1'
])
adam1
.
state
[
p1
][
'state2'
].
copy_
(
adam2
.
state
[
p2
][
'state2'
])
if
i
%
10
==
0
and
i
>
0
:
path
=
get_temp_dir
()
torch
.
save
(
adam2
.
state_dict
(),
join
(
path
,
'opt.pt'
))
del
adam2
adam2
=
None
adam2
=
bnb
.
optim
.
Adam
([
p2
],
lr
,
(
beta1
,
beta2
),
eps
,
optim_bits
=
optim_bits
,
percentile_clipping
=
5
)
adam2
.
load_state_dict
(
torch
.
load
(
join
(
path
,
'opt.pt'
)))
dim1
=
[
4096
]
dim2
=
[
4096
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names
=
[
'adam8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_benchmark_blockwise
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
for
i
in
range
(
5000
):
if
i
==
500
:
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
bnb_optimizer
.
step
()
torch
.
cuda
.
synchronize
()
s
=
time
.
time
()
-
t0
print
(
''
)
params
=
4500
*
4096
*
4096
print
(
optim_name
,
gtype
,
s
/
params
)
#assert s < 3.9
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment