"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6b1ff250842f52136d5159bb67a26b50ba01485d"
Commit aceb3fba authored by Patrick von Platen's avatar Patrick von Platen
Browse files

only do output_past=True for language generation in bart

parent 7cba11fb
......@@ -64,7 +64,6 @@ class ModelTester:
self.eos_token_id = 2
self.pad_token_id = 1
self.bos_token_id = 0
self.output_past = True
torch.manual_seed(0)
def prepare_config_and_inputs_for_common(self):
......@@ -86,7 +85,6 @@ class ModelTester:
eos_token_ids=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
output_past=self.output_past,
)
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict
......
......@@ -628,6 +628,9 @@ class ModelTesterMixin:
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
if self.is_encoder_decoder:
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
for model_class in self.all_generative_model_classes:
model = model_class(config)
model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment