Unverified Commit a51ff542 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update FP8 recipe test to handle recipe changes (#834)



Update FP8 recipe test to handle recipe changes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent aad4e173
......@@ -29,7 +29,7 @@ class TestFP8Recipe:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("amax_history_len", [1, 31, 1024])
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update(
......@@ -51,7 +51,10 @@ class TestFP8Recipe:
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y = module(
torch.randn([16, 16], device="cuda"),
is_first_microbatch=True,
)
y.backward(torch.zeros_like(y))
# Get amax history and scaling factors
......@@ -67,101 +70,96 @@ class TestFP8Recipe:
# Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
if amax_history_len > 1:
amax_history_forward[1, 0].fill_(3)
amax_history_forward[0, :].zero_()
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5)
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5)
scale_inv_backward.copy_(torch.reciprocal(scale_backward))
amax_history_backward[0, :].zero_()
# Expected amax history after update
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0)
ref_amax_history_forward[0].zero_()
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0)
ref_amax_history_backward[0].zero_()
# Note: amax history is only updated when amax is updated
update_weight_amax = is_first_microbatch is None or is_first_microbatch
ref_amax_history_forward = amax_history_forward.clone()
ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
if update_weight_amax:
ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
ref_amax_history_forward[0, :].zero_()
ref_amax_history_backward = amax_history_backward.clone()
ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
ref_amax_history_backward[0, :].zero_()
# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[0]
ref_amax_backward = amax_history_backward[0]
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if not update_weight_scale_inv:
update_weight_amax = is_first_microbatch is None or is_first_microbatch
if not update_weight_amax:
ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Make sure we are not trivially passing tests
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
amax_history_forward[1:],
ref_amax_history_forward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_forward,
ref_scale_forward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_inv_forward,
ref_scale_inv_forward,
)
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
with pytest.raises(AssertionError):
# Perform forward, backward, and optimizer steps to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.randn([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.randn_like(y))
# Check that amax history matches expected values
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
amax_history_forward[:-1],
ref_amax_history_forward[:-1],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
amax_history_backward[:-1],
ref_amax_history_backward[:-1],
)
# Perform forward and backward pass to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.zeros([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.zeros_like(y))
# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Check that fp8_meta matches expected values
# Check that scale and scale inverse match expected values
# Note: scale and scale inverse are only updated when amax is updated
torch.testing.assert_close(
fp8_meta[forward_key].amax_history[1:],
ref_amax_history_forward[1:],
scale_forward[0],
ref_scale_forward[0],
)
torch.testing.assert_close(
fp8_meta[forward_key].scale,
ref_scale_forward,
scale_inv_forward[0],
ref_scale_inv_forward[0],
)
if update_weight_amax:
torch.testing.assert_close(
fp8_meta[forward_key].scale_inv,
ref_scale_inv_forward,
scale_forward[1],
ref_scale_forward[1],
)
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
scale_inv_forward[1],
ref_scale_inv_forward[1],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
scale_backward[0],
ref_scale_backward[0],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
scale_inv_backward[0],
ref_scale_inv_backward[0],
)
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment