Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5d7dc352
Unverified
Commit
5d7dc352
authored
Mar 16, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 16, 2022
Browse files
[hotfix] run cpu adam unittest in pytest (#424)
parent
54229cd3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
19 deletions
+19
-19
tests/test_optimizer/unittest_cpu_adam.py
tests/test_optimizer/unittest_cpu_adam.py
+19
-19
No files found.
tests/test_optimizer/unittest_cpu_adam.py
View file @
5d7dc352
...
@@ -29,12 +29,12 @@
...
@@ -29,12 +29,12 @@
import
math
import
math
import
torch
import
torch
import
colossalai
try
:
try
:
import
cpu_adam
import
cpu_adam
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"import cpu_adam error"
)
raise
ImportError
(
"import cpu_adam error"
)
def
torch_adam_update
(
def
torch_adam_update
(
step
,
step
,
lr
,
lr
,
...
@@ -42,7 +42,6 @@ def torch_adam_update(
...
@@ -42,7 +42,6 @@ def torch_adam_update(
beta2
,
beta2
,
eps
,
eps
,
weight_decay
,
weight_decay
,
bias_correction
,
param
,
param
,
grad
,
grad
,
exp_avg
,
exp_avg
,
...
@@ -52,8 +51,8 @@ def torch_adam_update(
...
@@ -52,8 +51,8 @@ def torch_adam_update(
):
):
if
loss_scale
>
0
:
if
loss_scale
>
0
:
grad
.
div_
(
loss_scale
)
grad
.
div_
(
loss_scale
)
bias_correction1
=
1
-
beta1
**
step
bias_correction1
=
1
-
beta1
**
step
bias_correction2
=
1
-
beta2
**
step
bias_correction2
=
1
-
beta2
**
step
if
weight_decay
!=
0
:
if
weight_decay
!=
0
:
if
use_adamw
:
if
use_adamw
:
...
@@ -73,12 +72,13 @@ def torch_adam_update(
...
@@ -73,12 +72,13 @@ def torch_adam_update(
class
Test
():
class
Test
():
def
__init__
(
self
):
def
__init__
(
self
):
self
.
opt_id
=
0
self
.
opt_id
=
0
def
assertLess
(
self
,
data_diff
,
threshold
,
msg
):
def
assertLess
(
self
,
data_diff
,
threshold
,
msg
):
assert
data_diff
<
threshold
,
msg
assert
data_diff
<
threshold
,
msg
def
assertTrue
(
self
,
condition
,
msg
):
def
assertTrue
(
self
,
condition
,
msg
):
assert
condition
,
msg
assert
condition
,
msg
...
@@ -89,7 +89,6 @@ class Test():
...
@@ -89,7 +89,6 @@ class Test():
eps
,
eps
,
beta1
,
beta1
,
beta2
,
beta2
,
weight_decay
,
weight_decay
,
shape
,
shape
,
grad_dtype
,
grad_dtype
,
...
@@ -118,8 +117,8 @@ class Test():
...
@@ -118,8 +117,8 @@ class Test():
eps
,
eps
,
weight_decay
,
weight_decay
,
True
,
True
,
p_data
.
view
(
-
1
),
# fp32 data
p_data
.
view
(
-
1
),
# fp32 data
p_grad
.
view
(
-
1
),
# fp32 grad
p_grad
.
view
(
-
1
),
# fp32 grad
exp_avg
.
view
(
-
1
),
exp_avg
.
view
(
-
1
),
exp_avg_sq
.
view
(
-
1
),
exp_avg_sq
.
view
(
-
1
),
loss_scale
,
loss_scale
,
...
@@ -132,15 +131,14 @@ class Test():
...
@@ -132,15 +131,14 @@ class Test():
beta2
,
beta2
,
eps
,
eps
,
weight_decay
,
weight_decay
,
True
,
p_data_copy
,
# fp32 data
p_data_copy
,
# fp32 data
p_grad_copy
,
# fp32 grad
p_grad_copy
,
# fp32 grad
exp_avg_copy
,
exp_avg_copy
,
exp_avg_sq_copy
,
exp_avg_sq_copy
,
loss_scale
,
loss_scale
,
use_adamw
,
use_adamw
,
)
)
if
loss_scale
>
0
:
if
loss_scale
>
0
:
p_grad
.
div_
(
loss_scale
)
p_grad
.
div_
(
loss_scale
)
...
@@ -158,16 +156,14 @@ class Test():
...
@@ -158,16 +156,14 @@ class Test():
max_exp_avg_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_copy
-
exp_avg
))
max_exp_avg_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_copy
-
exp_avg
))
self
.
assertTrue
(
max_exp_avg_diff
<
threshold
,
f
"max_exp_avg_diff
{
max_exp_avg_diff
}
"
)
self
.
assertTrue
(
max_exp_avg_diff
<
threshold
,
f
"max_exp_avg_diff
{
max_exp_avg_diff
}
"
)
max_exp_avg_sq_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_sq_copy
-
exp_avg_sq
))
max_exp_avg_sq_diff
=
torch
.
max
(
torch
.
abs
(
exp_avg_sq_copy
-
exp_avg_sq
))
self
.
assertTrue
(
self
.
assertTrue
(
max_exp_avg_sq_diff
<
threshold
,
f
"max_exp_avg_sq_diff
{
max_exp_avg_sq_diff
}
"
)
max_exp_avg_sq_diff
<
threshold
,
f
"max_exp_avg_sq_diff
{
max_exp_avg_sq_diff
}
"
)
def
test_cpu_adam
(
self
):
def
test_cpu_adam
(
self
):
lr
=
0.9
lr
=
0.9
eps
=
1e-6
eps
=
1e-6
weight_decay
=
0
weight_decay
=
0
for
use_adamw
in
[
False
,
True
]:
for
use_adamw
in
[
False
,
True
]:
for
shape
in
[(
10
23
,
),
(
32
,
10
24
)]:
for
shape
in
[(
23
,),
(
8
,
24
)]:
for
step
in
range
(
1
,
2
):
for
step
in
range
(
1
,
2
):
for
lr
in
[
0.01
]:
for
lr
in
[
0.01
]:
for
eps
in
[
1e-8
]:
for
eps
in
[
1e-8
]:
...
@@ -175,7 +171,7 @@ class Test():
...
@@ -175,7 +171,7 @@ class Test():
for
beta2
in
[
0.999
]:
for
beta2
in
[
0.999
]:
for
weight_decay
in
[
0.001
]:
for
weight_decay
in
[
0.001
]:
for
grad_dtype
in
[
torch
.
half
,
torch
.
float
]:
for
grad_dtype
in
[
torch
.
half
,
torch
.
float
]:
for
loss_scale
in
[
-
1
,
2
**
5
]:
for
loss_scale
in
[
-
1
,
2
**
5
]:
self
.
check_res
(
self
.
check_res
(
step
,
step
,
lr
,
lr
,
...
@@ -191,7 +187,11 @@ class Test():
...
@@ -191,7 +187,11 @@ class Test():
)
)
def
test_cpu_adam
():
test_case
=
Test
()
test_case
.
test_cpu_adam
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test
=
Test
()
test
=
Test
()
test
.
test_cpu_adam
()
test
.
test_cpu_adam
()
print
(
'All is well.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment