Unverified Commit e2d8f573 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] add mixed precision Adam (#40)



Add support for mixed-precision (half precision params, full precision gradients) and memory-efficient (half precision params and half precision gradients) training with Adam
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 585f177b
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
import fairscale.nn.pipe.pipe as pipe from fairscale.nn import Pipe
try: try:
from fairscale.optim.adam import Adam # type: ignore from fairscale.optim.adam import Adam # type: ignore
...@@ -129,13 +129,15 @@ def make_model(device, ntokens): ...@@ -129,13 +129,15 @@ def make_model(device, ntokens):
dropout = 0 dropout = 0
initrange = 0.1 initrange = 0.1
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device) model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
balance = generate_balance(min(num_devices, 4), len(model))
p = Pipe(model, balance)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate lr = 0.01 # learning rate
optimizer = Adam(model.parameters(), lr=lr) optimizer = Adam(p.parameters(), lr=lr, mixed_precision=True)
return model, criterion, optimizer return p, criterion, optimizer
def train(train_data, model, criterion, optimizer, bptt, ntokens): def train(train_data, model, criterion, optimizer, bptt, ntokens):
...@@ -221,7 +223,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -221,7 +223,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
if can_benchmark and len(model.balance) == 4: if can_benchmark and len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average # Assert that words per second is within 3 standard deviations of the average
# of six golden runs # of six golden runs
assert wps > 20052.1 - (3 * 359) assert wps > 27799.2 - (3 * 522.145)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"])) print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"])) print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
...@@ -230,10 +232,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -230,10 +232,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
# Assert that memory usage on each GPU is within 10% of golden run # Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110% # Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365916160 * 1.1 assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1 assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1 assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1 assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1
print("No regression detected") print("No regression detected")
...@@ -259,7 +261,5 @@ if __name__ == "__main__": ...@@ -259,7 +261,5 @@ if __name__ == "__main__":
device = torch.device("cuda") device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device) ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer = make_model(device, ntokens) model, criterion, optimizer = make_model(device, ntokens)
balance = generate_balance(min(num_devices, 4), len(model)) benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens)
p = pipe.Pipe(model, balance) del model
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens)
del p
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
#include <assert.h>
#include <cmath> #include <cmath>
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
// #include "ATen/Type.h" // #include "ATen/Type.h"
...@@ -19,9 +20,7 @@ typedef enum{ ...@@ -19,9 +20,7 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root ADAM_MODE_1 =1 // eps outside square root
} adamMode_t; } adamMode_t;
template <int DEPTH, typename PARAM_T, typename GRAD_T>
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor struct AdamFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
...@@ -40,26 +39,26 @@ struct AdamFunctor ...@@ -40,26 +39,26 @@ struct AdamFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
GRAD_T* p = (GRAD_T *)tl.addresses[0][tensor_loc]; PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size; p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc]; float* m = (float *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size; m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc]; float* v = (float *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size; v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size; g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL; at::Half* p_copy = NULL;
if (DEPTH == 5) { if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; p_copy = (at::Half*)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size; p_copy += chunk_idx*chunk_size;
} }
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
T incoming_p[ILP]; PARAM_T incoming_p[ILP];
T incoming_m[ILP]; float incoming_m[ILP];
T incoming_v[ILP]; float incoming_v[ILP];
T incoming_g[ILP]; GRAD_T incoming_g[ILP];
for(int i_start = 0; for(int i_start = 0;
i_start < n && i_start < chunk_size; i_start < n && i_start < chunk_size;
...@@ -74,10 +73,10 @@ struct AdamFunctor ...@@ -74,10 +73,10 @@ struct AdamFunctor
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
incoming_p[ii] = static_cast<T>(p[i]); incoming_p[ii] = static_cast<PARAM_T>(p[i]);
incoming_m[ii] = m[i]; incoming_m[ii] = m[i];
incoming_v[ii] = v[i]; incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]); incoming_g[ii] = static_cast<GRAD_T>(g[i]);
} }
} }
...@@ -91,7 +90,7 @@ struct AdamFunctor ...@@ -91,7 +90,7 @@ struct AdamFunctor
int j = i_start + threadIdx.x + ii*blockDim.x; int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) { if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale; float scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom; float denom;
...@@ -100,8 +99,8 @@ struct AdamFunctor ...@@ -100,8 +99,8 @@ struct AdamFunctor
else // Mode 1 else // Mode 1
denom = sqrtf(v[j]) + eps; denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]); float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = (GRAD_T)(incoming_p[ii] - (step_size*update)); p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = p[j]; if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
} }
} }
} }
...@@ -135,11 +134,32 @@ void fused_adam_cuda( ...@@ -135,11 +134,32 @@ void fused_adam_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size(); size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4, "expected tensor lists of size 4"); assert(tl_sz == 4 || tl_sz == 5);
// check that the model and gradients are FP32 if(tl_sz == 5) {
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float); // Mixed precision case
AT_ASSERTM(tensor_lists[3][0].scalar_type() == at::ScalarType::Float); assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
assert(tensor_lists[3][0].scalar_type() == at::ScalarType::Half);
assert(tensor_lists[4][0].scalar_type() == at::ScalarType::Half);
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, float, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
// tl_sz == 4
assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type());
if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) {
// Full precision case
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -154,5 +174,25 @@ void fused_adam_cuda( ...@@ -154,5 +174,25 @@ void fused_adam_cuda(
(adamMode_t) mode, (adamMode_t) mode,
decay decay
); );
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
// "Memory Efficient Training" case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
throw "Parameters must be of type float or half";
}
}
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
...@@ -17,6 +17,7 @@ try: ...@@ -17,6 +17,7 @@ try:
class Adam(torch.optim.Optimizer): class Adam(torch.optim.Optimizer):
state: dict state: dict
defaults: dict
""" """
Implements Adam algorithm. Currently GPU-only. Implements Adam algorithm. Currently GPU-only.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. It has been proposed in `Adam: A Method for Stochastic Optimization`_.
...@@ -55,9 +56,11 @@ try: ...@@ -55,9 +56,11 @@ try:
weight_decay: Optional[float] = 0.0, weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0, max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False, amsgrad: Optional[bool] = False,
mixed_precision: Optional[bool] = False,
): ):
self.mixed_precision = mixed_precision
parameters: List[Any] = list(params)
self._use_multi_tensor = False
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad: if amsgrad:
...@@ -70,14 +73,53 @@ try: ...@@ -70,14 +73,53 @@ try:
"weight_decay": weight_decay, "weight_decay": weight_decay,
"max_grad_norm": max_grad_norm, "max_grad_norm": max_grad_norm,
} }
super().__init__(params, defaults) super().__init__(parameters, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
if mixed_precision:
self._build_fp32_params(parameters)
def _build_fp32_params(self, params: Any) -> None:
# create FP32 copy of parameters and grads
fp32_params = []
for p in params:
p32 = torch.nn.Parameter(p.data.float()).to(p.device)
p32.grad = torch.zeros_like(p32.data)
fp32_params.append(p32)
params = fp32_params
self.fp32_param_groups = []
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
for param_group in param_groups:
params = param_group["params"]
if isinstance(params, torch.Tensor):
param_group["params"] = [params]
else:
param_group["params"] = list(params)
for name, default in self.defaults.items():
param_group.setdefault(name, default)
params = param_group["params"]
param_set = set()
for group in self.param_groups:
param_set.update(set(group["params"]))
self.fp32_param_groups.append(param_group)
@property @property
def supports_memory_efficient_fp16(self) -> bool: def supports_memory_efficient_fp16(self) -> bool:
return True return True
def step(self, closure: Optional[Callable[[], float]] = None, scale: Optional[float] = 1.0) -> Optional[float]: @property
def _step_supports_amp_scaling(self) -> bool:
return False
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
closure (callable, optional): A closure that reevaluates the model closure (callable, optional): A closure that reevaluates the model
...@@ -95,11 +137,13 @@ try: ...@@ -95,11 +137,13 @@ try:
if closure is not None: if closure is not None:
loss = closure() loss = closure()
for group in self.param_groups: for i in range(len(self.param_groups)):
group = self.param_groups[i]
bias_correction = 1 if group["bias_correction"] else 0 bias_correction = 1 if group["bias_correction"] else 0
tensorlists: Dict[torch.device, List[List[torch.Tensor]]] = dict() tensorlists: Dict[torch.device, List[List[torch.Tensor]]] = dict()
for p in group["params"]: for j in range(len(group["params"])):
p = group["params"][j]
# note: p.grad should not ever be set for correct # note: p.grad should not ever be set for correct
# operation of mixed precision optimizer that sometimes # operation of mixed precision optimizer that sometimes
# sends None gradients # sends None gradients
...@@ -126,9 +170,20 @@ try: ...@@ -126,9 +170,20 @@ try:
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
state["step"] += 1 state["step"] += 1
out_p = torch.tensor([]) out_p = p.data if self.mixed_precision else torch.tensor([])
param = self.fp32_param_groups[i]["params"][j] if self.mixed_precision else p
scale = 1.0
pl = [p.data, exp_avg, exp_avg_sq, grad] if self.mixed_precision:
pl = [param.data, exp_avg, exp_avg_sq, grad, out_p]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
else:
pl = [param.data, exp_avg, exp_avg_sq, grad]
if p.device not in tensorlists: if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], []] tensorlists[p.device] = [[], [], [], []]
......
...@@ -44,6 +44,84 @@ def test_step(): ...@@ -44,6 +44,84 @@ def test_step():
for _i in range(5): for _i in range(5):
optimizer.step(fn) optimizer.step(fn)
assert fn().item() < initial_value assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float32
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_me():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_mixed_precision():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda @skip_if_no_cuda
...@@ -74,6 +152,34 @@ def test_step_multigpu(): ...@@ -74,6 +152,34 @@ def test_step_multigpu():
assert fn().item() < initial_value assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu_mixed_precision():
if not torch.cuda.device_count() > 1:
return
weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
bias = torch.randn(10).cuda(1).half().requires_grad_()
input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_state_dict(): def test_state_dict():
...@@ -111,6 +217,26 @@ def test_state_dict(): ...@@ -111,6 +217,26 @@ def test_state_dict():
assert torch.equal(bias, bias_c) assert torch.equal(bias, bias_c)
@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
optimizer = Adam([weight, bias], lr=1e-3)
optimizer._build_fp32_params([weight, bias])
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_invalid_beta(): def test_invalid_beta():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment