Unverified Commit f5afae50 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add test for sigmoid_focal_loss (#5783)



* Add test for sigmoid_focal_loss

* Update test/test_ops.py

Improve code by using torch.testing.make_tensor
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Update test/test_ops.py

Remove unnecessary assert
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Update test/test_ops.py

Refactor code for generating inputs and targets
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Improve focal_loss test code suggested on comment by Philip

* Use fuser2 to prevent fuser bug

* Combine function to generate input, dont set the fuser when device is cpu

* Add github issue for the fuser problem
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5e2db86c
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_gpu, needs_cuda from common_utils import assert_equal, cpu_and_gpu, needs_cuda
from PIL import Image from PIL import Image
from torch import nn, Tensor from torch import nn, Tensor
...@@ -1450,5 +1451,123 @@ class TestDropBlock: ...@@ -1450,5 +1451,123 @@ class TestDropBlock:
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
class TestFocalLoss:
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
def logit(p: Tensor) -> Tensor:
return torch.log(p / (1 - p))
def generate_tensor_with_range_type(shape, range_type, **kwargs):
if range_type != "random_binary":
low, high = {
"small": (0.0, 0.2),
"big": (0.8, 1.0),
"zeros": (0.0, 0.0),
"ones": (1.0, 1.0),
"random": (0.0, 1.0),
}[range_type]
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
else:
return torch.randint(0, 2, shape, **kwargs)
# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
inputs = []
targets = []
for input_range_type, target_range_type in [
("small", "zeros"),
("small", "ones"),
("small", "random_binary"),
("big", "zeros"),
("big", "ones"),
("big", "random_binary"),
("random", "zeros"),
("random", "ones"),
("random", "random_binary"),
]:
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
return torch.cat(inputs), torch.cat(targets)
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [0, 1])
def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# For testing the ratio with manual calculation, we require the reduction to be "none"
reduction = "none"
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
assert torch.all(
focal_loss <= ce_loss
), "focal loss must be less or equal to cross entropy loss with same input"
loss_ratio = (focal_loss / ce_loss).squeeze()
prob = torch.sigmoid(inputs)
p_t = prob * targets + (1 - prob) * (1 - targets)
correct_ratio = (1.0 - p_t) ** gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
correct_ratio = correct_ratio * alpha_t
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol)
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [2, 3])
def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
alpha = -1
gamma = 0
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
inputs_fl = inputs.clone().requires_grad_()
targets_fl = targets.clone()
inputs_ce = inputs.clone().requires_grad_()
targets_ce = targets.clone()
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol)
focal_loss.backward()
ce_loss.backward()
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol)
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [4, 5])
def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment