Unverified Commit b18d8534 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Make generate multi-modal (#14784)



* finish refactor

* refactor

* add tests

* add more tests

* up

* finish tests

* finish

* up

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve docstring

* fix docs
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 48463ebb
This diff is collapsed.
...@@ -20,6 +20,8 @@ import unittest ...@@ -20,6 +20,8 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_common import floats_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -29,6 +31,9 @@ if is_torch_available(): ...@@ -29,6 +31,9 @@ if is_torch_available():
BartTokenizer, BartTokenizer,
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Tokenizer, GPT2Tokenizer,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_beam_search import BeamSearchScorer
...@@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase):
# cannot generate from `inputs_embeds` for decoder only # cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds) model.generate(inputs_embeds=inputs_embeds)
def test_generate_input_ids_as_kwarg(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 15))
def test_generate_input_ids_as_encoder_kwarg(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 5))
def test_generate_inputs_and_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids, input_ids=input_ids)
def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, input_values=input_ids)
def test_generate_input_values_as_encoder_kwarg(self):
input_values = floats_tensor((2, 250))
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu()
output_sequences = model.generate(input_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_input_features_as_encoder_kwarg(self):
input_features = floats_tensor((3, 20, 24))
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
output_sequences = model.generate(input_features, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 5))
def test_generate_pixel_values_as_encoder_kwarg(self):
pixel_values = floats_tensor((2, 3, 30, 30))
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
output_sequences = model.generate(pixel_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment