# first step must be input, which will be converted to Req
steps:List[Step]
eos_token_id:int=-1
def__post_init__(self):
ifself.steps[0].type!=StepType.INPUT:
raiseValueError("First step must be input")
# each steps should have the same expected_tensors.keys()
foriinrange(1,len(self.steps)):
ifself.tensor_keys(i)!=self.tensor_keys():
raiseValueError(
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
# each test_subjects.steps should have the same expected_tensors.keys()
foriinrange(1,len(self.test_subjects)):
ifself.tensor_keys(i)!=self.tensor_keys():
raiseValueError(
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
)
deftensor_keys(self,i:int=0)->List[str]:
returnset(self.test_subjects[i].tensor_keys())
classBaseBatchedPenalizerTest(unittest.TestCase):
Penalizer:Type[_BatchedPenalizer]
device="cuda"
vocab_size=5
enabled:Subject=None
disabled:Subject=None
defsetUp(self):
ifself.__class__==BaseBatchedPenalizerTest:
self.skipTest("Base class for penalizer tests")
self.create_test_subjects()
self.create_test_cases()
deftensor(self,data,**kwargs)->torch.Tensor:
"""
Shortcut to create a tensor with device=self.device.