Unverified Commit d6286646 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Support batched input for decoder start ids (#28887)



* support batched input for decoder start ids

* Fix typos
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* minor changes

* fix: decoder_start_id as list

* empty commit

* empty commit

* empty commit

* empty commit

* empty commit

* empty commit

* empty commit

* empty commit

* empty commit

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent cc309fd4
...@@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin): ...@@ -233,8 +233,11 @@ class GenerationConfig(PushToHubMixin):
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
`decoder_input_ids`. `decoder_input_ids`.
decoder_start_token_id (`int`, *optional*): decoder_start_token_id (`Union[int, List[int]]`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token or a list of length
`batch_size`. Indicating a list enables different start ids for each element in the batch
(e.g. multilingual models with different target languages in one batch)
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) > Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)
......
...@@ -501,7 +501,7 @@ class GenerationMixin: ...@@ -501,7 +501,7 @@ class GenerationMixin:
batch_size: int, batch_size: int,
model_input_name: str, model_input_name: str,
model_kwargs: Dict[str, torch.Tensor], model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: int = None, decoder_start_token_id: Union[int, List[int]] = None,
bos_token_id: int = None, bos_token_id: int = None,
device: torch.device = None, device: torch.device = None,
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
...@@ -519,7 +519,17 @@ class GenerationMixin: ...@@ -519,7 +519,17 @@ class GenerationMixin:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None: if device is None:
device = self.device device = self.device
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id if isinstance(decoder_start_token_id, list):
if len(decoder_start_token_id) != batch_size:
raise ValueError(
f"`decoder_start_token_id` expcted to have length {batch_size} but got {len(decoder_start_token_id)}"
)
decoder_input_ids_start = torch.tensor(decoder_start_token_id, dtype=torch.long, device=device)
decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
else:
decoder_input_ids_start = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
)
# no user input -> use decoder_start_token_id as decoder_input_ids # no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None: if decoder_input_ids is None:
...@@ -531,7 +541,13 @@ class GenerationMixin: ...@@ -531,7 +541,13 @@ class GenerationMixin:
pass pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided) # decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): elif (
isinstance(decoder_start_token_id, int)
and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
) or (
isinstance(decoder_start_token_id, torch.Tensor)
and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
):
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs: if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = model_kwargs["decoder_attention_mask"]
...@@ -543,7 +559,9 @@ class GenerationMixin: ...@@ -543,7 +559,9 @@ class GenerationMixin:
return decoder_input_ids, model_kwargs return decoder_input_ids, model_kwargs
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: def _get_decoder_start_token_id(
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
) -> int:
decoder_start_token_id = ( decoder_start_token_id = (
decoder_start_token_id decoder_start_token_id
if decoder_start_token_id is not None if decoder_start_token_id is not None
......
...@@ -3163,6 +3163,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3163,6 +3163,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]]) model.generate(input_ids, force_words_ids=[[[-1]]])
def test_batched_decoder_start_id(self):
# PT-only test: TF doesn't support batched_decoder_start_id
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0]
outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id)
outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch)
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
def test_contrastive_search_batched(self): def test_contrastive_search_batched(self):
# PT-only test: TF doesn't have constrained beam search # PT-only test: TF doesn't have constrained beam search
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs) # Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment