Unverified Commit cb966e64 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate Test] fix greedy generate test (#8293)

* fix greedy generate test

* delet ipdb
parent 734afa37
......@@ -140,10 +140,6 @@ class GenerationTesterMixin:
# check `generate()` and `greedy_search()` are equal
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
max_length = 4
output_ids_generate = model.generate(
......@@ -154,6 +150,13 @@ class GenerationTesterMixin:
max_length=max_length,
**logits_process_kwargs,
)
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
output_ids_greedy = model.greedy_search(
input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment