Unverified Commit 8ee5a8ff authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] allow fp16 optimizer state with Adam (#41)



Allow training with optimizer state in fp16. Use an enum to select from full-precision, mixed precision, memory efficient mixed precision and pure fp16. Improve clarity of testing code
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent e2d8f573
...@@ -11,7 +11,7 @@ from torchtext.data.utils import get_tokenizer ...@@ -11,7 +11,7 @@ from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe from fairscale.nn import Pipe
try: try:
from fairscale.optim.adam import Adam # type: ignore from fairscale.optim.adam import Adam, Precision # type: ignore
can_benchmark = True can_benchmark = True
except ImportError: except ImportError:
...@@ -135,7 +135,7 @@ def make_model(device, ntokens): ...@@ -135,7 +135,7 @@ def make_model(device, ntokens):
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate lr = 0.01 # learning rate
optimizer = Adam(p.parameters(), lr=lr, mixed_precision=True) optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)
return p, criterion, optimizer return p, criterion, optimizer
......
...@@ -20,7 +20,7 @@ typedef enum{ ...@@ -20,7 +20,7 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root ADAM_MODE_1 =1 // eps outside square root
} adamMode_t; } adamMode_t;
template <int DEPTH, typename PARAM_T, typename GRAD_T> template <int DEPTH, typename PARAM_T, typename GRAD_T, typename OPTIM_T>
struct AdamFunctor struct AdamFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -41,9 +41,9 @@ struct AdamFunctor ...@@ -41,9 +41,9 @@ struct AdamFunctor
PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc]; PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
float* m = (float *)tl.addresses[1][tensor_loc]; OPTIM_T* m = (OPTIM_T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
float* v = (float *)tl.addresses[2][tensor_loc]; OPTIM_T* v = (OPTIM_T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
...@@ -56,8 +56,8 @@ struct AdamFunctor ...@@ -56,8 +56,8 @@ struct AdamFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
PARAM_T incoming_p[ILP]; PARAM_T incoming_p[ILP];
float incoming_m[ILP]; OPTIM_T incoming_m[ILP];
float incoming_v[ILP]; OPTIM_T incoming_v[ILP];
GRAD_T incoming_g[ILP]; GRAD_T incoming_g[ILP];
for(int i_start = 0; for(int i_start = 0;
...@@ -91,14 +91,16 @@ struct AdamFunctor ...@@ -91,14 +91,16 @@ struct AdamFunctor
if(j < n && j < chunk_size) { if(j < n && j < chunk_size) {
float scaled_grad = incoming_g[ii]/grad_scale; float scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom; float denom;
if (mode == ADAM_MODE_0) if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps); denom = sqrtf(velocity + eps);
else // Mode 1 else // Mode 1
denom = sqrtf(v[j]) + eps; denom = sqrtf(velocity) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]); float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update)); p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j]; if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
} }
...@@ -135,6 +137,7 @@ void fused_adam_cuda( ...@@ -135,6 +137,7 @@ void fused_adam_cuda(
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
assert(tl_sz == 4 || tl_sz == 5); assert(tl_sz == 4 || tl_sz == 5);
assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type());
if(tl_sz == 5) { if(tl_sz == 5) {
// Mixed precision case // Mixed precision case
...@@ -146,7 +149,7 @@ void fused_adam_cuda( ...@@ -146,7 +149,7 @@ void fused_adam_cuda(
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<5, float, at::Half>(), AdamFunctor<5, float, at::Half, float>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -160,12 +163,13 @@ void fused_adam_cuda( ...@@ -160,12 +163,13 @@ void fused_adam_cuda(
assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type()); assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type());
if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) { if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) {
// Full precision case // Full precision case
assert(tensor_lists[1][0].scalar_type() == at::ScalarType::Float);
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, float, float>(), AdamFunctor<4, float, float, float>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -175,13 +179,14 @@ void fused_adam_cuda( ...@@ -175,13 +179,14 @@ void fused_adam_cuda(
decay decay
); );
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) { } else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
// "Memory Efficient Training" case if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) {
// FP16 model parameters and gradients; FP32 optimizer state
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
noop_flag, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<4, at::Half, at::Half>(), AdamFunctor<4, at::Half, at::Half, float>(),
beta1, beta1,
beta2, beta2,
eps, eps,
...@@ -190,6 +195,25 @@ void fused_adam_cuda( ...@@ -190,6 +195,25 @@ void fused_adam_cuda(
(adamMode_t) mode, (adamMode_t) mode,
decay decay
); );
} else if (tensor_lists[1][0].scalar_type() == at::ScalarType::Half) {
// Pure FP16 case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
throw "Optimizer state must be of type float or half";
}
} else { } else {
throw "Parameters must be of type float or half"; throw "Parameters must be of type float or half";
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -15,6 +16,12 @@ else: ...@@ -15,6 +16,12 @@ else:
try: try:
from fairscale import fused_adam_cuda # type: ignore from fairscale import fused_adam_cuda # type: ignore
class Precision(Enum):
FULL_PRECISION = auto()
MIXED_PRECISION = auto()
MEMORY_EFFICIENT_MIXED_PRECISION = auto()
PURE_FP16 = auto()
class Adam(torch.optim.Optimizer): class Adam(torch.optim.Optimizer):
state: dict state: dict
defaults: dict defaults: dict
...@@ -39,6 +46,10 @@ try: ...@@ -39,6 +46,10 @@ try:
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
precision (Precision, optional): One of Precision.FULL_PRECISION,
Precision.MIXED_PRECISION, Precision.MEMORY_EFFICIENT_MIXED_PRECISION
or Precision.PURE_FP16. Inferred based on model parameter precision if
None. (default: None)
.. _Adam: A Method for Stochastic Optimization: .. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
...@@ -56,11 +67,24 @@ try: ...@@ -56,11 +67,24 @@ try:
weight_decay: Optional[float] = 0.0, weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0, max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False, amsgrad: Optional[bool] = False,
mixed_precision: Optional[bool] = False, precision: Optional[Precision] = None,
): ):
self.mixed_precision = mixed_precision
parameters: List[Any] = list(params) parameters: List[Any] = list(params)
if precision is None:
precision = (
Precision.FULL_PRECISION if parameters[0].dtype == torch.float32 else Precision.MIXED_PRECISION
)
self.mixed_precision = False
if precision is Precision.MIXED_PRECISION:
self.mixed_precision = True
if precision is not Precision.FULL_PRECISION:
assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad: if amsgrad:
...@@ -76,7 +100,8 @@ try: ...@@ -76,7 +100,8 @@ try:
super().__init__(parameters, defaults) super().__init__(parameters, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
if mixed_precision: self.fp32_param_groups: List[Any] = []
if self.mixed_precision:
self._build_fp32_params(parameters) self._build_fp32_params(parameters)
def _build_fp32_params(self, params: Any) -> None: def _build_fp32_params(self, params: Any) -> None:
...@@ -119,6 +144,21 @@ try: ...@@ -119,6 +144,21 @@ try:
def _step_supports_amp_scaling(self) -> bool: def _step_supports_amp_scaling(self) -> bool:
return False return False
def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_type"] = self.optim_type
d["mixed_precision"] = self.mixed_precision
d["fp32_param_groups"] = self.fp32_param_groups
d["state"] = self.state
return d
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self.optim_type = state_dict["optim_type"]
self.mixed_precision = state_dict["mixed_precision"]
self.fp32_param_groups = state_dict["fp32_param_groups"]
self.state = state_dict["state"]
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -161,9 +201,9 @@ try: ...@@ -161,9 +201,9 @@ try:
if len(state) == 0: if len(state) == 0:
state["step"] = 0 state["step"] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) state["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) state["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
exp_avg = state["exp_avg"] exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"] exp_avg_sq = state["exp_avg_sq"]
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
import torch import torch
try: try:
from fairscale.optim.adam import Adam from fairscale.optim.adam import Adam, Precision
imported_adam = True imported_adam = True
except ImportError: except ImportError:
...@@ -20,17 +20,12 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda ...@@ -20,17 +20,12 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available") skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")
@skip_if_no_cuda def assert_almost_zero(x):
@skip_if_no_adam assert abs(x) < 2 * 1e-3
def test_step(): return 1.0
weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def step_test(optimizer, weight, bias, input):
def fn(): def fn():
optimizer.zero_grad() optimizer.zero_grad()
y = weight.mv(input) y = weight.mv(input)
...@@ -44,71 +39,67 @@ def test_step(): ...@@ -44,71 +39,67 @@ def test_step():
for _i in range(5): for _i in range(5):
optimizer.step(fn) optimizer.step(fn)
assert fn().item() < initial_value assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad: def state_dict_test(optimizer, weight, bias, input):
assert p.dtype == torch.float32 def fn_base(optimizer, weight, bias, input):
with pytest.raises(AttributeError): optimizer.zero_grad()
optimizer.fp32_param_groups loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
fn = functools.partial(fn_base, optimizer, weight, bias, input)
# Prime the optimizer
for _i in range(5):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict)
# Run both optimizations in parallel
for _i in range(5):
optimizer.step(fn)
optimizer_c.step(fn_c)
(weight - weight_c).to("cpu").detach().apply_(assert_almost_zero)
(bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_me(): def test_step_full_precision_inferred():
weight = torch.randn(10, 5).cuda().half().requires_grad_() weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_() bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).half().cuda() input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string # to check if the optimizer can be printed as a string
optimizer.__repr__() optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups: for group in optimizer.param_groups:
for p in group["params"]: for p in group["params"]:
if p.requires_grad: if p.requires_grad:
assert p.dtype == torch.float16 assert p.dtype == torch.float32
with pytest.raises(AttributeError): assert not optimizer.fp32_param_groups
optimizer.fp32_param_groups
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_mixed_precision(): def test_step_mixed_precision_inferred():
weight = torch.randn(10, 5).cuda().half().requires_grad_() weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_() bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda() input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string # to check if the optimizer can be printed as a string
optimizer.__repr__() optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups) assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups): for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
...@@ -124,6 +115,44 @@ def test_step_mixed_precision(): ...@@ -124,6 +115,44 @@ def test_step_mixed_precision():
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero) (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda
@skip_if_no_adam
def test_step_memory_efficient():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
assert not optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
assert not optimizer.fp32_param_groups
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_step_multigpu(): def test_step_multigpu():
...@@ -136,20 +165,7 @@ def test_step_multigpu(): ...@@ -136,20 +165,7 @@ def test_step_multigpu():
# to check if the optimizer can be printed as a string # to check if the optimizer can be printed as a string
optimizer.__repr__() optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda @skip_if_no_cuda
...@@ -160,61 +176,75 @@ def test_step_multigpu_mixed_precision(): ...@@ -160,61 +176,75 @@ def test_step_multigpu_mixed_precision():
weight = torch.randn(10, 5).cuda(0).half().requires_grad_() weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
bias = torch.randn(10).cuda(1).half().requires_grad_() bias = torch.randn(10).cuda(1).half().requires_grad_()
input = torch.randn(5).cuda(0).half() input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True) optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string # to check if the optimizer can be printed as a string
optimizer.__repr__() optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item() @skip_if_no_cuda
for _i in range(5): @skip_if_no_adam
optimizer.step(fn) def test_step_pure_fp16_multigpu():
assert fn().item() < initial_value if not torch.cuda.device_count() > 1:
return
weight = torch.randn(10, 5).half().cuda(0).requires_grad_()
bias = torch.randn(10).half().cuda(1).requires_grad_()
input = torch.randn(5).half().cuda(0)
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict(): def test_state_dict_full_precision():
weight = torch.randn(10, 5).float().cuda().requires_grad_() weight = torch.randn(10, 5).float().cuda().requires_grad_()
bias = torch.randn(10).float().cuda().requires_grad_() bias = torch.randn(10).float().cuda().requires_grad_()
input = torch.randn(5).float().cuda() input = torch.randn(5).float().cuda()
optimizer = Adam([weight, bias], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3)
def fn_base(optimizer, weight, bias, input): state_dict_test(optimizer, weight, bias, input)
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
fn = functools.partial(fn_base, optimizer, weight, bias, input)
# Prime the optimizer @skip_if_no_cuda
for _i in range(5): @skip_if_no_adam
optimizer.step(fn) def test_state_dict_mixed_precision():
# Clone the weights and construct new optimizer for them weight = torch.randn(10, 5).half().cuda().requires_grad_()
weight_c = weight.data.clone().requires_grad_() bias = torch.randn(10).half().cuda().requires_grad_()
bias_c = bias.data.clone().requires_grad_() input = torch.randn(5).half().cuda()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict state_dict_test(optimizer, weight, bias, input)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c) @skip_if_no_cuda
# Run both optimizations in parallel @skip_if_no_adam
for _i in range(5): def test_state_dict_memory_efficient():
optimizer.step(fn) weight = torch.randn(10, 5).half().cuda().requires_grad_()
optimizer_c.step(fn_c) bias = torch.randn(10).half().cuda().requires_grad_()
assert torch.equal(weight, weight_c) input = torch.randn(5).half().cuda()
assert torch.equal(bias, bias_c) optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda @skip_if_no_cuda
...@@ -226,11 +256,6 @@ def test_build_fp32_params(): ...@@ -226,11 +256,6 @@ def test_build_fp32_params():
optimizer._build_fp32_params([weight, bias]) optimizer._build_fp32_params([weight, bias])
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups): for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]): for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32 assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad: if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16 assert fp16_p.dtype == torch.float16
...@@ -262,3 +287,30 @@ def test_amsgrad(): ...@@ -262,3 +287,30 @@ def test_amsgrad():
bias = torch.randn(10, requires_grad=True).float().cuda() bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
Adam([weight, bias], lr=1e-2, amsgrad=True) Adam([weight, bias], lr=1e-2, amsgrad=True)
@skip_if_no_cuda
@skip_if_no_adam
def test_mixed_precision_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.MIXED_PRECISION)
@skip_if_no_cuda
@skip_if_no_adam
def test_memory_efficient_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
@skip_if_no_cuda
@skip_if_no_adam
def test_pure_fp16_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.PURE_FP16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment