Unverified Commit 5251a69a authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] optimizer state scaling (#44)



Implement scaling of optimizer state when using pure-fp16 training to avoid underflow. Update benchmark to use pure-fp16. Modify state_dict methods to store and load the optimizer state scale.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 46c3776b
......@@ -134,10 +134,10 @@ def make_model(device, ntokens):
p = Pipe(model, balance)
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
lr = 0.0005 # learning rate
try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16)
except NameError:
optimizer = Adam(p.parameters(), lr=lr)
......@@ -236,10 +236,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 193206272 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1412608 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 95364608 * 1.1
print("No regression detected")
......
#include <torch/extension.h>
// CUDA forward declaration
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, float optim_scale, at::Tensor& found_inf, int step, int mode, int bias_correction, float decay);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation.");
......
......@@ -7,7 +7,6 @@
#include <assert.h>
#include <cmath>
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
......@@ -31,6 +30,9 @@ struct AdamFunctor
const float b2,
const float eps,
const float grad_scale,
const bool use_optim_scaling,
const float optim_scale,
float* found_inf_ptr,
const float step_size,
adamMode_t mode,
const float decay)
......@@ -90,19 +92,43 @@ struct AdamFunctor
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
if (use_optim_scaling) {
// Optimizer state is in half precision and must be scaled
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * (incoming_m[ii] / optim_scale) + (1-b1)*scaled_grad;
float velocity = b2 * (incoming_v[ii] / optim_scale) + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum * optim_scale);
v[j] = static_cast<OPTIM_T>(velocity * optim_scale);
if (!isfinite(m[j]) || !isfinite(v[j])) {
*found_inf_ptr = 1.f;
}
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
} else {
// Optimizer state is in floating point precision
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
}
}
}
}
......@@ -118,6 +144,8 @@ void fused_adam_cuda(
float beta2,
float eps,
float grad_scale,
float optim_scale,
at::Tensor& found_inf,
int step,
int mode,
int bias_correction,
......@@ -139,6 +167,9 @@ void fused_adam_cuda(
assert(tl_sz == 4 || tl_sz == 5);
assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type());
bool use_optim_scaling = (tensor_lists[1][0].scalar_type() == at::ScalarType::Half);
float* found_inf_ptr = found_inf.data_ptr<float>();
if(tl_sz == 5) {
// Mixed precision case
assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
......@@ -154,6 +185,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
......@@ -174,13 +208,17 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
);
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) {
// FP16 model parameters and gradients; FP32 optimizer state
// Memory-efficient mixed-precision case
// ie FP16 model parameters and gradients; FP32 optimizer state
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
......@@ -191,6 +229,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
......@@ -207,6 +248,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
......
......@@ -22,6 +22,23 @@ try:
MEMORY_EFFICIENT_MIXED_PRECISION = auto()
PURE_FP16 = auto()
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor):
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class Adam(torch.optim.Optimizer):
state: dict
defaults: dict
......@@ -81,7 +98,9 @@ try:
assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
self._optim_scale = float(2 ** 16) if precision is Precision.PURE_FP16 else 1.0
self._steps_since_optim_scale_change = 0
self._optim_scale_update_freq = 2000 # This is the value that GradScaler uses by default
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad:
......@@ -145,8 +164,14 @@ try:
def mixed_precision(self) -> bool:
return self.precision is Precision.MIXED_PRECISION
def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_scale"] = self._optim_scale
return d
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._optim_scale = state_dict["optim_scale"]
# TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision. Eventually
......@@ -228,6 +253,9 @@ try:
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=list(tensorlists.keys())[0])
per_device_found_inf = _MultiDeviceReplicator(found_inf)
for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
fused_adam_cuda.adam(
......@@ -239,12 +267,33 @@ try:
beta2,
group["eps"],
scale,
self._optim_scale,
per_device_found_inf.get(tensordevice),
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)
if sum(v.item() for v in per_device_found_inf._per_device_tensors.values()):
self._steps_since_optim_scale_change = 0
self._optim_scale /= 2
if self._optim_scale < 1.0:
raise RuntimeError("Optimizer state scale < 1. This may mean that gradients are exploding")
for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
self.state[p]["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
else:
self._steps_since_optim_scale_change += 1
if self._steps_since_optim_scale_change == self._optim_scale_update_freq:
self._steps_since_optim_scale_change = 0
if self._optim_scale < 2 ** 16:
self._optim_scale *= 2
return loss
......
......@@ -285,6 +285,38 @@ def test_state_dict_pure_fp16():
state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
@skip_if_no_adam
def test_update_optim_scale():
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale_update_freq = 1
optimizer._optim_scale = 2 ** 15
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
optimizer.step()
assert optimizer._optim_scale == 2 ** 16
@skip_if_no_cuda
@skip_if_no_adam
def test_exploding_optimizer_state():
weight = torch.tensor([[float("inf")]]).half().cuda().requires_grad_()
input = torch.tensor([1.0]).half().cuda().requires_grad_()
optimizer = Adam([weight], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale = 1.0
optimizer.zero_grad()
loss = (weight.mv(input)).pow(2).sum()
loss.backward()
with pytest.raises(RuntimeError):
optimizer.step()
@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment