Unverified Commit e5369541 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Support `store_param_remainders` feature from Apex in TE Fused Adam (#1408)



* Initial commit
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Fixed compilation errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Fixed syntax errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed NaN issue when initial param value is zero
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Removed 64 bit indexing instantiation
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Made this feature an opt-in
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Removed arg from unscaled state
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Fixed compilation error
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cleaned up errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added support for checkpointing
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed checkpointing logic
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* Added tests
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added assert failure for capturable mode
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed pylint errors
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 96534aa5
......@@ -184,6 +184,7 @@ class TestFusedAdam(TestFusedOptimizer):
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
store_param_remainders=False,
model_rtol=None,
model_atol=None,
master_rtol=None,
......@@ -220,6 +221,7 @@ class TestFusedAdam(TestFusedOptimizer):
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
......@@ -228,6 +230,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
......@@ -237,7 +240,7 @@ class TestFusedAdam(TestFusedOptimizer):
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
if use_master_weights and not store_param_remainders:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
......@@ -270,6 +273,7 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
tst_optim.load_state_dict(state_dict)
......@@ -300,6 +304,19 @@ class TestFusedAdam(TestFusedOptimizer):
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
store_param_remainders=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
......
......@@ -479,6 +479,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const int step, const int mode, const int bias_correction,
const float weight_decay);
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay);
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
......
......@@ -179,6 +179,122 @@ struct AdamFunctorMaster {
}
};
template <typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctorMasterParamRemainder {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
int16_t *p = reinterpret_cast<int16_t *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;
FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;
int16_t *p_remainder = reinterpret_cast<int16_t *>(tl.addresses[4][tensor_loc]);
p_remainder += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
union fp32_or_int162 {
float fp32;
int16_t int16[2];
};
fp32_or_int162 local_master_param[ILP];
int16_t local_p[ILP];
int16_t local_p_rem[ILP];
MATH_T r_g[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
local_p[ii] = static_cast<int16_t>(p[i]);
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]);
} else {
r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
local_p[ii] = int16_t(0);
local_p_rem[ii] = int16_t(0);
}
}
// Reconstruct FP32 params
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding
local_master_param[ii].int16[1] = local_p[ii];
local_master_param[ii].int16[0] = local_p_rem[ii];
}
MATH_T *r_p = reinterpret_cast<MATH_T *>(local_master_param);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
// Split into BF16 params (rounded-to-nearest) and remainders
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
local_p[ii] = local_master_param[ii].int16[1];
local_p_rem[ii] = local_master_param[ii].int16[0];
if (local_p_rem[ii] < 0) local_p[ii]++; // Round up
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]);
p[i] = static_cast<int16_t>(local_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
}
};
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
......@@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16,
"Adam with BF16 param remainders requires BF16 params");
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
......
......@@ -213,6 +213,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
......
......@@ -94,6 +94,13 @@ class FusedAdam(torch.optim.Optimizer):
instead of ".grad" for reading gradients. It's useful when the dtypes
of grad and param are different.
(default: False)
store_param_remainders (bool, optional): Whether to store entire FP32 master
params or just store the trailing 16 remainder bits. Whole FP32 master can be
reconstructed from BF16 params plus the trailing remainder bits. Works only
when param type is BF16 and master weight type is FP32, no effect otherwise.
Useful memory saving optimization.
(default: False)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -118,6 +125,7 @@ class FusedAdam(torch.optim.Optimizer):
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
store_param_remainders=False,
):
if amsgrad:
......@@ -142,6 +150,8 @@ class FusedAdam(torch.optim.Optimizer):
raise RuntimeError("Capturable mode only supports fp32 exp_avg.")
if capturable and exp_avg_sq_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq")
if capturable and store_param_remainders:
raise RuntimeError("Capturable mode doesn't support storing param remainders")
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
......@@ -172,6 +182,7 @@ class FusedAdam(torch.optim.Optimizer):
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
self.multi_tensor_adam = tex.multi_tensor_adam
self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder
self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
......@@ -192,6 +203,10 @@ class FusedAdam(torch.optim.Optimizer):
}
self._scales = {}
self.use_decoupled_grad = use_decoupled_grad
# Works only when master params is in FP32
self.store_param_remainders = (
store_param_remainders and master_weights and master_weight_dtype == torch.float32
)
def zero_grad(self):
# pylint: disable=missing-function-docstring
......@@ -261,6 +276,13 @@ class FusedAdam(torch.optim.Optimizer):
unscaled = state[state_name].float()
unscaled.mul_(self._scales[param][state_name])
elif dtype == torch.float32:
if (
self.store_param_remainders
and state_name == "master_param"
and param.dtype == torch.bfloat16
):
assert state[state_name].dtype == torch.int16
else:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
else:
......@@ -279,10 +301,19 @@ class FusedAdam(torch.optim.Optimizer):
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""
store_param_remainders = (
self.store_param_remainders
and state_name == "master_param"
and param.dtype == torch.bfloat16
)
if store_param_remainders:
assert unscaled_state.dtype == torch.int16
else:
assert unscaled_state.dtype == torch.float32
state = self.state[param]
if state_name not in state:
self._initialize_state(param, state_name, False)
self._initialize_state(param, state_name, False, store_param_remainders)
dtype = self.name_to_dtype_map[state_name]
if dtype != torch.float32:
......@@ -291,7 +322,9 @@ class FusedAdam(torch.optim.Optimizer):
else:
state[state_name].copy_(unscaled_state)
def _initialize_state(self, param, state_name, zero_buffer: bool):
def _initialize_state(
self, param, state_name, zero_buffer: bool, store_param_remainders: bool = False
):
"""Initialize one of the optimizer states according to `state_name`.
Arguments:
......@@ -299,8 +332,12 @@ class FusedAdam(torch.optim.Optimizer):
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
zero_buffer (bool): Whether to initialize the optimizer state with zeros.
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
if store_param_remainders:
data = torch.zeros_like(param, dtype=torch.int16)
else:
data = torch.empty_like(param, dtype=dtype)
if zero_buffer:
data.zero_()
......@@ -322,16 +359,23 @@ class FusedAdam(torch.optim.Optimizer):
[1], dtype=torch.float32, device=param.device
)
def initialize_state(self, param):
def initialize_state(self, param, store_param_remainders):
"""Initialize optimizer states.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
store_param_remainders (bool): Store trailing remainder bits.
"""
self._initialize_state(param, "exp_avg", zero_buffer=True)
self._initialize_state(param, "exp_avg_sq", zero_buffer=True)
if self.master_weights:
self._initialize_state(param, "master_param", zero_buffer=False)
self._initialize_state(
param,
"master_param",
zero_buffer=False,
store_param_remainders=store_param_remainders,
)
if not store_param_remainders:
self.set_scaled_state(param, "master_param", param.clone().detach().float())
def state_dict(self):
......@@ -377,6 +421,14 @@ class FusedAdam(torch.optim.Optimizer):
param = id_map[k]
self.state[param] = {}
for name in v:
if (
self.store_param_remainders
and name == "master_param"
and param.dtype == torch.bfloat16
):
self.set_scaled_state(param, name, v[name])
assert v[name].dtype == torch.int16
else:
self.set_scaled_state(param, name, v[name].float())
def step(self, closure=None, grad_scaler=None):
......@@ -444,9 +496,11 @@ class FusedAdam(torch.optim.Optimizer):
for p in group["params"]:
state = self.state[p]
store_param_remainders = self.store_param_remainders and p.dtype == torch.bfloat16
# State initialization
if len(state) == 0:
self.initialize_state(p)
self.initialize_state(p, store_param_remainders)
if self.use_decoupled_grad:
p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None
......@@ -462,6 +516,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_state = {}
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if name in state:
if name == "master_param" and store_param_remainders:
unscaled_state[name] = self.state[p][name]
assert unscaled_state[name].dtype == torch.int16
else:
unscaled = self.get_unscaled_state(p, name)
unscaled_state[name] = unscaled
if self.name_to_dtype_map[name] != torch.float32:
......@@ -506,6 +564,12 @@ class FusedAdam(torch.optim.Optimizer):
)
if has_fp16 and has_bf16:
if self.store_param_remainders:
raise RuntimeError(
"FusedAdam doesn't support a mix of FP16/BF16 weights + Store param"
" remainder."
)
# simple to add support for this, but not needed for now
raise RuntimeError(
"FusedAdam does not support a mix of float16 and bfloat16 model weights."
......@@ -599,6 +663,13 @@ class FusedAdam(torch.optim.Optimizer):
v_of_f16_model,
p_main_of_f16_model,
]
if self.store_param_remainders and has_bf16 and not has_fp16:
# When you have BF16 params and need FP32 master params, you can reconstruct
# the FP32 master params with BF16 params + int16 remainders
apply_multi_tensor_adam(
self.multi_tensor_adam_param_remainder, tensor_lists
)
else:
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if len(p_fp8_model) > 0:
tensor_lists = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment