logits (torch.Tensor): The logits to apply the penalizers to.
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
ifnotself.is_required:
return
forpenalizerinself.penalizers.values():
logits=penalizer.apply(logits)
returnlogits
deffilter(
self,
indices_to_keep:typing.List[int],
indices_tensor_to_keep:torch.Tensor=None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.