Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
7dc198fe
Commit
7dc198fe
authored
Apr 17, 2023
by
Tim Dettmers
Browse files
Added 32-bit optimizer for bfloat16 gradients.
parent
b8ea2b41
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
65 additions
and
86 deletions
+65
-86
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+1
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+34
-55
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+7
-0
csrc/kernels.cu
csrc/kernels.cu
+5
-2
csrc/ops.cu
csrc/ops.cu
+1
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+6
-4
tests/test_optim.py
tests/test_optim.py
+11
-24
No files found.
bitsandbytes/cextension.py
View file @
7dc198fe
...
@@ -23,7 +23,7 @@ try:
...
@@ -23,7 +23,7 @@ try:
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues'''
)
https://github.com/TimDettmers/bitsandbytes/issues'''
)
lib
.
cadam
32
bit_
g
32
lib
.
cadam
_8
bit_
blockwise_fp
32
lib
.
get_context
.
restype
=
ct
.
c_void_p
lib
.
get_context
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
COMPILED_WITH_CUDA
=
True
...
...
bitsandbytes/functional.py
View file @
7dc198fe
...
@@ -28,7 +28,7 @@ name2qmap = {}
...
@@ -28,7 +28,7 @@ name2qmap = {}
if
COMPILED_WITH_CUDA
:
if
COMPILED_WITH_CUDA
:
"""C FUNCTIONS FOR OPTIMIZERS"""
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit
=
{}
str2optimizer32bit
=
{}
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer32bit
[
"adam"
]
=
(
lib
.
cadam32bit_g
fp
32
,
lib
.
cadam32bit_g
fp16
,
lib
.
cadam32bit_gbf
16
)
str2optimizer32bit
[
"momentum"
]
=
(
str2optimizer32bit
[
"momentum"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
lib
.
cmomentum32bit_g16
,
...
@@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA:
...
@@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA:
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g32
,
lib
.
cadagrad32bit_g16
,
lib
.
cadagrad32bit_g16
,
)
)
str2optimizer32bit
[
"lars"
]
=
(
lib
.
cmomentum32bit_g32
,
lib
.
cmomentum32bit_g16
,
)
str2optimizer32bit
[
"lamb"
]
=
(
lib
.
cadam32bit_g32
,
lib
.
cadam32bit_g16
)
str2optimizer8bit
=
{}
str2optimizer8bit
=
{}
str2optimizer8bit
[
"adam"
]
=
(
str2optimizer8bit
[
"adam"
]
=
(
...
@@ -998,53 +993,37 @@ def optimizer_update_32bit(
...
@@ -998,53 +993,37 @@ def optimizer_update_32bit(
if
max_unorm
>
0.0
:
if
max_unorm
>
0.0
:
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
param_norm
=
torch
.
norm
(
p
.
data
.
float
())
if
optimizer_name
not
in
str2optimizer32bit
:
raise
NotImplementedError
(
f
'Optimizer not implemented:
{
optimizer_name
}
. Choices:
{
","
.
join
(
str2optimizer32bit
.
keys
())
}
'
)
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
float32
:
optim_func
=
None
str2optimizer32bit
[
optimizer_name
][
0
](
if
g
.
dtype
==
torch
.
float32
:
get_ptr
(
g
),
optim_func
=
str2optimizer32bit
[
optimizer_name
][
0
]
get_ptr
(
p
),
elif
g
.
dtype
==
torch
.
float16
:
get_ptr
(
state1
),
optim_func
=
str2optimizer32bit
[
optimizer_name
][
1
]
get_ptr
(
state2
),
elif
(
g
.
dtype
==
torch
.
bfloat16
and
len
(
str2optimizer32bit
[
optimizer_name
])
==
3
):
get_ptr
(
unorm_vec
),
optim_func
=
str2optimizer32bit
[
optimizer_name
][
2
]
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
float32
:
str2optimizer32bit
[
optimizer_name
][
1
](
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
is_on_gpu
([
g
,
p
,
state1
,
state2
,
unorm_vec
])
prev_device
=
pre_call
(
g
.
device
)
optim_func
(
get_ptr
(
g
),
get_ptr
(
p
),
get_ptr
(
state1
),
get_ptr
(
state2
),
get_ptr
(
unorm_vec
),
ct
.
c_float
(
max_unorm
),
ct
.
c_float
(
param_norm
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_float
(
weight_decay
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()))
post_call
(
prev_device
)
def
optimizer_update_8bit
(
def
optimizer_update_8bit
(
...
@@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise(
...
@@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise(
optim_func
=
None
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optim
izer
_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
optim_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
...
@@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise(
...
@@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise(
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
prev_device
=
pre_call
(
g
.
device
)
prev_device
=
pre_call
(
g
.
device
)
optim
izer
_func
(
optim_func
(
get_ptr
(
p
),
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state1
),
...
...
bitsandbytes/nn/modules.py
View file @
7dc198fe
...
@@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter):
...
@@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter):
s
[
0
]
=
s
[
0
].
to
(
device
)
s
[
0
]
=
s
[
0
].
to
(
device
)
if
self
.
compress_statistics
:
if
self
.
compress_statistics
:
# TODO: refactor this. This is a nightmare
# TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s
[
-
2
][
0
]
=
s
[
-
2
][
0
].
to
(
device
)
# offset
s
[
-
2
][
0
]
=
s
[
-
2
][
0
].
to
(
device
)
# offset
s
[
-
2
][
1
][
0
]
=
s
[
-
2
][
1
][
0
].
to
(
device
)
# nested quantiation state statitics
s
[
-
2
][
1
][
0
]
=
s
[
-
2
][
1
][
0
].
to
(
device
)
# nested quantiation state statitics
s
[
-
2
][
1
][
1
]
=
s
[
-
2
][
1
][
1
].
to
(
device
)
# nested quantiation codebook
s
[
-
2
][
1
][
1
]
=
s
[
-
2
][
1
][
1
].
to
(
device
)
# nested quantiation codebook
...
...
csrc/kernels.cu
View file @
7dc198fe
...
@@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
...
@@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
const float beta1, const float beta2, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
half
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
float
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
float
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
half
)
MAKE_PreconditionOptimizer32bit2State
(
ADAM
,
__nv_bfloat16
)
template
__global__
void
kOptimizer32bit2State
<
float
,
ADAM
>(
float
*
g
,
float
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
__global__
void
kOptimizer32bit2State
<
half
,
ADAM
>(
half
*
g
,
half
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
template
__global__
void
kOptimizer32bit2State
<
half
,
ADAM
>(
half
*
g
,
half
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
template
__global__
void
kOptimizer32bit2State
<
float
,
ADAM
>(
float
*
g
,
float
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
template
__global__
void
kOptimizer32bit2State
<
__nv_b
float
16
,
ADAM
>(
__nv_b
float
16
*
g
,
__nv_b
float
16
*
p
,
float
*
state1
,
float
*
state2
,
float
*
unorm
,
const
float
max_unorm
,
const
float
param_norm
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
const
bool
skip_zeros
,
const
int
n
);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
...
...
csrc/ops.cu
View file @
7dc198fe
...
@@ -703,6 +703,7 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
...
@@ -703,6 +703,7 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
MAKE_optimizer32bit
(
ADAM
,
half
)
MAKE_optimizer32bit
(
ADAM
,
half
)
MAKE_optimizer32bit
(
ADAM
,
float
)
MAKE_optimizer32bit
(
ADAM
,
float
)
MAKE_optimizer32bit
(
ADAM
,
__nv_bfloat16
)
MAKE_optimizer32bit
(
MOMENTUM
,
half
)
MAKE_optimizer32bit
(
MOMENTUM
,
half
)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
MOMENTUM
,
float
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
MAKE_optimizer32bit
(
RMSPROP
,
half
)
...
...
csrc/pythonInterface.c
View file @
7dc198fe
...
@@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
...
@@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_FUNC32
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_FUNC32
(
adam
,
ADAM
,
float
,
32
)
MAKE_FUNC32
(
adam
,
ADAM
,
float
,
fp32
)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
16
)
MAKE_FUNC32
(
adam
,
ADAM
,
half
,
fp16
)
MAKE_FUNC32
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_FUNC32
(
adagrad
,
ADAGRAD
,
float
,
32
)
...
@@ -173,8 +174,9 @@ extern "C"
...
@@ -173,8 +174,9 @@ extern "C"
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32
(
adam
,
float
,
32
)
MAKE_CFUNC32
(
adam
,
float
,
fp32
)
MAKE_CFUNC32
(
adam
,
half
,
16
)
MAKE_CFUNC32
(
adam
,
half
,
fp16
)
MAKE_CFUNC32
(
adam
,
__nv_bfloat16
,
bf16
)
MAKE_CFUNC32
(
momentum
,
float
,
32
)
MAKE_CFUNC32
(
momentum
,
float
,
32
)
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
momentum
,
half
,
16
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
MAKE_CFUNC32
(
rmsprop
,
float
,
32
)
...
...
tests/test_optim.py
View file @
7dc198fe
...
@@ -44,10 +44,6 @@ str2optimizers["momentum"] = (
...
@@ -44,10 +44,6 @@ str2optimizers["momentum"] = (
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
)
str2optimizers
[
"lars"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"rmsprop"
]
=
(
str2optimizers
[
"rmsprop"
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
...
@@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = (
...
@@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = (
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
),
)
)
str2optimizers
[
"lars8bit"
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
),
)
str2optimizers
[
"adam8bit_blockwise"
]
=
(
str2optimizers
[
"adam8bit_blockwise"
]
=
(
torch
.
optim
.
Adam
,
torch
.
optim
.
Adam
,
...
@@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = (
...
@@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"adam"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"momentum"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lars"
]
=
[(
"momentum_buffer"
,
"state1"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"lamb"
]
=
[(
"exp_avg"
,
"state1"
),
(
"exp_avg_sq"
,
"state2"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"rmsprop"
]
=
[(
"square_avg"
,
"state1"
)]
str2statenames
[
"adam8bit"
]
=
[
str2statenames
[
"adam8bit"
]
=
[
...
@@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [
...
@@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [
str2statenames
[
"momentum8bit_blockwise"
]
=
[
str2statenames
[
"momentum8bit_blockwise"
]
=
[
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"absmax1"
)
]
]
str2statenames
[
"lars8bit"
]
=
[(
"momentum_buffer"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit"
]
=
[(
"square_avg"
,
"state1"
,
"qmap1"
,
"max1"
)]
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
str2statenames
[
"rmsprop8bit_blockwise"
]
=
[
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
(
"square_avg"
,
"state1"
,
"qmap1"
,
"absmax1"
)
...
@@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [
...
@@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
,
"lars"
]
optimizer_names
=
[
"adam"
,
"momentum"
,
"rmsprop"
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
names
=
[
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
"dim1_{}_dim2_{}_gtype_{}_optim_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer32bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
...
@@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
1e-6
,
1e-5
atol
,
rtol
=
1e-6
,
1e-5
elif
gtype
==
torch
.
bfloat16
:
atol
,
rtol
=
1e-3
,
1e-2
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
...
@@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol
=
rtol
,
rtol
=
rtol
,
)
)
if
gtype
=
=
torch
.
float
16
:
if
gtype
!
=
torch
.
float
32
:
# the adam buffers should also be close because they are 32-bit
# the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit
# but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update
# the difference grow larger and larger with each update
# --> copy the state to keep weights close
# --> copy the state to keep weights close
p1
.
data
=
p1
.
data
.
half
(
).
float
()
p1
.
data
=
p1
.
data
.
to
(
p2
.
dtype
).
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
half
(
),
p2
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
p2
.
dtype
),
p2
)
if
optim_name
in
[
"lars"
,
"lamb"
]:
if
optim_name
in
[
"lars"
,
"lamb"
]:
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
assert
bnb_optimizer
.
state
[
p2
][
"unorm_vec"
]
>
0.0
...
@@ -246,7 +234,6 @@ optimizer_names = [
...
@@ -246,7 +234,6 @@ optimizer_names = [
"momentum8bit"
,
"momentum8bit"
,
"rmsprop8bit"
,
"rmsprop8bit"
,
"adam8bit_blockwise"
,
"adam8bit_blockwise"
,
"lars8bit"
,
"momentum8bit_blockwise"
,
"momentum8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
"rmsprop8bit_blockwise"
,
]
]
...
@@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerr
=
err
/
torch
.
abs
(
p1
)
relerr
=
err
/
torch
.
abs
(
p1
)
if
g
.
dtype
==
torch
.
bfloat16
:
if
g
.
dtype
==
torch
.
bfloat16
:
assert
err
.
mean
()
<
0.00015
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.001
5
assert
relerr
.
mean
()
<
0.001
6
else
:
else
:
assert
err
.
mean
()
<
0.0001
assert
err
.
mean
()
<
0.0001
2
assert
relerr
.
mean
()
<
0.001
assert
relerr
.
mean
()
<
0.001
2
errors
.
append
(
err
.
mean
().
item
())
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment