Unverified Commit 05c0fb02 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam (#1078)



* Add precision aware fused adam
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Minor changes based on review comments.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 23caab3f
......@@ -4,6 +4,7 @@
from itertools import product
import copy
from contextlib import nullcontext
import pytest
import torch
......@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param)
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
tst_optim.load_state_dict(state_dict)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
......@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
......@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
......@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
......@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
......
......@@ -179,7 +179,7 @@ struct AdamFunctorMaster {
}
};
template <typename T, typename FULL_T, typename index_t>
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
......@@ -199,10 +199,10 @@ struct AdamFunctor {
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
......@@ -223,10 +223,10 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
......@@ -259,9 +259,9 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
......@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
......@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment