import pytest import torch from basicsr.losses.basic_loss import CharbonnierLoss, L1Loss, MSELoss, WeightedTVLoss @pytest.mark.parametrize('loss_class', [L1Loss, MSELoss, CharbonnierLoss]) def test_pixellosses(loss_class): """Test loss: pixel losses""" pred = torch.rand((1, 3, 4, 4), dtype=torch.float32) target = torch.rand((1, 3, 4, 4), dtype=torch.float32) loss = loss_class(loss_weight=1.0, reduction='mean') out = loss(pred, target, weight=None) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # -------------------- test with other reduction -------------------- # # reduction = none loss = loss_class(loss_weight=1.0, reduction='none') out = loss(pred, target, weight=None) assert isinstance(out, torch.Tensor) assert out.shape == (1, 3, 4, 4) # test with spatial weights weight = torch.rand((1, 3, 4, 4), dtype=torch.float32) out = loss(pred, target, weight=weight) assert isinstance(out, torch.Tensor) assert out.shape == (1, 3, 4, 4) # reduction = sum loss = loss_class(loss_weight=1.0, reduction='sum') out = loss(pred, target, weight=None) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # -------------------- test unsupported loss reduction -------------------- # with pytest.raises(ValueError): loss_class(loss_weight=1.0, reduction='unknown') def test_weightedtvloss(): """Test loss: WeightedTVLoss""" pred = torch.rand((1, 3, 4, 4), dtype=torch.float32) loss = WeightedTVLoss(loss_weight=1.0, reduction='mean') out = loss(pred, weight=None) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # test with spatial weights weight = torch.rand((1, 3, 4, 4), dtype=torch.float32) out = loss(pred, weight=weight) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # -------------------- test reduction = sum-------------------- # loss = WeightedTVLoss(loss_weight=1.0, reduction='sum') out = loss(pred, weight=None) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # test with spatial weights weight = torch.rand((1, 3, 4, 4), dtype=torch.float32) out = loss(pred, weight=weight) assert isinstance(out, torch.Tensor) assert out.shape == torch.Size([]) # -------------------- test unsupported loss reduction -------------------- # with pytest.raises(ValueError): WeightedTVLoss(loss_weight=1.0, reduction='unknown') with pytest.raises(ValueError): WeightedTVLoss(loss_weight=1.0, reduction='none')