Unverified Commit 5d7dc352 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] run cpu adam unittest in pytest (#424)

parent 54229cd3
......@@ -29,12 +29,12 @@
import math
import torch
import colossalai
try:
import cpu_adam
except ImportError:
raise ImportError("import cpu_adam error")
def torch_adam_update(
step,
lr,
......@@ -42,7 +42,6 @@ def torch_adam_update(
beta2,
eps,
weight_decay,
bias_correction,
param,
grad,
exp_avg,
......@@ -52,8 +51,8 @@ def torch_adam_update(
):
if loss_scale > 0:
grad.div_(loss_scale)
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
......@@ -73,12 +72,13 @@ def torch_adam_update(
class Test():
def __init__(self):
self.opt_id = 0
def assertLess(self, data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(self, condition, msg):
assert condition, msg
......@@ -89,7 +89,6 @@ class Test():
eps,
beta1,
beta2,
weight_decay,
shape,
grad_dtype,
......@@ -118,8 +117,8 @@ class Test():
eps,
weight_decay,
True,
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
exp_avg.view(-1),
exp_avg_sq.view(-1),
loss_scale,
......@@ -132,15 +131,14 @@ class Test():
beta2,
eps,
weight_decay,
True,
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
exp_avg_copy,
exp_avg_sq_copy,
loss_scale,
use_adamw,
)
if loss_scale > 0:
p_grad.div_(loss_scale)
......@@ -158,16 +156,14 @@ class Test():
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
self.assertTrue(
max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}"
)
self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
def test_cpu_adam(self):
lr = 0.9
eps = 1e-6
weight_decay = 0
for use_adamw in [False, True]:
for shape in [(1023, ), (32, 1024)]:
for shape in [(23,), (8, 24)]:
for step in range(1, 2):
for lr in [0.01]:
for eps in [1e-8]:
......@@ -175,7 +171,7 @@ class Test():
for beta2 in [0.999]:
for weight_decay in [0.001]:
for grad_dtype in [torch.half, torch.float]:
for loss_scale in [-1, 2 ** 5]:
for loss_scale in [-1, 2**5]:
self.check_res(
step,
lr,
......@@ -191,7 +187,11 @@ class Test():
)
def test_cpu_adam():
test_case = Test()
test_case.test_cpu_adam()
if __name__ == "__main__":
test = Test()
test.test_cpu_adam()
print('All is well.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment