Unverified Commit 8ee5a8ff authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] allow fp16 optimizer state with Adam (#41)



Allow training with optimizer state in fp16. Use an enum to select from full-precision, mixed precision, memory efficient mixed precision and pure fp16. Improve clarity of testing code
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent e2d8f573
......@@ -11,7 +11,7 @@ from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe
try:
from fairscale.optim.adam import Adam # type: ignore
from fairscale.optim.adam import Adam, Precision # type: ignore
can_benchmark = True
except ImportError:
......@@ -135,7 +135,7 @@ def make_model(device, ntokens):
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
optimizer = Adam(p.parameters(), lr=lr, mixed_precision=True)
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)
return p, criterion, optimizer
......
......@@ -20,7 +20,7 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename PARAM_T, typename GRAD_T>
template <int DEPTH, typename PARAM_T, typename GRAD_T, typename OPTIM_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
......@@ -41,9 +41,9 @@ struct AdamFunctor
PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
float* m = (float *)tl.addresses[1][tensor_loc];
OPTIM_T* m = (OPTIM_T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
float* v = (float *)tl.addresses[2][tensor_loc];
OPTIM_T* v = (OPTIM_T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
......@@ -56,8 +56,8 @@ struct AdamFunctor
n -= chunk_idx*chunk_size;
PARAM_T incoming_p[ILP];
float incoming_m[ILP];
float incoming_v[ILP];
OPTIM_T incoming_m[ILP];
OPTIM_T incoming_v[ILP];
GRAD_T incoming_g[ILP];
for(int i_start = 0;
......@@ -91,14 +91,16 @@ struct AdamFunctor
if(j < n && j < chunk_size) {
float scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
}
......@@ -135,6 +137,7 @@ void fused_adam_cuda(
size_t tl_sz = tensor_lists.size();
assert(tl_sz == 4 || tl_sz == 5);
assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type());
if(tl_sz == 5) {
// Mixed precision case
......@@ -146,7 +149,7 @@ void fused_adam_cuda(
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, float, at::Half>(),
AdamFunctor<5, float, at::Half, float>(),
beta1,
beta2,
eps,
......@@ -160,12 +163,13 @@ void fused_adam_cuda(
assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type());
if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) {
// Full precision case
assert(tensor_lists[1][0].scalar_type() == at::ScalarType::Float);
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, float, float>(),
AdamFunctor<4, float, float, float>(),
beta1,
beta2,
eps,
......@@ -175,21 +179,41 @@ void fused_adam_cuda(
decay
);
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
// "Memory Efficient Training" case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) {
// FP16 model parameters and gradients; FP32 optimizer state
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half, float>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else if (tensor_lists[1][0].scalar_type() == at::ScalarType::Half) {
// Pure FP16 case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
throw "Optimizer state must be of type float or half";
}
} else {
throw "Parameters must be of type float or half";
}
......
......@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
......@@ -15,6 +16,12 @@ else:
try:
from fairscale import fused_adam_cuda # type: ignore
class Precision(Enum):
FULL_PRECISION = auto()
MIXED_PRECISION = auto()
MEMORY_EFFICIENT_MIXED_PRECISION = auto()
PURE_FP16 = auto()
class Adam(torch.optim.Optimizer):
state: dict
defaults: dict
......@@ -39,6 +46,10 @@ try:
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
precision (Precision, optional): One of Precision.FULL_PRECISION,
Precision.MIXED_PRECISION, Precision.MEMORY_EFFICIENT_MIXED_PRECISION
or Precision.PURE_FP16. Inferred based on model parameter precision if
None. (default: None)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
......@@ -56,11 +67,24 @@ try:
weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False,
mixed_precision: Optional[bool] = False,
precision: Optional[Precision] = None,
):
self.mixed_precision = mixed_precision
parameters: List[Any] = list(params)
if precision is None:
precision = (
Precision.FULL_PRECISION if parameters[0].dtype == torch.float32 else Precision.MIXED_PRECISION
)
self.mixed_precision = False
if precision is Precision.MIXED_PRECISION:
self.mixed_precision = True
if precision is not Precision.FULL_PRECISION:
assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad:
......@@ -76,7 +100,8 @@ try:
super().__init__(parameters, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
if mixed_precision:
self.fp32_param_groups: List[Any] = []
if self.mixed_precision:
self._build_fp32_params(parameters)
def _build_fp32_params(self, params: Any) -> None:
......@@ -119,6 +144,21 @@ try:
def _step_supports_amp_scaling(self) -> bool:
return False
def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_type"] = self.optim_type
d["mixed_precision"] = self.mixed_precision
d["fp32_param_groups"] = self.fp32_param_groups
d["state"] = self.state
return d
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self.optim_type = state_dict["optim_type"]
self.mixed_precision = state_dict["mixed_precision"]
self.fp32_param_groups = state_dict["fp32_param_groups"]
self.state = state_dict["state"]
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
......@@ -161,9 +201,9 @@ try:
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32)
state["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32)
state["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
......
......@@ -10,7 +10,7 @@ import pytest
import torch
try:
from fairscale.optim.adam import Adam
from fairscale.optim.adam import Adam, Precision
imported_adam = True
except ImportError:
......@@ -20,17 +20,12 @@ skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")
@skip_if_no_cuda
@skip_if_no_adam
def test_step():
weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3)
def assert_almost_zero(x):
assert abs(x) < 2 * 1e-3
return 1.0
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def step_test(optimizer, weight, bias, input):
def fn():
optimizer.zero_grad()
y = weight.mv(input)
......@@ -44,71 +39,67 @@ def test_step():
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float32
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
def state_dict_test(optimizer, weight, bias, input):
def fn_base(optimizer, weight, bias, input):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
fn = functools.partial(fn_base, optimizer, weight, bias, input)
# Prime the optimizer
for _i in range(5):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict)
# Run both optimizations in parallel
for _i in range(5):
optimizer.step(fn)
optimizer_c.step(fn_c)
(weight - weight_c).to("cpu").detach().apply_(assert_almost_zero)
(bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda
@skip_if_no_adam
def test_step_me():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
def test_step_full_precision_inferred():
weight = torch.randn(10, 5).cuda().requires_grad_()
bias = torch.randn(10).cuda().requires_grad_()
input = torch.randn(5).cuda()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
assert p.dtype == torch.float32
assert not optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_mixed_precision():
def test_step_mixed_precision_inferred():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
......@@ -124,6 +115,44 @@ def test_step_mixed_precision():
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda
@skip_if_no_adam
def test_step_memory_efficient():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
assert not optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
assert not optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu():
......@@ -136,20 +165,7 @@ def test_step_multigpu():
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
step_test(optimizer, weight, bias, input)
@skip_if_no_cuda
......@@ -160,61 +176,75 @@ def test_step_multigpu_mixed_precision():
weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
bias = torch.randn(10).cuda(1).half().requires_grad_()
input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16_multigpu():
if not torch.cuda.device_count() > 1:
return
weight = torch.randn(10, 5).half().cuda(0).requires_grad_()
bias = torch.randn(10).half().cuda(1).requires_grad_()
input = torch.randn(5).half().cuda(0)
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
step_test(optimizer, weight, bias, input)
assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict():
def test_state_dict_full_precision():
weight = torch.randn(10, 5).float().cuda().requires_grad_()
bias = torch.randn(10).float().cuda().requires_grad_()
input = torch.randn(5).float().cuda()
optimizer = Adam([weight, bias], lr=1e-3)
def fn_base(optimizer, weight, bias, input):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
return loss
state_dict_test(optimizer, weight, bias, input)
fn = functools.partial(fn_base, optimizer, weight, bias, input)
# Prime the optimizer
for _i in range(5):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = weight.data.clone().requires_grad_()
bias_c = bias.data.clone().requires_grad_()
optimizer_c = Adam([weight_c, bias_c], lr=1e-3)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
for _i in range(5):
optimizer.step(fn)
optimizer_c.step(fn_c)
assert torch.equal(weight, weight_c)
assert torch.equal(bias, bias_c)
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_mixed_precision():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_memory_efficient():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_pure_fp16():
weight = torch.randn(10, 5).half().cuda().requires_grad_()
bias = torch.randn(10).half().cuda().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
......@@ -226,11 +256,6 @@ def test_build_fp32_params():
optimizer._build_fp32_params([weight, bias])
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16
......@@ -262,3 +287,30 @@ def test_amsgrad():
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(RuntimeError):
Adam([weight, bias], lr=1e-2, amsgrad=True)
@skip_if_no_cuda
@skip_if_no_adam
def test_mixed_precision_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.MIXED_PRECISION)
@skip_if_no_cuda
@skip_if_no_adam
def test_memory_efficient_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
@skip_if_no_cuda
@skip_if_no_adam
def test_pure_fp16_with_full_precision_parameters():
weight = torch.randn(10, 5, requires_grad=True).float().cuda()
bias = torch.randn(10, requires_grad=True).float().cuda()
with pytest.raises(AssertionError):
Adam([weight, bias], lr=1e-2, precision=Precision.PURE_FP16)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment