@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
...
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
Returns:
Returns:
torch.Tensor: The logits after applying the penalizers.
torch.Tensor: The logits after applying the penalizers.
"""
"""
ifnotself.is_required:
return
forpenalizerinself.penalizers.values():
forpenalizerinself.penalizers.values():
logits=penalizer.apply(logits)
logits=penalizer.apply(logits)
...
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
...
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
"""
ifnotself.is_required:
return
empty_indices=len(indices_to_keep)==0
empty_indices=len(indices_to_keep)==0
is_required=False
forpenalizerinself.penalizers.values():
forpenalizerinself.penalizers.values():
ifnotpenalizer.is_required()orempty_indices:
tmp_is_required=penalizer.is_required()
is_required=is_requiredortmp_is_required
ifnottmp_is_requiredorempty_indices:
penalizer.teardown()
penalizer.teardown()
else:
else:
# create tensor index only when it's needed
# create tensor index only when it's needed
...
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
...
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator: