Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c4cfe4fb
Commit
c4cfe4fb
authored
Apr 01, 2023
by
Tim Dettmers
Browse files
Added bf16 Adam.
parent
8645d1f7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
78 additions
and
87 deletions
+78
-87
Makefile
Makefile
+4
-3
bitsandbytes/functional.py
bitsandbytes/functional.py
+30
-38
csrc/kernels.cu
csrc/kernels.cu
+2
-0
csrc/ops.cu
csrc/ops.cu
+2
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+22
-21
tests/test_optim.py
tests/test_optim.py
+18
-25
No files found.
Makefile
View file @
c4cfe4fb
...
@@ -12,6 +12,7 @@ CUDA_VERSION:=
...
@@ -12,6 +12,7 @@ CUDA_VERSION:=
endif
endif
NVCC
:=
$(CUDA_HOME)
/bin/nvcc
NVCC
:=
$(CUDA_HOME)
/bin/nvcc
###########################################
###########################################
...
@@ -59,9 +60,9 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
...
@@ -59,9 +60,9 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
all
:
$(ROOT_DIR)/dependencies/cub
$(BUILD_DIR) env
all
:
$(BUILD_DIR) env
$(NVCC)
$(CC_CUDA1
0
x)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_CUDA1
1
x)
-Xcompiler
'-fPIC'
--use_fast_math
-Xptxas
=
-v
-dc
$(FILES_CUDA)
$(INCLUDE)
$(LIB)
--output-directory
$(BUILD_DIR)
$(NVCC)
$(CC_CUDA1
0
x)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(NVCC)
$(CC_CUDA1
1
x)
-Xcompiler
'-fPIC'
-dlink
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
-o
$(BUILD_DIR)
/link.o
$(GPP)
-std
=
c++14
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
$(GPP)
-std
=
c++14
-DBUILD_CUDA
-shared
-fPIC
$(INCLUDE)
$(BUILD_DIR)
/ops.o
$(BUILD_DIR)
/kernels.o
$(BUILD_DIR)
/link.o
$(FILES_CPP)
-o
./bitsandbytes/libbitsandbytes_cuda
$(CUDA_VERSION)
.so
$(LIB)
cuda92
:
$(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
cuda92
:
$(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
...
...
bitsandbytes/functional.py
View file @
c4cfe4fb
...
@@ -73,6 +73,7 @@ if COMPILED_WITH_CUDA:
...
@@ -73,6 +73,7 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise
[
"adam"
]
=
(
str2optimizer8bit_blockwise
[
"adam"
]
=
(
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp32
,
lib
.
cadam_8bit_blockwise_fp16
,
lib
.
cadam_8bit_blockwise_fp16
,
lib
.
cadam_8bit_blockwise_bf16
,
)
)
str2optimizer8bit_blockwise
[
"momentum"
]
=
(
str2optimizer8bit_blockwise
[
"momentum"
]
=
(
lib
.
cmomentum_8bit_blockwise_fp32
,
lib
.
cmomentum_8bit_blockwise_fp32
,
...
@@ -1125,28 +1126,23 @@ def optimizer_update_8bit_blockwise(
...
@@ -1125,28 +1126,23 @@ def optimizer_update_8bit_blockwise(
skip_zeros
=
False
,
skip_zeros
=
False
,
)
->
None
:
)
->
None
:
optim_func
=
None
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
if
g
.
dtype
==
torch
.
float32
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
0
](
optimizer_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
0
]
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state2
),
ct
.
c_float
(
beta1
),
ct
.
c_float
(
beta2
),
ct
.
c_float
(
eps
),
ct
.
c_int32
(
step
),
ct
.
c_float
(
lr
),
get_ptr
(
qmap1
),
get_ptr
(
qmap2
),
get_ptr
(
absmax1
),
get_ptr
(
absmax2
),
ct
.
c_float
(
weight_decay
),
ct
.
c_float
(
gnorm_scale
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
)
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
elif
g
.
dtype
==
torch
.
float16
and
state1
.
dtype
==
torch
.
uint8
:
str2optimizer8bit_blockwise
[
optimizer_name
][
1
](
optimizer_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
1
]
elif
(
g
.
dtype
==
torch
.
bfloat16
and
state1
.
dtype
==
torch
.
uint8
and
len
(
str2optimizer8bit_blockwise
[
optimizer_name
])
==
3
):
optimizer_func
=
str2optimizer8bit_blockwise
[
optimizer_name
][
2
]
else
:
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
is_on_gpu
([
p
,
g
,
state1
,
state2
,
qmap1
,
qmap2
,
absmax1
,
absmax2
])
prev_device
=
pre_call
(
g
.
device
)
optimizer_func
(
get_ptr
(
p
),
get_ptr
(
p
),
get_ptr
(
g
),
get_ptr
(
g
),
get_ptr
(
state1
),
get_ptr
(
state1
),
...
@@ -1165,11 +1161,7 @@ def optimizer_update_8bit_blockwise(
...
@@ -1165,11 +1161,7 @@ def optimizer_update_8bit_blockwise(
ct
.
c_bool
(
skip_zeros
),
ct
.
c_bool
(
skip_zeros
),
ct
.
c_int32
(
g
.
numel
()),
ct
.
c_int32
(
g
.
numel
()),
)
)
else
:
post_call
(
prev_device
)
raise
ValueError
(
f
"Gradient+optimizer bit data type combination not supported: grad
{
g
.
dtype
}
, optimizer
{
state1
.
dtype
}
"
)
def
percentile_clipping
(
def
percentile_clipping
(
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
grad
:
Tensor
,
gnorm_vec
:
Tensor
,
step
:
int
,
percentile
:
int
=
5
...
...
csrc/kernels.cu
View file @
c4cfe4fb
...
@@ -2988,6 +2988,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
...
@@ -2988,6 +2988,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
float
,
2048
,
8
)
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
half
,
2048
,
8
)
MAKE_OptimizerStatic8bit2StateBlockwise
(
ADAM
,
__nv_bfloat16
,
2048
,
8
)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
...
...
csrc/ops.cu
View file @
c4cfe4fb
...
@@ -741,3 +741,5 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
...
@@ -741,3 +741,5 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template
void
percentileClipping
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
template
void
percentileClipping
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
);
MAKE_optimizerStatic8bitBlockwise
(
__nv_bfloat16
,
ADAM
);
csrc/pythonInterface.c
View file @
c4cfe4fb
...
@@ -57,19 +57,20 @@ MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
...
@@ -57,19 +57,20 @@ MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_FUNC8
(
rmsprop
,
RMSPROP
,
half
,
16
)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_
fp
##gbits(gtype* p, gtype* g, \
void fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
half
,
16
)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
half
,
fp16
)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
float
,
32
)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
float
,
fp32
)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
half
,
fp16
)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_BLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
fp32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
fp16
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_BLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
fp32
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
fp16
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_BLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
fp32
)
MAKE_BLOCKWISE8
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
void
percentileClipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping
<
float
>
(
g
,
gnorm_vec
,
step
,
n
);
}
...
@@ -194,20 +195,20 @@ extern "C"
...
@@ -194,20 +195,20 @@ extern "C"
MAKE_CFUNC8
(
rmsprop
,
half
,
16
)
MAKE_CFUNC8
(
rmsprop
,
half
,
16
)
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_
fp
##gbits(gtype* p, gtype* g, \
void c##fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_
fp
##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
{ fname##_8bit_blockwise_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
half
,
fp
16
)
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
float
,
fp
32
)
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
half
,
16
)
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
half
,
fp
16
)
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
32
)
MAKE_CBLOCKWISE8
(
momentum
,
MOMENTUM
,
float
,
fp
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
half
,
fp
16
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
32
)
MAKE_CBLOCKWISE8
(
rmsprop
,
RMSPROP
,
float
,
fp
32
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
16
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
half
,
fp
16
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
32
)
MAKE_CBLOCKWISE8
(
adagrad
,
ADAGRAD
,
float
,
fp
32
)
MAKE_CBLOCKWISE8
(
adam
,
ADAM
,
__nv_bfloat16
,
bf16
)
void
cpercentile_clipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g32
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g32
(
float
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g32
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
...
...
tests/test_optim.py
View file @
c4cfe4fb
...
@@ -26,6 +26,8 @@ def get_temp_dir():
...
@@ -26,6 +26,8 @@ def get_temp_dir():
def
rm_path
(
path
):
def
rm_path
(
path
):
shutil
.
rmtree
(
path
)
shutil
.
rmtree
(
path
)
str2bf16support
=
{}
str2bf16support
[
'adam8bit_blockwise'
]
=
True
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
"adam_pytorch"
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
...
@@ -238,7 +240,7 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -238,7 +240,7 @@ def test_global_config(dim1, dim2, gtype):
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
optimizer_names
=
[
optimizer_names
=
[
"adam8bit"
,
"adam8bit"
,
"momentum8bit"
,
"momentum8bit"
,
...
@@ -256,6 +258,7 @@ names = [
...
@@ -256,6 +258,7 @@ names = [
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
def
test_optimizer8bit
(
dim1
,
dim2
,
gtype
,
optim_name
):
if
gtype
==
torch
.
bfloat16
and
optim_name
not
in
str2bf16support
:
return
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
"cuda"
,
dtype
=
gtype
)
*
0.1
...
@@ -269,7 +272,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -269,7 +272,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
3e-3
,
1e-3
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
elif
gtype
==
torch
.
bfloat16
:
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-4
,
1e-2
else
:
else
:
atol
,
rtol
=
3e-3
,
1e-3
atol
,
rtol
=
3e-3
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
patol
,
prtol
=
1e-5
,
1e-3
...
@@ -314,6 +319,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -314,6 +319,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
err
=
torch
.
abs
(
p1
-
p2
)
err
=
torch
.
abs
(
p1
-
p2
)
relerr
=
err
/
torch
.
abs
(
p1
)
relerr
=
err
/
torch
.
abs
(
p1
)
if
g
.
dtype
==
torch
.
bfloat16
:
assert
err
.
mean
()
<
0.00015
assert
relerr
.
mean
()
<
0.0015
else
:
assert
err
.
mean
()
<
0.0001
assert
err
.
mean
()
<
0.0001
assert
relerr
.
mean
()
<
0.001
assert
relerr
.
mean
()
<
0.001
...
@@ -335,12 +344,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -335,12 +344,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
bnb_optimizer
.
load_state_dict
(
torch
.
load
(
join
(
path
,
"opt.pt"
)))
rm_path
(
path
)
rm_path
(
path
)
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
])
raws1cpy
,
bnb_optimizer
.
state
[
p2
][
name2
]
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
])
)
torch
.
testing
.
assert_allclose
(
qmap1
,
bnb_optimizer
.
state
[
p2
][
qmap
]
)
if
"blockwise"
in
optim_name
:
if
"blockwise"
in
optim_name
:
s1
=
F
.
dequantize_blockwise
(
s1
=
F
.
dequantize_blockwise
(
...
@@ -357,28 +362,16 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
...
@@ -357,28 +362,16 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
)
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
torch
.
testing
.
assert_allclose
(
s1cpy
,
s1
)
num_not_close
=
(
num_not_close
=
(
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
)
==
0
)
torch
.
isclose
(
torch_optimizer
.
state
[
p1
][
name1
],
s1
,
atol
=
atol
,
rtol
=
rtol
,
)
==
0
)
assert
num_not_close
.
sum
().
item
()
<
20
assert
num_not_close
.
sum
().
item
()
<
20
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
p1
,
p2
.
float
(),
atol
=
patol
,
rtol
=
prtol
)
# the parameters diverge quickly. Here we keep them close
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
# together so we can test against the Adam error
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p1
.
data
=
p1
.
data
.
to
(
gtype
).
float
()
p2
.
copy_
(
p1
.
data
)
p2
.
copy_
(
p1
.
data
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
torch
.
testing
.
assert_allclose
(
p1
.
to
(
gtype
),
p2
)
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
for
(
name1
,
name2
,
qmap
,
max_val
),
s
in
zip
(
str2statenames
[
optim_name
],
dequant_states
):
str2statenames
[
optim_name
],
dequant_states
):
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
torch_optimizer
.
state
[
p1
][
name1
].
copy_
(
s
.
data
)
# print(sum(errors)/len(errors))
# print(sum(errors)/len(errors))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment