"configs/vscode:/vscode.git/clone" did not exist on "4cac91eb7ba7ce303260cf22841525595e935985"
Unverified Commit 95119ad7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Correct input_ids detection (#14815)

* [Generate] Correct input_ids detection

* correct
parent bdbe3df8
...@@ -457,7 +457,7 @@ class GenerationMixin: ...@@ -457,7 +457,7 @@ class GenerationMixin:
pad_token_id: int, pad_token_id: int,
eos_token_id: int, eos_token_id: int,
) -> torch.LongTensor: ) -> torch.LongTensor:
is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2 is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id) (eos_token_id is not None) and (pad_token_id != eos_token_id)
......
...@@ -1719,6 +1719,31 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1719,6 +1719,31 @@ class GenerationIntegrationTests(unittest.TestCase):
# make sure model generated correctly until `max_length` # make sure model generated correctly until `max_length`
self.assertEqual(output_sequences.shape, (1, 5)) self.assertEqual(output_sequences.shape, (1, 5))
def test_encoder_decoder_generate_attention_mask(self):
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
# need extrem generation values here to force this test
# to fail when `attention_mask` is not correctly treated in generate
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5
).to(torch_device)
model.config.eos_token_id = None
input_ids = tokenizer(articles[0], return_tensors="pt").input_ids.to(torch_device)
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device)
output_sequences_batched = model.generate(
input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True
)
output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True)
batched_out = output_sequences_batched.sequences_scores
out = output_sequences.sequences_scores
diff = (batched_out[:5].sum() - out.sum()).abs()
self.assertTrue(diff < 1e-4)
def test_decoder_generate_with_inputs_embeds(self): def test_decoder_generate_with_inputs_embeds(self):
article = """I need input_ids to generate""" article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment