Unverified Commit 9d6c7b6a authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[fix] fix tests and state_dict; refactor tests (#45)



Refactor tests to remove duplicated code. Fix the state_dict test to instantiate the second optimizer with the correct precision. Fix Adam.load_state_dict to make optimizer state the right type.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 8ee5a8ff
......@@ -70,17 +70,14 @@ try:
precision: Optional[Precision] = None,
):
parameters: List[Any] = list(params)
self.precision = precision
if precision is None:
precision = (
if self.precision is None:
self.precision = (
Precision.FULL_PRECISION if parameters[0].dtype == torch.float32 else Precision.MIXED_PRECISION
)
self.mixed_precision = False
if precision is Precision.MIXED_PRECISION:
self.mixed_precision = True
if precision is not Precision.FULL_PRECISION:
if self.precision is not Precision.FULL_PRECISION:
assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
......@@ -144,20 +141,16 @@ try:
def _step_supports_amp_scaling(self) -> bool:
return False
def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_type"] = self.optim_type
d["mixed_precision"] = self.mixed_precision
d["fp32_param_groups"] = self.fp32_param_groups
d["state"] = self.state
return d
@property
def mixed_precision(self) -> bool:
return self.precision is Precision.MIXED_PRECISION
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self.optim_type = state_dict["optim_type"]
self.mixed_precision = state_dict["mixed_precision"]
self.fp32_param_groups = state_dict["fp32_param_groups"]
self.state = state_dict["state"]
for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = self.state[p]["exp_avg"].type(self.optim_type)
self.state[p]["exp_avg_sq"] = self.state[p]["exp_avg_sq"].type(self.optim_type)
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
......
......@@ -20,12 +20,26 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")
def assert_almost_zero(x):
assert abs(x) < 2 * 1e-3
return 1.0
def make_full_precision_params():
weight = torch.randn(2, 1).cuda().requires_grad_()
bias = torch.randn(2).cuda().requires_grad_()
input = torch.randn(1).cuda()
return weight, bias, input
def make_half_precision_params():
weight = torch.randn(2, 1).cuda().half().requires_grad_()
bias = torch.randn(2).cuda().half().requires_grad_()
input = torch.randn(1).half().cuda()
return weight, bias, input
def step_test(optimizer, weight, bias, input):
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
......@@ -56,7 +70,7 @@ def state_dict_test(optimizer, weight, bias, input):
# Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3)
optimizer_c = Adam([weight_c, bias_c], lr=1e-3, precision=optimizer.precision)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
......@@ -69,16 +83,17 @@ def state_dict_test(optimizer, weight, bias, input):
(bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
@skip_if_no_cuda
@skip_if_no_adam
def test_step_full_precision_inferred():
weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
weight, bias, input = make_full_precision_params()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
......@@ -87,17 +102,17 @@ def test_step_full_precision_inferred():
assert p.dtype == torch.float32
assert not optimizer.fp32_param_groups
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda
@skip_if_no_adam
def test_step_mixed_precision_inferred():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
......@@ -114,42 +129,49 @@ def test_step_mixed_precision_inferred():
assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda
@skip_if_no_adam
def test_step_memory_efficient():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
assert not optimizer.fp32_param_groups
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
assert not optimizer.fp32_param_groups
......@@ -163,8 +185,6 @@ def test_step_multigpu():
input = torch.randn(5).cuda(0)
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
......@@ -178,8 +198,6 @@ def test_step_multigpu_mixed_precision():
input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
......@@ -193,8 +211,6 @@ def test_step_pure_fp16_multigpu():
input = torch.randn(5).half().cuda(0)
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
......@@ -206,9 +222,7 @@ def test_step_pure_fp16_multigpu():
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_full_precision():
weight = torch.randn(10, 5).float().cuda().requires_grad_()
bias = torch.randn(10).float().cuda().requires_grad_()
input = torch.randn(5).float().cuda()
weight, bias, input = make_full_precision_params()
optimizer = Adam([weight, bias], lr=1e-3)
state_dict_test(optimizer, weight, bias, input)
......@@ -217,9 +231,7 @@ def test_state_dict_full_precision():
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_mixed_precision():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input)
......@@ -228,9 +240,7 @@ def test_state_dict_mixed_precision():
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_memory_efficient():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input)
......@@ -239,9 +249,7 @@ def test_state_dict_memory_efficient():
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
state_dict_test(optimizer, weight, bias, input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment