"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "4b1ac23f53d0e714a4a48d2c8058438405c0fd07"
Unverified Commit 05c0fb02 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam (#1078)



* Add precision aware fused adam
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Minor changes based on review comments.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 23caab3f
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from itertools import product from itertools import product
import copy import copy
from contextlib import nullcontext
import pytest import pytest
import torch import torch
...@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param) torch.testing.assert_close(ref_param, tst_param)
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
tst_optim.load_state_dict(state_dict)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self): def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True, fuse_qkv_params=True,
).cuda() ).cuda()
ref_params = [] ref_params = []
master_params = []
model_params = [] model_params = []
for p in model.parameters(): for p in model.parameters():
if p.requires_grad: if p.requires_grad:
ref_params.append(p.detach().clone().float()) ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p) model_params.append(p)
options = { options = {
"lr": 5e-4, "lr": 5e-4,
...@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False, "amsgrad": False,
} }
ref_optim = torch.optim.Adam(ref_params, **options) ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_params, master_params) for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params) torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params] model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close( torch.testing.assert_close(
...@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True, fuse_qkv_params=True,
).cuda() ).cuda()
ref_params = [] ref_params = []
master_params = []
model_params = [] model_params = []
for p in model.parameters(): for p in model.parameters():
if p.requires_grad: if p.requires_grad:
ref_params.append(p.detach().clone().float()) ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p) model_params.append(p)
options = { options = {
"lr": 5e-4, "lr": 5e-4,
...@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False, "amsgrad": False,
} }
ref_optim = torch.optim.Adam(ref_params, **options) ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_params, master_params) for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params) torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params] model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close( torch.testing.assert_close(
......
...@@ -179,7 +179,7 @@ struct AdamFunctorMaster { ...@@ -179,7 +179,7 @@ struct AdamFunctorMaster {
} }
}; };
template <typename T, typename FULL_T, typename index_t> template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor { struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*) TensorListMetadata<4> &tl, // NOLINT(*)
...@@ -199,10 +199,10 @@ struct AdamFunctor { ...@@ -199,10 +199,10 @@ struct AdamFunctor {
index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc]; index_t n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]); GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size; g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]); PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size; p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]); FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
...@@ -223,10 +223,10 @@ struct AdamFunctor { ...@@ -223,10 +223,10 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
r_g[ii] = g[i]; r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = p[i]; r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = m[i]; r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = v[i]; r_v[ii] = static_cast<MATH_T>(v[i]);
} else { } else {
r_g[ii] = MATH_T(0); r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0); r_p[ii] = MATH_T(0);
...@@ -259,9 +259,9 @@ struct AdamFunctor { ...@@ -259,9 +259,9 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p[i] = r_p[ii]; p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = r_m[ii]; m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = r_v[ii]; v[i] = static_cast<FULL_T>(r_v[ii]);
} }
} }
} }
...@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size(); auto tl_size = tensor_lists.size();
...@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2, g_in_type, 1, "adam",
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
weight_decay);) tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
...@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2, AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, beta2, bias_correction1, bias_correction2, epsilon, lr,
weight_decay);) (adamMode_t)mode, weight_decay);));
} else { } else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment