Unverified Commit 05c0fb02 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam (#1078)



* Add precision aware fused adam
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Minor changes based on review comments.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 23caab3f
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from itertools import product from itertools import product
import copy import copy
from contextlib import nullcontext
import pytest import pytest
import torch import torch
...@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param) torch.testing.assert_close(ref_param, tst_param)
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
tst_optim.load_state_dict(state_dict)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self): def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True, fuse_qkv_params=True,
).cuda() ).cuda()
ref_params = [] ref_params = []
master_params = []
model_params = [] model_params = []
for p in model.parameters(): for p in model.parameters():
if p.requires_grad: if p.requires_grad:
ref_params.append(p.detach().clone().float()) ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p) model_params.append(p)
options = { options = {
"lr": 5e-4, "lr": 5e-4,
...@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False, "amsgrad": False,
} }
ref_optim = torch.optim.Adam(ref_params, **options) ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_params, master_params) for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params) torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params] model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close( torch.testing.assert_close(
...@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True, fuse_qkv_params=True,
).cuda() ).cuda()
ref_params = [] ref_params = []
master_params = []
model_params = [] model_params = []
for p in model.parameters(): for p in model.parameters():
if p.requires_grad: if p.requires_grad:
ref_params.append(p.detach().clone().float()) ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p) model_params.append(p)
options = { options = {
"lr": 5e-4, "lr": 5e-4,
...@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False, "amsgrad": False,
} }
ref_optim = torch.optim.Adam(ref_params, **options) ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters): for i in range(self.iters):
self.gen_grad(ref_params, master_params) for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step() ref_optim.step()
tst_optim.step() tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params) torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params] model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close( torch.testing.assert_close(
......
...@@ -179,7 +179,7 @@ struct AdamFunctorMaster { ...@@ -179,7 +179,7 @@ struct AdamFunctorMaster {
} }
}; };
template <typename T, typename FULL_T, typename index_t> template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor { struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*) TensorListMetadata<4> &tl, // NOLINT(*)
...@@ -199,10 +199,10 @@ struct AdamFunctor { ...@@ -199,10 +199,10 @@ struct AdamFunctor {
index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc]; index_t n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]); GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size; g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]); PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size; p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]); FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
...@@ -223,10 +223,10 @@ struct AdamFunctor { ...@@ -223,10 +223,10 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
r_g[ii] = g[i]; r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = p[i]; r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = m[i]; r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = v[i]; r_v[ii] = static_cast<MATH_T>(v[i]);
} else { } else {
r_g[ii] = MATH_T(0); r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0); r_p[ii] = MATH_T(0);
...@@ -259,9 +259,9 @@ struct AdamFunctor { ...@@ -259,9 +259,9 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
p[i] = r_p[ii]; p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = r_m[ii]; m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = r_v[ii]; v[i] = static_cast<FULL_T>(r_v[ii]);
} }
} }
} }
...@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size(); auto tl_size = tensor_lists.size();
...@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2, g_in_type, 1, "adam",
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
weight_decay);) tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
...@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2, g_in_type, 1, "adam",
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
weight_decay);) AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else { } else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Fused Adam optimizer.""" """Fused Adam optimizer."""
from copy import deepcopy
from itertools import chain
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from ..float8_tensor import Float8Tensor
def get_fp8_meta(fp8_tensor): def get_fp8_meta(fp8_tensor):
...@@ -68,11 +72,28 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -68,11 +72,28 @@ class FusedAdam(torch.optim.Optimizer):
method is called. (default: True) method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False) that can be used with CUDA Graphs. (default: False)
master_weights (list of torch.Tensor, optional): master weights to use master_weights (bool, optional): whether to maintain FP32 master weights
for mixed precision training. If provided, the optimizer will update in the optimizer with FP16/BF16 mixed precision training.
the master weights and then cast the master weights to the model weights. (default: False)
If not provided, the optimizer will update the model weights directly. master_weight_dtype (torch.dtype, optional): The dtype of master weights.
(default: None) If master_weights is False, this will be ignored. It can be one of
[torch.float32, torch.float16]. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg. It can be
one of [torch.float32, torch.float16, torch.uint8], where torch.uint8
represents FP8. If it's not torch.float32, the optimizer will create
a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq. It
can be one of [torch.float32, torch.float16, torch.uint8], where
torch.uint8 represents FP8. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
use_decoupled_grad (bool, optional): Whether to use ".decoupled_grad"
instead of ".grad" for reading gradients. It's useful when the dtypes
of grad and param are different.
(default: False)
.. _Adam - A Method for Stochastic Optimization: .. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -92,12 +113,36 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -92,12 +113,36 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad=False, amsgrad=False,
set_grad_none=True, set_grad_none=True,
capturable=False, capturable=False,
master_weights=None, master_weights=False,
master_weight_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
): ):
if amsgrad: if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.") raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
# Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.")
# Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
# they will be copied to temporary fp32 buffers first. These fp32 buffers are then
# used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs,
# making CUDA Graph inapplicable in this scenario.
if capturable and master_weights and master_weight_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 master weights.")
if capturable and exp_avg_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 exp_avg.")
if capturable and exp_avg_sq_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq")
# If the optimizer is capturable then LR should be a tensor (on GPU) # If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = { defaults = {
...@@ -112,9 +157,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -112,9 +157,6 @@ class FusedAdam(torch.optim.Optimizer):
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
self.capturable = capturable self.capturable = capturable
if master_weights is not None:
assert isinstance(master_weights, list), "master_weights must be a list if provided"
self.master_weights = master_weights self.master_weights = master_weights
if capturable: if capturable:
...@@ -134,14 +176,208 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -134,14 +176,208 @@ class FusedAdam(torch.optim.Optimizer):
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
self.master_weight_dtype = master_weight_dtype
self.exp_avg_dtype = exp_avg_dtype
self.exp_avg_sq_dtype = exp_avg_sq_dtype
self.name_to_dtype_map = {
"exp_avg": self.exp_avg_dtype,
"exp_avg_sq": self.exp_avg_sq_dtype,
"master_param": self.master_weight_dtype,
}
self.dtype_to_range_map = {
torch.float16: torch.full(
[1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32
),
torch.uint8: torch.full([1], 448.0, dtype=torch.float32),
}
self._scales = {}
self.use_decoupled_grad = use_decoupled_grad
def zero_grad(self): def zero_grad(self):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if self.set_grad_none: if not self.use_decoupled_grad and not self.set_grad_none:
for group in self.param_groups: super().zero_grad()
for p in group["params"]: return
for group in self.param_groups:
for p in group["params"]:
if self.use_decoupled_grad and self.set_grad_none:
p.decoupled_grad = None
elif self.use_decoupled_grad and not self.set_grad_none:
p.decoupled_grad.zero_()
elif not self.use_decoupled_grad and self.set_grad_none:
p.grad = None p.grad = None
def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
"""Apply scaling on `unscaled_state`. `scaled_state` and `scale` will be written inplace.
Arguments:
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
unscaled_state (torch.Tensor): An unscaled high-precision tensor.
scaled_state (torch.Tensor): An scaled low-precision tensor.
scale (torch.Tensor): A FP32 tensor representing the scaling factor.
"""
assert unscaled_state.dtype == torch.float32
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(scaled_state, Float8Tensor)
else: else:
super().zero_grad() assert scaled_state.dtype == dtype
max_range = self.dtype_to_range_map[dtype]
if max_range.device != scaled_state.device:
max_range = max_range.to(scaled_state.device)
self.dtype_to_range_map[scaled_state.dtype] = max_range
if unscaled_state.device != scaled_state.device:
unscaled_state = unscaled_state.to(scaled_state.device)
min_val, max_val = torch.aminmax(unscaled_state)
absmax = torch.maximum(-min_val, max_val)
absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device)
torch.div(absmax, max_range, out=scale)
if isinstance(scaled_state, Float8Tensor):
scaled_state._scale_inv.copy_(scale)
scaled_state.copy_(unscaled_state)
else:
rscale = torch.where(scale > 0, scale.reciprocal(), 0.0)
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)
def get_unscaled_state(self, param, state_name):
"""Return the unscaled state corresponding to the input `param` and `state_name`.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
"""
state = self.state[param]
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(state[state_name], Float8Tensor)
unscaled = state[state_name].float()
elif dtype == torch.float16:
assert state[state_name].dtype == torch.float16
unscaled = state[state_name].float()
unscaled.mul_(self._scales[param][state_name])
elif dtype == torch.float32:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
return unscaled
def set_scaled_state(self, param, state_name, unscaled_state):
"""Set the optimizer state.
If the dtype of the corresponding optimizer state is not FP32,
it will do scaling automatically.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""
assert unscaled_state.dtype == torch.float32
state = self.state[param]
if state_name not in state:
self._initialize_state(param, state_name, False)
dtype = self.name_to_dtype_map[state_name]
if dtype != torch.float32:
scale = self._scales[param]
self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
else:
state[state_name].copy_(unscaled_state)
def _initialize_state(self, param, state_name, zero_buffer: bool):
"""Initialize one of the optimizer states according to `state_name`.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
zero_buffer (bool): Whether to initialize the optimizer state with zeros.
"""
dtype = self.name_to_dtype_map[state_name]
data = torch.empty_like(param, dtype=dtype)
if zero_buffer:
data.zero_()
if dtype == torch.uint8:
self.state[param][state_name] = Float8Tensor(
data=data,
dtype=torch.float32,
fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device),
)
else:
self.state[param][state_name] = data
# Create scale if necessary.
if dtype != torch.float32:
if param not in self._scales:
self._scales[param] = {}
self._scales[param][state_name] = torch.ones(
[1], dtype=torch.float32, device=param.device
)
def initialize_state(self, param):
"""Initialize optimizer states.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
"""
self._initialize_state(param, "exp_avg", zero_buffer=True)
self._initialize_state(param, "exp_avg_sq", zero_buffer=True)
if self.master_weights:
self._initialize_state(param, "master_param", zero_buffer=False)
self.set_scaled_state(param, "master_param", param.clone().detach().float())
def state_dict(self):
"""Override the state_dict() of pytorch. Before returning the state_dict, cast all
non-fp32 states to fp32.
"""
state_dict = super().state_dict()
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
new_v = {}
for name in v:
new_v[name] = self.get_unscaled_state(param, name)
state_dict["state"][k] = new_v
return state_dict
def load_state_dict(self, state_dict):
"""Override the load_state_dict() of pytorch. Since pytorch's load_state_dict forces the
state to be the same dtype as param, We need to manully set the state again.
"""
super().load_state_dict(state_dict)
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.state[param] = {}
for name in v:
self.set_scaled_state(param, name, v[name].float())
def step(self, closure=None, grad_scaler=None): def step(self, closure=None, grad_scaler=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -156,8 +392,6 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -156,8 +392,6 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
master_param_idx = 0
for group in self.param_groups: for group in self.param_groups:
if len(group["params"]) == 0: if len(group["params"]) == 0:
continue continue
...@@ -196,6 +430,11 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -196,6 +430,11 @@ class FusedAdam(torch.optim.Optimizer):
amaxes = [] amaxes = []
scale_invs = [] scale_invs = []
# Lists for scaling
unscaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
scaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
state_scales = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
# Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is.
out_dtype = tex.DType.kFloat32 out_dtype = tex.DType.kFloat32
...@@ -207,31 +446,29 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -207,31 +446,29 @@ class FusedAdam(torch.optim.Optimizer):
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values self.initialize_state(p)
state["exp_avg"] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values if self.use_decoupled_grad:
state["exp_avg_sq"] = torch.zeros_like(p.data).float() p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None
# Master weights else:
if self.master_weights and p.dtype != torch.float32: p_grad = p.grad
# model weights can be fp32/bf16/fp16/fp8
# If it's fp32, it has no corresponding master weights
state["master_param"] = self.master_weights[master_param_idx]
master_param_idx += 1
assert (
state["master_param"].shape == p.shape
), "Master weights shape must match model weights shape"
p_master = state.get("master_param", None)
p_grad = p.grad
if self.master_weights and p_master is not None and p_master.grad is not None:
p_grad = p_master.grad
if p_grad is None: if p_grad is None:
continue continue
if p_grad.data.is_sparse: if p_grad.data.is_sparse:
raise RuntimeError("FusedAdam does not support sparse gradients.") raise RuntimeError("FusedAdam does not support sparse gradients.")
# Unscaling
unscaled_state = {}
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if name in state:
unscaled = self.get_unscaled_state(p, name)
unscaled_state[name] = unscaled
if self.name_to_dtype_map[name] != torch.float32:
unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name])
if isinstance(p, Float8Tensor): if isinstance(p, Float8Tensor):
out_dtype = p._fp8_dtype out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data) p_fp8_model.append(p._data.data)
...@@ -240,26 +477,28 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -240,26 +477,28 @@ class FusedAdam(torch.optim.Optimizer):
amaxes.append(amax) amaxes.append(amax)
scale_invs.append(scale_inv) scale_invs.append(scale_inv)
if self.master_weights: if self.master_weights:
p_main_of_fp8_model.append(p_master.data) p_main_of_fp8_model.append(unscaled_state["master_param"].data)
g_of_fp8_model.append(p_grad.data) g_of_fp8_model.append(p_grad.data)
m_of_fp8_model.append(state["exp_avg"]) m_of_fp8_model.append(unscaled_state["exp_avg"])
v_of_fp8_model.append(state["exp_avg_sq"]) v_of_fp8_model.append(unscaled_state["exp_avg_sq"])
elif p.dtype in [torch.float16, torch.bfloat16]: elif p.dtype in [torch.float16, torch.bfloat16]:
has_fp16 = has_fp16 or p.dtype == torch.float16 has_fp16 = has_fp16 or p.dtype == torch.float16
has_bf16 = has_bf16 or p.dtype == torch.bfloat16 has_bf16 = has_bf16 or p.dtype == torch.bfloat16
p_f16_model.append(p.data) p_f16_model.append(p.data)
if self.master_weights: if self.master_weights:
p_main_of_f16_model.append(p_master.data) p_main_of_f16_model.append(unscaled_state["master_param"].data)
g_of_f16_model.append(p_grad.data) g_of_f16_model.append(p_grad.data)
m_of_f16_model.append(state["exp_avg"]) m_of_f16_model.append(unscaled_state["exp_avg"])
v_of_f16_model.append(state["exp_avg_sq"]) v_of_f16_model.append(unscaled_state["exp_avg_sq"])
elif p.dtype == torch.float32: elif p.dtype == torch.float32:
p_f32_model.append(p.data) p_f32_model.append(p.data)
g_of_f32_model.append(p_grad.data) g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(state["exp_avg"]) m_of_f32_model.append(unscaled_state["exp_avg"])
v_of_f32_model.append(state["exp_avg_sq"]) v_of_f32_model.append(unscaled_state["exp_avg_sq"])
else: else:
raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8") raise RuntimeError(
"FusedAdam only support model weights in fp32, fp16, bf16 and fp8"
)
if self.capturable and len(p_fp8_model) > 0: if self.capturable and len(p_fp8_model) > 0:
raise RuntimeError( raise RuntimeError(
...@@ -389,4 +628,15 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -389,4 +628,15 @@ class FusedAdam(torch.optim.Optimizer):
tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
# Scaling
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if len(unscaled_lists[name]) > 0:
for unscaled, scaled, scale in zip(
unscaled_lists[name], scaled_lists[name], state_scales[name]
):
self._apply_scale(name, unscaled, scaled, scale)
# Try to reclaim the temporary fp32 buffers.
del unscaled_lists
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment