# Copyright (c) Microsoft Corporation # All rights reserved. # # MIT License # # Permission is hereby granted, free of charge, # to any person obtaining a copy of this software and associated # documentation files (the "Software"), # to deal in the Software without restriction, including without limitation # the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and # to permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included # in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING # BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss2d(nn.Module): def __init__(self, gamma=2, size_average=True): super(FocalLoss2d, self).__init__() self.gamma = gamma self.size_average = size_average def forward(self, logit, target, class_weight=None, type='sigmoid'): target = target.view(-1, 1).long() if type=='sigmoid': if class_weight is None: class_weight = [1]*2 #[0.5, 0.5] prob = torch.sigmoid(logit) prob = prob.view(-1, 1) prob = torch.cat((1-prob, prob), 1) select = torch.FloatTensor(len(prob), 2).zero_().cuda() select.scatter_(1, target, 1.) elif type=='softmax': B,C,H,W = logit.size() if class_weight is None: class_weight =[1]*C #[1/C]*C logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) prob = F.softmax(logit,1) select = torch.FloatTensor(len(prob), C).zero_().cuda() select.scatter_(1, target, 1.) class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) class_weight = torch.gather(class_weight, 0, target) prob = (prob*select).sum(1).view(-1,1) prob = torch.clamp(prob,1e-8,1-1e-8) batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log() if self.size_average: loss = batch_loss.mean() else: loss = batch_loss return loss if __name__ == '__main__': L = FocalLoss2d() out = torch.randn(2, 3, 3).cuda() target = (torch.sigmoid(out) > 0.5).float() loss = L(out, target) print(loss)