Unverified Commit 5251a69a authored by Jun Ru Anderson's avatar Jun Ru Anderson Committed by GitHub
Browse files

[feat] optimizer state scaling (#44)



Implement scaling of optimizer state when using pure-fp16 training to avoid underflow. Update benchmark to use pure-fp16. Modify state_dict methods to store and load the optimizer state scale.
Co-authored-by: default avatarJun Ru Anderson <andersonic@fb.com>
parent 46c3776b
...@@ -134,10 +134,10 @@ def make_model(device, ntokens): ...@@ -134,10 +134,10 @@ def make_model(device, ntokens):
p = Pipe(model, balance) p = Pipe(model, balance)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate lr = 0.0005 # learning rate
try: try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION) optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16)
except NameError: except NameError:
optimizer = Adam(p.parameters(), lr=lr) optimizer = Adam(p.parameters(), lr=lr)
...@@ -236,10 +236,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -236,10 +236,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
# Assert that memory usage on each GPU is within 10% of golden run # Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110% # Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1 assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 193206272 * 1.1
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1 assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1 assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1412608 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1 assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 95364608 * 1.1
print("No regression detected") print("No regression detected")
......
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, float optim_scale, at::Tensor& found_inf, int step, int mode, int bias_correction, float decay);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation."); m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation.");
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <assert.h> #include <assert.h>
#include <cmath> #include <cmath>
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
...@@ -31,6 +30,9 @@ struct AdamFunctor ...@@ -31,6 +30,9 @@ struct AdamFunctor
const float b2, const float b2,
const float eps, const float eps,
const float grad_scale, const float grad_scale,
const bool use_optim_scaling,
const float optim_scale,
float* found_inf_ptr,
const float step_size, const float step_size,
adamMode_t mode, adamMode_t mode,
const float decay) const float decay)
...@@ -90,19 +92,43 @@ struct AdamFunctor ...@@ -90,19 +92,43 @@ struct AdamFunctor
int j = i_start + threadIdx.x + ii*blockDim.x; int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) { if(j < n && j < chunk_size) {
float scaled_grad = incoming_g[ii]/grad_scale; if (use_optim_scaling) {
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad; // Optimizer state is in half precision and must be scaled
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; float scaled_grad = incoming_g[ii]/grad_scale;
m[j] = static_cast<OPTIM_T>(momentum); float momentum = b1 * (incoming_m[ii] / optim_scale) + (1-b1)*scaled_grad;
v[j] = static_cast<OPTIM_T>(velocity); float velocity = b2 * (incoming_v[ii] / optim_scale) + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0) m[j] = static_cast<OPTIM_T>(momentum * optim_scale);
denom = sqrtf(velocity + eps); v[j] = static_cast<OPTIM_T>(velocity * optim_scale);
else // Mode 1
denom = sqrtf(velocity) + eps; if (!isfinite(m[j]) || !isfinite(v[j])) {
float update = (momentum/denom) + (decay*incoming_p[ii]); *found_inf_ptr = 1.f;
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update)); }
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
} else {
// Optimizer state is in floating point precision
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
}
} }
} }
} }
...@@ -118,6 +144,8 @@ void fused_adam_cuda( ...@@ -118,6 +144,8 @@ void fused_adam_cuda(
float beta2, float beta2,
float eps, float eps,
float grad_scale, float grad_scale,
float optim_scale,
at::Tensor& found_inf,
int step, int step,
int mode, int mode,
int bias_correction, int bias_correction,
...@@ -139,6 +167,9 @@ void fused_adam_cuda( ...@@ -139,6 +167,9 @@ void fused_adam_cuda(
assert(tl_sz == 4 || tl_sz == 5); assert(tl_sz == 4 || tl_sz == 5);
assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type()); assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type());
bool use_optim_scaling = (tensor_lists[1][0].scalar_type() == at::ScalarType::Half);
float* found_inf_ptr = found_inf.data_ptr<float>();
if(tl_sz == 5) { if(tl_sz == 5) {
// Mixed precision case // Mixed precision case
assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float); assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
...@@ -154,6 +185,9 @@ void fused_adam_cuda( ...@@ -154,6 +185,9 @@ void fused_adam_cuda(
beta2, beta2,
eps, eps,
grad_scale, grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay decay
...@@ -174,13 +208,17 @@ void fused_adam_cuda( ...@@ -174,13 +208,17 @@ void fused_adam_cuda(
beta2, beta2,
eps, eps,
grad_scale, grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay decay
); );
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) { } else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) { if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) {
// FP16 model parameters and gradients; FP32 optimizer state // Memory-efficient mixed-precision case
// ie FP16 model parameters and gradients; FP32 optimizer state
multi_tensor_apply<4>( multi_tensor_apply<4>(
BLOCK_SIZE, BLOCK_SIZE,
chunk_size, chunk_size,
...@@ -191,6 +229,9 @@ void fused_adam_cuda( ...@@ -191,6 +229,9 @@ void fused_adam_cuda(
beta2, beta2,
eps, eps,
grad_scale, grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay decay
...@@ -207,6 +248,9 @@ void fused_adam_cuda( ...@@ -207,6 +248,9 @@ void fused_adam_cuda(
beta2, beta2,
eps, eps,
grad_scale, grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size, step_size,
(adamMode_t) mode, (adamMode_t) mode,
decay decay
......
...@@ -22,6 +22,23 @@ try: ...@@ -22,6 +22,23 @@ try:
MEMORY_EFFICIENT_MIXED_PRECISION = auto() MEMORY_EFFICIENT_MIXED_PRECISION = auto()
PURE_FP16 = auto() PURE_FP16 = auto()
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor):
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class Adam(torch.optim.Optimizer): class Adam(torch.optim.Optimizer):
state: dict state: dict
defaults: dict defaults: dict
...@@ -81,7 +98,9 @@ try: ...@@ -81,7 +98,9 @@ try:
assert parameters[0].dtype == torch.float16 assert parameters[0].dtype == torch.float16
self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32 self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32
self._optim_scale = float(2 ** 16) if precision is Precision.PURE_FP16 else 1.0
self._steps_since_optim_scale_change = 0
self._optim_scale_update_freq = 2000 # This is the value that GradScaler uses by default
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
if amsgrad: if amsgrad:
...@@ -145,8 +164,14 @@ try: ...@@ -145,8 +164,14 @@ try:
def mixed_precision(self) -> bool: def mixed_precision(self) -> bool:
return self.precision is Precision.MIXED_PRECISION return self.precision is Precision.MIXED_PRECISION
def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_scale"] = self._optim_scale
return d
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict) super().load_state_dict(state_dict)
self._optim_scale = state_dict["optim_scale"]
# TODO: Optimizer state gets cast to FP16 and back to FP32 for # TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision. Eventually # mixed-precision and memory-efficient mixed-precision. Eventually
...@@ -228,6 +253,9 @@ try: ...@@ -228,6 +253,9 @@ try:
for tl, t in zip(tensorlists[p.device], pl): for tl, t in zip(tensorlists[p.device], pl):
tl.append(t) tl.append(t)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=list(tensorlists.keys())[0])
per_device_found_inf = _MultiDeviceReplicator(found_inf)
for tensordevice, tensorlist in tensorlists.items(): for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice): with torch.cuda.device(tensordevice):
fused_adam_cuda.adam( fused_adam_cuda.adam(
...@@ -239,12 +267,33 @@ try: ...@@ -239,12 +267,33 @@ try:
beta2, beta2,
group["eps"], group["eps"],
scale, scale,
self._optim_scale,
per_device_found_inf.get(tensordevice),
state["step"], state["step"],
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group["weight_decay"], group["weight_decay"],
) )
if sum(v.item() for v in per_device_found_inf._per_device_tensors.values()):
self._steps_since_optim_scale_change = 0
self._optim_scale /= 2
if self._optim_scale < 1.0:
raise RuntimeError("Optimizer state scale < 1. This may mean that gradients are exploding")
for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
self.state[p]["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
else:
self._steps_since_optim_scale_change += 1
if self._steps_since_optim_scale_change == self._optim_scale_update_freq:
self._steps_since_optim_scale_change = 0
if self._optim_scale < 2 ** 16:
self._optim_scale *= 2
return loss return loss
......
...@@ -285,6 +285,38 @@ def test_state_dict_pure_fp16(): ...@@ -285,6 +285,38 @@ def test_state_dict_pure_fp16():
state_dict_test(optimizer, weight, bias, input) state_dict_test(optimizer, weight, bias, input)
@skip_if_no_cuda
@skip_if_no_adam
def test_update_optim_scale():
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale_update_freq = 1
optimizer._optim_scale = 2 ** 15
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
optimizer.step()
assert optimizer._optim_scale == 2 ** 16
@skip_if_no_cuda
@skip_if_no_adam
def test_exploding_optimizer_state():
weight = torch.tensor([[float("inf")]]).half().cuda().requires_grad_()
input = torch.tensor([1.0]).half().cuda().requires_grad_()
optimizer = Adam([weight], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale = 1.0
optimizer.zero_grad()
loss = (weight.mv(input)).pow(2).sum()
loss.backward()
with pytest.raises(RuntimeError):
optimizer.step()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_no_adam @skip_if_no_adam
def test_build_fp32_params(): def test_build_fp32_params():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment