Unverified Commit 86928e07 authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

Add adam bf16 state with original fp32 kernel (#1640)



* support adam bf16 state
Signed-off-by: default avatarXiaobingSuper <xiaobingzhangupc@gmail.com>

* use fp32 kernel but keep bf16 optimizer states to save memory
Signed-off-by: default avatarlit <lit@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXiaobingSuper <xiaobingzhangupc@gmail.com>
Signed-off-by: default avatarlit <lit@nvidia.com>
Co-authored-by: default avatarXiaobingSuper <xiaobingzhangupc@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 66d6afbf
......@@ -360,6 +360,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
......@@ -389,6 +403,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
......
......@@ -133,10 +133,10 @@ class FusedAdam(torch.optim.Optimizer):
# Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.")
# Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
......@@ -259,6 +259,10 @@ class FusedAdam(torch.optim.Optimizer):
scale (torch.Tensor): A FP32 tensor representing the scaling factor.
"""
assert unscaled_state.dtype == torch.float32
if scaled_state.dtype == torch.bfloat16:
scaled_state.copy_(unscaled_state.bfloat16())
return
dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8:
assert isinstance(scaled_state, Float8Tensor)
......@@ -313,8 +317,11 @@ class FusedAdam(torch.optim.Optimizer):
else:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
elif dtype == torch.bfloat16:
assert state[state_name].dtype == torch.bfloat16
unscaled = state[state_name].float()
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.")
return unscaled
def set_scaled_state(self, param, state_name, unscaled_state):
......@@ -329,6 +336,7 @@ class FusedAdam(torch.optim.Optimizer):
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""
store_param_remainders = (
self.store_param_remainders
and state_name == "master_param"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment