Unverified Commit a51ff542 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update FP8 recipe test to handle recipe changes (#834)



Update FP8 recipe test to handle recipe changes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent aad4e173
...@@ -29,7 +29,7 @@ class TestFP8Recipe: ...@@ -29,7 +29,7 @@ class TestFP8Recipe:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("amax_history_len", [1, 31, 1024]) @pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False]) @pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update( def test_amax_and_scale_update(
...@@ -51,7 +51,10 @@ class TestFP8Recipe: ...@@ -51,7 +51,10 @@ class TestFP8Recipe:
) )
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16) module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda")) y = module(
torch.randn([16, 16], device="cuda"),
is_first_microbatch=True,
)
y.backward(torch.zeros_like(y)) y.backward(torch.zeros_like(y))
# Get amax history and scaling factors # Get amax history and scaling factors
...@@ -67,101 +70,96 @@ class TestFP8Recipe: ...@@ -67,101 +70,96 @@ class TestFP8Recipe:
# Tweak amax history and scaling factors # Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
if amax_history_len > 1: amax_history_forward[0, :].zero_()
amax_history_forward[1, 0].fill_(3)
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
scale_inv_forward.copy_(torch.reciprocal(scale_forward)) scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5) amax_history_backward[0, :].zero_()
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5)
scale_inv_backward.copy_(torch.reciprocal(scale_backward))
# Expected amax history after update # Expected amax history after update
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0) # Note: amax history is only updated when amax is updated
ref_amax_history_forward[0].zero_() update_weight_amax = is_first_microbatch is None or is_first_microbatch
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0) ref_amax_history_forward = amax_history_forward.clone()
ref_amax_history_backward[0].zero_() ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
if update_weight_amax:
ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
ref_amax_history_forward[0, :].zero_()
ref_amax_history_backward = amax_history_backward.clone()
ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
ref_amax_history_backward[0, :].zero_()
# Expected scale and scale inverse # Expected scale and scale inverse
if amax_compute_algo == "max": if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent": elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[0] ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[0] ref_amax_backward = amax_history_backward[-1]
else: else:
raise ValueError(f"{amax_compute_algo=} is not supported") raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch update_weight_amax = is_first_microbatch is None or is_first_microbatch
if not update_weight_scale_inv: if not update_weight_amax:
ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Make sure we are not trivially passing tests # Perform forward, backward, and optimizer steps to update fp8_meta
if amax_history_len > 1: with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
with pytest.raises(AssertionError): x = torch.randn([16, 16], device="cuda")
torch.testing.assert_close( y = module(x, is_first_microbatch=is_first_microbatch)
amax_history_forward[1:], y.backward(torch.randn_like(y))
ref_amax_history_forward[1:],
) # Check that amax history matches expected values
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_forward,
ref_scale_forward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_inv_forward,
ref_scale_inv_forward,
)
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[backward_key].scale, amax_history_forward[:-1],
ref_scale_backward, ref_amax_history_forward[:-1],
) )
with pytest.raises(AssertionError):
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[backward_key].scale_inv, amax_history_backward[:-1],
ref_scale_inv_backward, ref_amax_history_backward[:-1],
) )
# Perform forward and backward pass to update fp8_meta # Expected scale and scale inverse
with te.fp8_autocast(enabled=True, fp8_recipe=recipe): if amax_compute_algo == "max":
x = torch.zeros([16, 16], device="cuda") ref_amax_forward = amax_history_forward.max(dim=0).values
y = module(x, is_first_microbatch=is_first_microbatch) ref_amax_backward = amax_history_backward.max(dim=0).values
y.backward(torch.zeros_like(y)) elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Check that fp8_meta matches expected values # Check that scale and scale inverse match expected values
# Note: scale and scale inverse are only updated when amax is updated
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[forward_key].amax_history[1:], scale_forward[0],
ref_amax_history_forward[1:], ref_scale_forward[0],
) )
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[forward_key].scale, scale_inv_forward[0],
ref_scale_forward, ref_scale_inv_forward[0],
) )
if update_weight_amax:
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[forward_key].scale_inv, scale_forward[1],
ref_scale_inv_forward, ref_scale_forward[1],
) )
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:], scale_inv_forward[1],
ref_amax_history_backward[1:], ref_scale_inv_forward[1],
) )
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[backward_key].scale, scale_backward[0],
ref_scale_backward, ref_scale_backward[0],
) )
torch.testing.assert_close( torch.testing.assert_close(
fp8_meta[backward_key].scale_inv, scale_inv_backward[0],
ref_scale_inv_backward, ref_scale_inv_backward[0],
) )
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment