Unverified Commit 05c0fb02 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Support using fp16 master weights and fp16/fp8 optimizer states in FusedAdam (#1078)



* Add precision aware fused adam
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Minor changes based on review comments.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Signed-off-by: default avatarKunlun Li <94586211+kunlunl@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 23caab3f
......@@ -4,6 +4,7 @@
from itertools import product
import copy
from contextlib import nullcontext
import pytest
import torch
......@@ -174,6 +175,216 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param)
def gen_precision_aware_test(
self,
use_fp8_params,
param_dtype,
use_master_weights,
master_weight_dtype,
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
model_rtol=None,
model_atol=None,
master_rtol=None,
master_atol=None,
skip_assert=False,
):
build_model_context = nullcontext
build_model_context_args = {}
if use_fp8_params:
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
with build_model_context(**build_model_context_args):
model = MultiheadAttention(
hidden_size=1024,
num_attention_heads=16,
layer_number=1,
params_dtype=param_dtype,
fuse_qkv_params=True,
).cuda()
ref_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 1,
"betas": (0.1, 0.25),
"eps": 1e-08,
"weight_decay": 0,
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
def test_one_iteration(ref_optimizer, tst_optimizer):
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
if not skip_assert:
torch.testing.assert_close(
ref_params,
master_weights_to_fp32,
rtol=master_rtol,
atol=master_atol,
equal_nan=True,
)
ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params]
if not skip_assert:
torch.testing.assert_close(
ref_params_to_model_dtype,
model_params,
rtol=model_rtol,
atol=model_atol,
equal_nan=True,
)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
state_dict = tst_optim.state_dict()
tst_optim = te.optimizers.FusedAdam(
model_params,
master_weights=use_master_weights,
master_weight_dtype=master_weight_dtype,
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
**options,
)
tst_optim.load_state_dict(state_dict)
for i in range(self.iters):
test_one_iteration(ref_optim, tst_optim)
def test_fp32_no_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.float32,
use_master_weights=False,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.half,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_grad(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.bfloat16,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.half,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.uint8,
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.half,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.uint8,
skip_assert=True,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_model_weight_cast(self):
dtype = torch.bfloat16
......@@ -185,12 +396,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
......@@ -200,12 +409,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
......@@ -224,12 +438,10 @@ class TestFusedAdam(TestFusedOptimizer):
fuse_qkv_params=True,
).cuda()
ref_params = []
master_params = []
model_params = []
for p in model.parameters():
if p.requires_grad:
ref_params.append(p.detach().clone().float())
master_params.append(p.detach().clone().float())
model_params.append(p)
options = {
"lr": 5e-4,
......@@ -239,12 +451,17 @@ class TestFusedAdam(TestFusedOptimizer):
"amsgrad": False,
}
ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params, master_weights=True, use_decoupled_grad=True, **options
)
for i in range(self.iters):
self.gen_grad(ref_params, master_params)
for p_ref, p in zip(ref_params, model_params):
p_ref.grad = torch.rand_like(p_ref)
p.decoupled_grad = p_ref.grad.clone()
ref_optim.step()
tst_optim.step()
master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params]
torch.testing.assert_close(ref_params, master_params)
model_params_to_fp32 = [p.float() for p in model_params]
torch.testing.assert_close(
......
......@@ -179,7 +179,7 @@ struct AdamFunctorMaster {
}
};
template <typename T, typename FULL_T, typename index_t>
template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
......@@ -199,10 +199,10 @@ struct AdamFunctor {
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];
T *g = reinterpret_cast<T *>(tl.addresses[0][tensor_loc]);
GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;
T *p = reinterpret_cast<T *>(tl.addresses[1][tensor_loc]);
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;
FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
......@@ -223,10 +223,10 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
r_g[ii] = static_cast<MATH_T>(g[i]);
r_p[ii] = static_cast<MATH_T>(p[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
......@@ -259,9 +259,9 @@ struct AdamFunctor {
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
......@@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
......@@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......@@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, float, int32_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);)
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
const auto g_in_type = tensor_lists[0][0].scalar_type();
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
......
......@@ -3,11 +3,15 @@
# See LICENSE for license information.
"""Fused Adam optimizer."""
from copy import deepcopy
from itertools import chain
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from .multi_tensor_apply import multi_tensor_applier
from ..float8_tensor import Float8Tensor
def get_fp8_meta(fp8_tensor):
......@@ -68,11 +72,28 @@ class FusedAdam(torch.optim.Optimizer):
method is called. (default: True)
capturable (bool, optional): whether to use the version of the optimizer
that can be used with CUDA Graphs. (default: False)
master_weights (list of torch.Tensor, optional): master weights to use
for mixed precision training. If provided, the optimizer will update
the master weights and then cast the master weights to the model weights.
If not provided, the optimizer will update the model weights directly.
(default: None)
master_weights (bool, optional): whether to maintain FP32 master weights
in the optimizer with FP16/BF16 mixed precision training.
(default: False)
master_weight_dtype (torch.dtype, optional): The dtype of master weights.
If master_weights is False, this will be ignored. It can be one of
[torch.float32, torch.float16]. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg. It can be
one of [torch.float32, torch.float16, torch.uint8], where torch.uint8
represents FP8. If it's not torch.float32, the optimizer will create
a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq. It
can be one of [torch.float32, torch.float16, torch.uint8], where
torch.uint8 represents FP8. If it's not torch.float32, the optimizer
will create a FP32 scalar scaling factor to ensure precision.
(default: torch.float32)
use_decoupled_grad (bool, optional): Whether to use ".decoupled_grad"
instead of ".grad" for reading gradients. It's useful when the dtypes
of grad and param are different.
(default: False)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -92,12 +113,36 @@ class FusedAdam(torch.optim.Optimizer):
amsgrad=False,
set_grad_none=True,
capturable=False,
master_weights=None,
master_weights=False,
master_weight_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
):
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
# Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.")
# Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
# they will be copied to temporary fp32 buffers first. These fp32 buffers are then
# used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs,
# making CUDA Graph inapplicable in this scenario.
if capturable and master_weights and master_weight_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 master weights.")
if capturable and exp_avg_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 exp_avg.")
if capturable and exp_avg_sq_dtype != torch.float32:
raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq")
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = {
......@@ -112,9 +157,6 @@ class FusedAdam(torch.optim.Optimizer):
self.set_grad_none = set_grad_none
self.capturable = capturable
if master_weights is not None:
assert isinstance(master_weights, list), "master_weights must be a list if provided"
self.master_weights = master_weights
if capturable:
......@@ -134,14 +176,208 @@ class FusedAdam(torch.optim.Optimizer):
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
self.master_weight_dtype = master_weight_dtype
self.exp_avg_dtype = exp_avg_dtype
self.exp_avg_sq_dtype = exp_avg_sq_dtype
self.name_to_dtype_map = {
"exp_avg": self.exp_avg_dtype,
"exp_avg_sq": self.exp_avg_sq_dtype,
"master_param": self.master_weight_dtype,
}
self.dtype_to_range_map = {
torch.float16: torch.full(
[1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32
),
torch.uint8: torch.full([1], 448.0, dtype=torch.float32),
}
self._scales = {}
self.use_decoupled_grad = use_decoupled_grad
def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none:
if not self.use_decoupled_grad and not self.set_grad_none:
super().zero_grad()
return
for group in self.param_groups:
for p in group["params"]:
if self.use_decoupled_grad and self.set_grad_none:
p.decoupled_grad = None
elif self.use_decoupled_grad and not self.set_grad_none:
p.decoupled_grad.zero_()
elif not self.use_decoupled_grad and self.set_grad_none:
p.grad = None
def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
"""Apply scaling on `unscaled_state`. `scaled_state` and `scale` will be written inplace.
Arguments:
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
unscaled_state (torch.Tensor): An unscaled high-precision tensor.
scaled_state (torch.Tensor): An scaled low-precision tensor.
scale (torch.Tensor): A FP32 tensor representing the scaling factor.
"""
assert unscaled_state.dtype == torch.float32
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(scaled_state, Float8Tensor)
else:
super().zero_grad()
assert scaled_state.dtype == dtype
max_range = self.dtype_to_range_map[dtype]
if max_range.device != scaled_state.device:
max_range = max_range.to(scaled_state.device)
self.dtype_to_range_map[scaled_state.dtype] = max_range
if unscaled_state.device != scaled_state.device:
unscaled_state = unscaled_state.to(scaled_state.device)
min_val, max_val = torch.aminmax(unscaled_state)
absmax = torch.maximum(-min_val, max_val)
absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device)
torch.div(absmax, max_range, out=scale)
if isinstance(scaled_state, Float8Tensor):
scaled_state._scale_inv.copy_(scale)
scaled_state.copy_(unscaled_state)
else:
rscale = torch.where(scale > 0, scale.reciprocal(), 0.0)
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)
def get_unscaled_state(self, param, state_name):
"""Return the unscaled state corresponding to the input `param` and `state_name`.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
"""
state = self.state[param]
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(state[state_name], Float8Tensor)
unscaled = state[state_name].float()
elif dtype == torch.float16:
assert state[state_name].dtype == torch.float16
unscaled = state[state_name].float()
unscaled.mul_(self._scales[param][state_name])
elif dtype == torch.float32:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
return unscaled
def set_scaled_state(self, param, state_name, unscaled_state):
"""Set the optimizer state.
If the dtype of the corresponding optimizer state is not FP32,
it will do scaling automatically.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""
assert unscaled_state.dtype == torch.float32
state = self.state[param]
if state_name not in state:
self._initialize_state(param, state_name, False)
dtype = self.name_to_dtype_map[state_name]
if dtype != torch.float32:
scale = self._scales[param]
self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
else:
state[state_name].copy_(unscaled_state)
def _initialize_state(self, param, state_name, zero_buffer: bool):
"""Initialize one of the optimizer states according to `state_name`.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
zero_buffer (bool): Whether to initialize the optimizer state with zeros.
"""
dtype = self.name_to_dtype_map[state_name]
data = torch.empty_like(param, dtype=dtype)
if zero_buffer:
data.zero_()
if dtype == torch.uint8:
self.state[param][state_name] = Float8Tensor(
data=data,
dtype=torch.float32,
fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device),
)
else:
self.state[param][state_name] = data
# Create scale if necessary.
if dtype != torch.float32:
if param not in self._scales:
self._scales[param] = {}
self._scales[param][state_name] = torch.ones(
[1], dtype=torch.float32, device=param.device
)
def initialize_state(self, param):
"""Initialize optimizer states.
Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
"""
self._initialize_state(param, "exp_avg", zero_buffer=True)
self._initialize_state(param, "exp_avg_sq", zero_buffer=True)
if self.master_weights:
self._initialize_state(param, "master_param", zero_buffer=False)
self.set_scaled_state(param, "master_param", param.clone().detach().float())
def state_dict(self):
"""Override the state_dict() of pytorch. Before returning the state_dict, cast all
non-fp32 states to fp32.
"""
state_dict = super().state_dict()
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
new_v = {}
for name in v:
new_v[name] = self.get_unscaled_state(param, name)
state_dict["state"][k] = new_v
return state_dict
def load_state_dict(self, state_dict):
"""Override the load_state_dict() of pytorch. Since pytorch's load_state_dict forces the
state to be the same dtype as param, We need to manully set the state again.
"""
super().load_state_dict(state_dict)
groups = self.param_groups
saved_groups = deepcopy(state_dict["param_groups"])
id_map = dict(
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
)
)
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.state[param] = {}
for name in v:
self.set_scaled_state(param, name, v[name].float())
def step(self, closure=None, grad_scaler=None):
"""Performs a single optimization step.
......@@ -156,8 +392,6 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
master_param_idx = 0
for group in self.param_groups:
if len(group["params"]) == 0:
continue
......@@ -196,6 +430,11 @@ class FusedAdam(torch.optim.Optimizer):
amaxes = []
scale_invs = []
# Lists for scaling
unscaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
scaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
state_scales = {"exp_avg": [], "exp_avg_sq": [], "master_param": []}
# Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is.
out_dtype = tex.DType.kFloat32
......@@ -207,31 +446,29 @@ class FusedAdam(torch.optim.Optimizer):
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data).float()
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data).float()
# Master weights
if self.master_weights and p.dtype != torch.float32:
# model weights can be fp32/bf16/fp16/fp8
# If it's fp32, it has no corresponding master weights
state["master_param"] = self.master_weights[master_param_idx]
master_param_idx += 1
assert (
state["master_param"].shape == p.shape
), "Master weights shape must match model weights shape"
p_master = state.get("master_param", None)
p_grad = p.grad
self.initialize_state(p)
if self.master_weights and p_master is not None and p_master.grad is not None:
p_grad = p_master.grad
if self.use_decoupled_grad:
p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None
else:
p_grad = p.grad
if p_grad is None:
continue
if p_grad.data.is_sparse:
raise RuntimeError("FusedAdam does not support sparse gradients.")
# Unscaling
unscaled_state = {}
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if name in state:
unscaled = self.get_unscaled_state(p, name)
unscaled_state[name] = unscaled
if self.name_to_dtype_map[name] != torch.float32:
unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name])
if isinstance(p, Float8Tensor):
out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data)
......@@ -240,26 +477,28 @@ class FusedAdam(torch.optim.Optimizer):
amaxes.append(amax)
scale_invs.append(scale_inv)
if self.master_weights:
p_main_of_fp8_model.append(p_master.data)
p_main_of_fp8_model.append(unscaled_state["master_param"].data)
g_of_fp8_model.append(p_grad.data)
m_of_fp8_model.append(state["exp_avg"])
v_of_fp8_model.append(state["exp_avg_sq"])
m_of_fp8_model.append(unscaled_state["exp_avg"])
v_of_fp8_model.append(unscaled_state["exp_avg_sq"])
elif p.dtype in [torch.float16, torch.bfloat16]:
has_fp16 = has_fp16 or p.dtype == torch.float16
has_bf16 = has_bf16 or p.dtype == torch.bfloat16
p_f16_model.append(p.data)
if self.master_weights:
p_main_of_f16_model.append(p_master.data)
p_main_of_f16_model.append(unscaled_state["master_param"].data)
g_of_f16_model.append(p_grad.data)
m_of_f16_model.append(state["exp_avg"])
v_of_f16_model.append(state["exp_avg_sq"])
m_of_f16_model.append(unscaled_state["exp_avg"])
v_of_f16_model.append(unscaled_state["exp_avg_sq"])
elif p.dtype == torch.float32:
p_f32_model.append(p.data)
g_of_f32_model.append(p_grad.data)
m_of_f32_model.append(state["exp_avg"])
v_of_f32_model.append(state["exp_avg_sq"])
m_of_f32_model.append(unscaled_state["exp_avg"])
v_of_f32_model.append(unscaled_state["exp_avg_sq"])
else:
raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8")
raise RuntimeError(
"FusedAdam only support model weights in fp32, fp16, bf16 and fp8"
)
if self.capturable and len(p_fp8_model) > 0:
raise RuntimeError(
......@@ -389,4 +628,15 @@ class FusedAdam(torch.optim.Optimizer):
tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
# Scaling
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if len(unscaled_lists[name]) > 0:
for unscaled, scaled, scale in zip(
unscaled_lists[name], scaled_lists[name], state_scales[name]
):
self._apply_scale(name, unscaled, scaled, scale)
# Try to reclaim the temporary fp32 buffers.
del unscaled_lists
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment