# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. def check_logits_losses(logits_list, losses): len_logits = len(logits_list) len_losses = len(losses['types']) if len_logits != len_losses: raise RuntimeError( 'The length of logits_list should equal to the types of loss config: {} != {}.' .format(len_logits, len_losses)) def loss_computation(logits_list, labels, losses, edges=None): check_logits_losses(logits_list, losses) loss_list = [] per_channel_dice = None for i in range(len(logits_list)): logits = logits_list[i] loss_i = losses['types'][i] coef_i = losses['coef'][i] if loss_i.__class__.__name__ in ('BCELoss', 'FocalLoss' ) and loss_i.edge_label: # If use edges as labels According to loss type. loss_list.append(coef_i * loss_i(logits, edges)) elif loss_i.__class__.__name__ == 'MixedLoss': mixed_loss_list, per_channel_dice = loss_i(logits, labels) for mixed_loss in mixed_loss_list: loss_list.append(coef_i * mixed_loss) elif loss_i.__class__.__name__ in ("KLLoss", ): loss_list.append(coef_i * loss_i(logits_list[0], logits_list[1].detach())) elif loss_i.__class__.__name__ in ["DiceLoss", "MultipleLoss"]: loss, per_channel_dice = loss_i(logits, labels) loss_list.append(coef_i * loss) else: loss_list.append(coef_i * loss_i(logits, labels)) return loss_list, per_channel_dice