Unverified Commit 9d6c7b6a authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[fix] fix tests and state_dict; refactor tests (#45)



Refactor tests to remove duplicated code. Fix the state_dict test to instantiate the second optimizer with the correct precision. Fix Adam.load_state_dict to make optimizer state the right type.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 8ee5a8ff
...@@ -70,17 +70,14 @@ try: ...@@ -70,17 +70,14 @@ try:
precision: Optional[Precision] = None, precision: Optional[Precision] = None,
): ):
parameters: List[Any] = list(params) parameters: List[Any] = list(params)
self.precision = precision
if precision is None: if self.precision is None:
precision = ( self.precision = (
Precision.FULL_PRECISION if parameters[0].dtype == torch.float32 else Precision.MIXED_PRECISION Precision.FULL_PRECISION if parameters[0].dtype == torch.float32 else Precision.MIXED_PRECISION
) )
self.mixed_precision = False if self.precision is not Precision.FULL_PRECISION:
if precision is Precision.MIXED_PRECISION:
self.mixed_precision = True
if precision is not Precision.FULL_PRECISION:
assert parameters[0].dtype == torch.float16 assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32 self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
...@@ -144,20 +141,16 @@ try: ...@@ -144,20 +141,16 @@ try:
def _step_supports_amp_scaling(self) -> bool: def _step_supports_amp_scaling(self) -> bool:
return False return False
def state_dict(self) -> Dict[str, Any]: @property
d = super().state_dict() def mixed_precision(self) -> bool:
d["optim_type"] = self.optim_type return self.precision is Precision.MIXED_PRECISION
d["mixed_precision"] = self.mixed_precision
d["fp32_param_groups"] = self.fp32_param_groups
d["state"] = self.state
return d
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
self.optim_type = state_dict["optim_type"] for group in self.param_groups:
self.mixed_precision = state_dict["mixed_precision"] for p in group["params"]:
self.fp32_param_groups = state_dict["fp32_param_groups"] self.state[p]["exp_avg"] = self.state[p]["exp_avg"].type(self.optim_type)
self.state = state_dict["state"] self.state[p]["exp_avg_sq"] = self.state[p]["exp_avg_sq"].type(self.optim_type)
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step. """Performs a single optimization step.
......
...@@ -20,12 +20,26 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda ...@@ -20,12 +20,26 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available") skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")
def assert_almost_zero(x): def make_full_precision_params():
assert abs(x) < 2 * 1e-3 weight = torch.randn(2, 1).cuda().requires_grad_()
return 1.0 bias = torch.randn(2).cuda().requires_grad_()
input = torch.randn(1).cuda()
return weight, bias, input
def make_half_precision_params():
weight = torch.randn(2, 1).cuda().half().requires_grad_()
bias = torch.randn(2).cuda().half().requires_grad_()
input = torch.randn(1).half().cuda()
return weight, bias, input
def step_test(optimizer, weight, bias, input): def step_test(optimizer, weight, bias, input):
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn(): def fn():
optimizer.zero_grad() optimizer.zero_grad()
y = weight.mv(input) y = weight.mv(input)
...@@ -56,7 +70,7 @@ def state_dict_test(optimizer, weight, bias, input): ...@@ -56,7 +70,7 @@ def state_dict_test(optimizer, weight, bias, input):
# Clone the weights and construct new optimizer for them # Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_() weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_() bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3) optimizer_c = Adam([weight_c, bias_c], lr=1e-3, precision=optimizer.precision)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input) fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict # Load state dict
state_dict = deepcopy(optimizer.state_dict()) state_dict = deepcopy(optimizer.state_dict())
...@@ -69,16 +83,17 @@ def state_dict_test(optimizer, weight, bias, input): ...@@ -69,16 +83,17 @@ def state_dict_test(optimizer, weight, bias, input):
(bias - bias_c).to("cpu").detach().apply_(assert_almost_zero) (bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_full_precision_inferred(): def test_step_full_precision_inferred():
weight = torch.randn(10, 5).cuda().requires_grad_() weight, bias, input = make_full_precision_params()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups: for group in optimizer.param_groups:
...@@ -87,17 +102,17 @@ def test_step_full_precision_inferred(): ...@@ -87,17 +102,17 @@ def test_step_full_precision_inferred():
assert p.dtype == torch.float32 assert p.dtype == torch.float32
assert not optimizer.fp32_param_groups assert not optimizer.fp32_param_groups
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_mixed_precision_inferred(): def test_step_mixed_precision_inferred():
weight = torch.randn(10, 5).cuda().half().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups) assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
...@@ -114,42 +129,49 @@ def test_step_mixed_precision_inferred(): ...@@ -114,42 +129,49 @@ def test_step_mixed_precision_inferred():
assert fp16_p.dtype == torch.float16 assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero) (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_memory_efficient(): def test_step_memory_efficient():
weight = torch.randn(10, 5).cuda().half().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups: for group in optimizer.param_groups:
for p in group["params"]: for p in group["params"]:
if p.requires_grad: if p.requires_grad:
assert p.dtype == torch.float16 assert p.dtype == torch.float16
assert not optimizer.fp32_param_groups assert not optimizer.fp32_param_groups
assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_pure_fp16(): def test_step_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16 assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16 assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16 assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16 assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
assert not optimizer.fp32_param_groups assert not optimizer.fp32_param_groups
...@@ -163,8 +185,6 @@ def test_step_multigpu(): ...@@ -163,8 +185,6 @@ def test_step_multigpu():
input = torch.randn(5).cuda(0) input = torch.randn(5).cuda(0)
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
...@@ -178,8 +198,6 @@ def test_step_multigpu_mixed_precision(): ...@@ -178,8 +198,6 @@ def test_step_multigpu_mixed_precision():
input = torch.randn(5).cuda(0).half() input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
...@@ -193,8 +211,6 @@ def test_step_pure_fp16_multigpu(): ...@@ -193,8 +211,6 @@ def test_step_pure_fp16_multigpu():
input = torch.randn(5).half().cuda(0) input = torch.randn(5).half().cuda(0)
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input) step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16 assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
...@@ -206,9 +222,7 @@ def test_step_pure_fp16_multigpu(): ...@@ -206,9 +222,7 @@ def test_step_pure_fp16_multigpu():
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict_full_precision(): def test_state_dict_full_precision():
weight = torch.randn(10, 5).float().cuda().requires_grad_() weight, bias, input = make_full_precision_params()
bias = torch.randn(10).float().cuda().requires_grad_()
input = torch.randn(5).float().cuda()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
state_dict_test(optimizer, weight, bias, input) state_dict_test(optimizer, weight, bias, input)
...@@ -217,9 +231,7 @@ def test_state_dict_full_precision(): ...@@ -217,9 +231,7 @@ def test_state_dict_full_precision():
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict_mixed_precision(): def test_state_dict_mixed_precision():
weight = torch.randn(10, 5).half().cuda().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input) state_dict_test(optimizer, weight, bias, input)
...@@ -228,9 +240,7 @@ def test_state_dict_mixed_precision(): ...@@ -228,9 +240,7 @@ def test_state_dict_mixed_precision():
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict_memory_efficient(): def test_state_dict_memory_efficient():
weight = torch.randn(10, 5).half().cuda().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input) state_dict_test(optimizer, weight, bias, input)
...@@ -239,9 +249,7 @@ def test_state_dict_memory_efficient(): ...@@ -239,9 +249,7 @@ def test_state_dict_memory_efficient():
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict_pure_fp16(): def test_state_dict_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_() weight, bias, input = make_half_precision_params()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
state_dict_test(optimizer, weight, bias, input) state_dict_test(optimizer, weight, bias, input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment