Unverified Commit e2d8f573 authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] add mixed precision Adam (#40)



Add support for mixed-precision (half precision params, full precision gradients) and memory-efficient (half precision params and half precision gradients) training with Adam
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 585f177b
......@@ -8,7 +8,7 @@ import torch.nn as nn
import torchtext
from torchtext.data.utils import get_tokenizer
import fairscale.nn.pipe.pipe as pipe
from fairscale.nn import Pipe
try:
from fairscale.optim.adam import Adam # type: ignore
......@@ -129,13 +129,15 @@ def make_model(device, ntokens):
dropout = 0
initrange = 0.1
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device)
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
balance = generate_balance(min(num_devices, 4), len(model))
p = Pipe(model, balance)
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
optimizer = Adam(model.parameters(), lr=lr)
optimizer = Adam(p.parameters(), lr=lr, mixed_precision=True)
return model, criterion, optimizer
return p, criterion, optimizer
def train(train_data, model, criterion, optimizer, bptt, ntokens):
......@@ -221,7 +223,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
if can_benchmark and len(model.balance) == 4:
# Assert that words per second is within 3 standard deviations of the average
# of six golden runs
assert wps > 20052.1 - (3 * 359)
assert wps > 27799.2 - (3 * 522.145)
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
......@@ -230,10 +232,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365916160 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1
print("No regression detected")
......@@ -259,7 +261,5 @@ if __name__ == "__main__":
device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer = make_model(device, ntokens)
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(model, balance)
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens)
del p
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens)
del model
......@@ -4,6 +4,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <assert.h>
#include <cmath>
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
......@@ -19,9 +20,7 @@ typedef enum{
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <int DEPTH, typename T, typename GRAD_T>
template <int DEPTH, typename PARAM_T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
......@@ -40,26 +39,26 @@ struct AdamFunctor
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
GRAD_T* p = (GRAD_T *)tl.addresses[0][tensor_loc];
PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
float* m = (float *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
float* v = (float *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
at::Half* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy = (at::Half*)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
PARAM_T incoming_p[ILP];
float incoming_m[ILP];
float incoming_v[ILP];
GRAD_T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
......@@ -74,10 +73,10 @@ struct AdamFunctor
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = static_cast<T>(p[i]);
incoming_p[ii] = static_cast<PARAM_T>(p[i]);
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
incoming_g[ii] = static_cast<GRAD_T>(g[i]);
}
}
......@@ -91,7 +90,7 @@ struct AdamFunctor
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
float scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
......@@ -100,8 +99,8 @@ struct AdamFunctor
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = (GRAD_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = p[j];
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
}
}
}
......@@ -135,24 +134,65 @@ void fused_adam_cuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4, "expected tensor lists of size 4");
// check that the model and gradients are FP32
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
AT_ASSERTM(tensor_lists[3][0].scalar_type() == at::ScalarType::Float);
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, float, float>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
assert(tl_sz == 4 || tl_sz == 5);
if(tl_sz == 5) {
// Mixed precision case
assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
assert(tensor_lists[3][0].scalar_type() == at::ScalarType::Half);
assert(tensor_lists[4][0].scalar_type() == at::ScalarType::Half);
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, float, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
// tl_sz == 4
assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type());
if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) {
// Full precision case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, float, float>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
// "Memory Efficient Training" case
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, at::Half, at::Half>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay
);
} else {
throw "Parameters must be of type float or half";
}
}
THCudaCheck(cudaGetLastError());
}
......@@ -17,6 +17,7 @@ try:
class Adam(torch.optim.Optimizer):
state: dict
defaults: dict
"""
Implements Adam algorithm. Currently GPU-only.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
......@@ -55,9 +56,11 @@ try:
weight_decay: Optional[float] = 0.0,
max_grad_norm: Optional[float] = 0.0,
amsgrad: Optional[bool] = False,
mixed_precision: Optional[bool] = False,
):
self.mixed_precision = mixed_precision
parameters: List[Any] = list(params)
self._use_multi_tensor = False
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad:
......@@ -70,14 +73,53 @@ try:
"weight_decay": weight_decay,
"max_grad_norm": max_grad_norm,
}
super().__init__(params, defaults)
super().__init__(parameters, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
if mixed_precision:
self._build_fp32_params(parameters)
def _build_fp32_params(self, params: Any) -> None:
# create FP32 copy of parameters and grads
fp32_params = []
for p in params:
p32 = torch.nn.Parameter(p.data.float()).to(p.device)
p32.grad = torch.zeros_like(p32.data)
fp32_params.append(p32)
params = fp32_params
self.fp32_param_groups = []
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
for param_group in param_groups:
params = param_group["params"]
if isinstance(params, torch.Tensor):
param_group["params"] = [params]
else:
param_group["params"] = list(params)
for name, default in self.defaults.items():
param_group.setdefault(name, default)
params = param_group["params"]
param_set = set()
for group in self.param_groups:
param_set.update(set(group["params"]))
self.fp32_param_groups.append(param_group)
@property
def supports_memory_efficient_fp16(self) -> bool:
return True
def step(self, closure: Optional[Callable[[], float]] = None, scale: Optional[float] = 1.0) -> Optional[float]:
@property
def _step_supports_amp_scaling(self) -> bool:
return False
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
......@@ -95,11 +137,13 @@ try:
if closure is not None:
loss = closure()
for group in self.param_groups:
for i in range(len(self.param_groups)):
group = self.param_groups[i]
bias_correction = 1 if group["bias_correction"] else 0
tensorlists: Dict[torch.device, List[List[torch.Tensor]]] = dict()
for p in group["params"]:
for j in range(len(group["params"])):
p = group["params"][j]
# note: p.grad should not ever be set for correct
# operation of mixed precision optimizer that sometimes
# sends None gradients
......@@ -126,15 +170,26 @@ try:
beta1, beta2 = group["betas"]
state["step"] += 1
out_p = torch.tensor([])
out_p = p.data if self.mixed_precision else torch.tensor([])
param = self.fp32_param_groups[i]["params"][j] if self.mixed_precision else p
scale = 1.0
if self.mixed_precision:
pl = [param.data, exp_avg, exp_avg_sq, grad, out_p]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], [], []]
pl = [p.data, exp_avg, exp_avg_sq, grad]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
else:
pl = [param.data, exp_avg, exp_avg_sq, grad]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], []]
if p.device not in tensorlists:
tensorlists[p.device] = [[], [], [], []]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
......
......@@ -44,6 +44,84 @@ def test_step():
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float32
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_me():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
for group in optimizer.param_groups:
for p in group["params"]:
if p.requires_grad:
assert p.dtype == torch.float16
with pytest.raises(AttributeError):
optimizer.fp32_param_groups
@skip_if_no_cuda
@skip_if_no_adam
def test_step_mixed_precision():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
input = torch.randn(5).half().cuda()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda
......@@ -74,6 +152,34 @@ def test_step_multigpu():
assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu_mixed_precision():
if not torch.cuda.device_count() > 1:
return
weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
bias = torch.randn(10).cuda(1).half().requires_grad_()
input = torch.randn(5).cuda(0).half()
optimizer = Adam([weight, bias], lr=1e-3, mixed_precision=True)
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _i in range(5):
optimizer.step(fn)
assert fn().item() < initial_value
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict():
......@@ -111,6 +217,26 @@ def test_state_dict():
assert torch.equal(bias, bias_c)
@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
weight = torch.randn(10, 5).cuda().half().requires_grad_()
bias = torch.randn(10).cuda().half().requires_grad_()
optimizer = Adam([weight, bias], lr=1e-3)
optimizer._build_fp32_params([weight, bias])
for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
def assert_almost_zero(x):
assert abs(x) < 1e-3
return 1.0
assert fp32_p.dtype == torch.float32
if fp16_p.requires_grad:
assert fp16_p.dtype == torch.float16
(fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_beta():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment