Unverified Commit 2b16299f authored by David Berard's avatar David Berard Committed by GitHub
Browse files

Remove torch.jit.fuser("fuser2") in test (#7069)

* [WIP] Remove torch.jit.fuser("fuser2") in test

Internally we're considering removing support for fuser2, so we need to remove this special case from the test.

* completely remove special-casing
parent 35f68a09
......@@ -1555,13 +1555,7 @@ class TestFocalLoss:
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment