output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
output_ids (torch.Tensor): The output tokens.
"""
ifnotself.is_required:
return
...
...
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
deffilter(
self,
indices_to_keep:typing.List[int],
indices_to_keep:List[int],
indices_tensor_to_keep:torch.Tensor=None,
):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_to_keep (List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
ifnotself.is_required:
...
...
@@ -174,32 +160,18 @@ class _TokenIDs:
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
# first step must be input, which will be converted to Req
steps:typing.List[Step]
steps:List[Step]
eos_token_id:int=-1
def__post_init__(self):
...
...
@@ -66,7 +66,7 @@ class Subject:
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
)
deftensor_keys(self,i:int=0)->typing.Set[str]:
deftensor_keys(self,i:int=0)->Set[str]:
returnset(self.steps[i].expected_tensors.keys())
defto_req(self)->MockReq:
...
...
@@ -80,7 +80,7 @@ class Subject:
@dataclasses.dataclass
classCase:
enabled:bool
test_subjects:typing.List[Subject]
test_subjects:List[Subject]
def__post_init__(self):
# each test_subjects.steps should have the same expected_tensors.keys()
...
...
@@ -90,12 +90,12 @@ class Case:
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
)
deftensor_keys(self,i:int=0)->typing.List[str]:
deftensor_keys(self,i:int=0)->List[str]:
returnset(self.test_subjects[i].tensor_keys())
classBaseBatchedPenalizerTest(unittest.TestCase):
Penalizer:typing.Type[_BatchedPenalizer]
Penalizer:Type[_BatchedPenalizer]
device="cuda"
vocab_size=5
...
...
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):