Unverified Commit 831590f6 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: contrastive search with full optional outputs (#19963)

* Use beam search functionality; Add extra outputs and test

* Add full tests for contrastive search

* Add error message on unconventional cache format
parent ab74ac11
...@@ -338,7 +338,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ ...@@ -338,7 +338,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
1. **[ERNIE](https://huggingface.co/docs/transformers/model_doc/ernie)** (from Baidu) released with the paper [ERNIE: Enhanced Representation through Knowledge Integration](https://arxiv.org/abs/1904.09223) by Yu Sun, Shuohuan Wang, Yukun Li, Shikun Feng, Xuyi Chen, Han Zhang, Xin Tian, Danxiang Zhu, Hao Tian, Hua Wu. 1. **[ERNIE](https://huggingface.co/docs/transformers/model_doc/ernie)** (from Baidu) released with the paper [ERNIE: Enhanced Representation through Knowledge Integration](https://arxiv.org/abs/1904.09223) by Yu Sun, Shuohuan Wang, Yukun Li, Shikun Feng, Xuyi Chen, Han Zhang, Xin Tian, Danxiang Zhu, Hao Tian, Hua Wu.
1. **[ESM](https://huggingface.co/docs/transformers/model_doc/esm)** (from Meta AI) are transformer protein language models. **ESM-1b** was released with the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. **ESM-1v** was released with the paper [Language models enable zero-shot prediction of the effects of mutations on protein function](https://doi.org/10.1101/2021.07.09.450648) by Joshua Meier, Roshan Rao, Robert Verkuil, Jason Liu, Tom Sercu and Alexander Rives. **ESM-2** was released with the paper [Language models of protein sequences at the scale of evolution enable accurate structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido, Alexander Rives. 1. **[ESM](https://huggingface.co/docs/transformers/model_doc/esm)** (from Meta AI) are transformer protein language models. **ESM-1b** was released with the paper [Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences](https://www.pnas.org/content/118/15/e2016239118) by Alexander Rives, Joshua Meier, Tom Sercu, Siddharth Goyal, Zeming Lin, Jason Liu, Demi Guo, Myle Ott, C. Lawrence Zitnick, Jerry Ma, and Rob Fergus. **ESM-1v** was released with the paper [Language models enable zero-shot prediction of the effects of mutations on protein function](https://doi.org/10.1101/2021.07.09.450648) by Joshua Meier, Roshan Rao, Robert Verkuil, Jason Liu, Tom Sercu and Alexander Rives. **ESM-2** was released with the paper [Language models of protein sequences at the scale of evolution enable accurate structure prediction](https://doi.org/10.1101/2022.07.20.500902) by Zeming Lin, Halil Akin, Roshan Rao, Brian Hie, Zhongkai Zhu, Wenting Lu, Allan dos Santos Costa, Maryam Fazel-Zarandi, Tom Sercu, Sal Candido, Alexander Rives.
1. **[FLAN-T5](https://huggingface.co/docs/transformers/main/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei 1. **[FLAN-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)** (from Google AI) released in the repository [google-research/t5x](https://github.com/google-research/t5x/blob/main/docs/models.md#flan-t5-checkpoints) by Hyung Won Chung, Le Hou, Shayne Longpre, Barret Zoph, Yi Tay, William Fedus, Eric Li, Xuezhi Wang, Mostafa Dehghani, Siddhartha Brahma, Albert Webson, Shixiang Shane Gu, Zhuyun Dai, Mirac Suzgun, Xinyun Chen, Aakanksha Chowdhery, Sharan Narang, Gaurav Mishra, Adams Yu, Vincent Zhao, Yanping Huang, Andrew Dai, Hongkun Yu, Slav Petrov, Ed H. Chi, Jeff Dean, Jacob Devlin, Adam Roberts, Denny Zhou, Quoc V. Le, and Jason Wei
1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. 1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab.
1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela.
1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
...@@ -360,7 +360,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ ...@@ -360,7 +360,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutxlm)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. 1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutxlm)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LeViT](https://huggingface.co/docs/transformers/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze. 1. **[LeViT](https://huggingface.co/docs/transformers/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
1. **[LiLT](https://huggingface.co/docs/transformers/main/model_doc/lilt)** (from South China University of Technology) released with the paper [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) by Jiapeng Wang, Lianwen Jin, Kai Ding. 1. **[LiLT](https://huggingface.co/docs/transformers/model_doc/lilt)** (from South China University of Technology) released with the paper [LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding](https://arxiv.org/abs/2202.13669) by Jiapeng Wang, Lianwen Jin, Kai Ding.
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang. 1. **[LongT5](https://huggingface.co/docs/transformers/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. 1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
...@@ -412,7 +412,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ ...@@ -412,7 +412,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. 1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo.
1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[Table Transformer](https://huggingface.co/docs/transformers/main/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham. 1. **[Table Transformer](https://huggingface.co/docs/transformers/model_doc/table-transformer)** (from Microsoft Research) released with the paper [PubTables-1M: Towards Comprehensive Table Extraction From Unstructured Documents](https://arxiv.org/abs/2110.00061) by Brandon Smock, Rohith Pesala, Robin Abraham.
1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou. 1. **[TAPEX](https://huggingface.co/docs/transformers/model_doc/tapex)** (from Microsoft Research) released with the paper [TAPEX: Table Pre-training via Learning a Neural SQL Executor](https://arxiv.org/abs/2107.07653) by Qian Liu, Bei Chen, Jiaqi Guo, Morteza Ziyadi, Zeqi Lin, Weizhu Chen, Jian-Guang Lou.
1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace). 1. **[Time Series Transformer](https://huggingface.co/docs/transformers/model_doc/time_series_transformer)** (from HuggingFace).
......
...@@ -106,19 +106,33 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): ...@@ -106,19 +106,33 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`. if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
`config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`: encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
""" """
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -136,6 +150,9 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput): ...@@ -136,6 +150,9 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
passed or when `config.output_hidden_states=True`): passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
...@@ -144,6 +161,7 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput): ...@@ -144,6 +161,7 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -628,47 +646,47 @@ class GenerationMixin: ...@@ -628,47 +646,47 @@ class GenerationMixin:
@staticmethod @staticmethod
def _expand_inputs_for_generation( def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1, expand_size: int = 1,
is_encoder_decoder: bool = False, is_encoder_decoder: bool = False,
attention_mask: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
**model_kwargs, **model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]: ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = ( if input_ids is not None:
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) input_ids = input_ids.repeat_interleave(expand_size, dim=0)
)
input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs: if model_kwargs.get("token_type_ids") is not None:
token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = model_kwargs["token_type_ids"].repeat_interleave(expand_size, dim=0)
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None: if model_kwargs.get("attention_mask") is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
if is_encoder_decoder: if is_encoder_decoder:
encoder_outputs = model_kwargs.get("encoder_outputs")
if encoder_outputs is None: if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) expand_size, dim=0
) )
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs return input_ids, model_kwargs
@staticmethod @staticmethod
def _update_model_kwargs_for_generation( def _extract_past_from_model_output(outputs: ModelOutput):
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False past = None
) -> Dict[str, Any]:
# update past
if "past_key_values" in outputs: if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values past = outputs.past_key_values
elif "mems" in outputs: elif "mems" in outputs:
model_kwargs["past"] = outputs.mems past = outputs.mems
elif "past_buckets_states" in outputs: elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states past = outputs.past_buckets_states
else: return past
model_kwargs["past"] = None
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
# update token_type_ids with last value # update token_type_ids with last value
if "token_type_ids" in model_kwargs: if "token_type_ids" in model_kwargs:
...@@ -1533,7 +1551,7 @@ class GenerationMixin: ...@@ -1533,7 +1551,7 @@ class GenerationMixin:
# 11. expand input_ids with `num_return_sequences` additional sequences per batch # 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, input_ids=input_ids,
expand_size=num_return_sequences, expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs, **model_kwargs,
...@@ -1571,7 +1589,10 @@ class GenerationMixin: ...@@ -1571,7 +1589,10 @@ class GenerationMixin:
) )
# 11. interleave input_ids with `num_beams` additional sequences per batch # 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
) )
# 12. run beam search # 12. run beam search
return self.beam_search( return self.beam_search(
...@@ -1611,7 +1632,7 @@ class GenerationMixin: ...@@ -1611,7 +1632,7 @@ class GenerationMixin:
# 12. interleave input_ids with `num_beams` additional sequences per batch # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, input_ids=input_ids,
expand_size=num_beams * num_return_sequences, expand_size=num_beams * num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs, **model_kwargs,
...@@ -1658,7 +1679,10 @@ class GenerationMixin: ...@@ -1658,7 +1679,10 @@ class GenerationMixin:
) )
# 11. interleave input_ids with `num_beams` additional sequences per batch # 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
) )
# 12. run beam search # 12. run beam search
return self.group_beam_search( return self.group_beam_search(
...@@ -1739,7 +1763,10 @@ class GenerationMixin: ...@@ -1739,7 +1763,10 @@ class GenerationMixin:
) )
# 11. interleave input_ids with `num_beams` additional sequences per batch # 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs input_ids=input_ids,
expand_size=num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
) )
# 12. run beam search # 12. run beam search
return self.constrained_beam_search( return self.constrained_beam_search(
...@@ -1861,12 +1888,22 @@ class GenerationMixin: ...@@ -1861,12 +1888,22 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
batch_size = input_ids.shape[0]
while True: while True:
if synced_gpus: if synced_gpus:
...@@ -1879,27 +1916,20 @@ class GenerationMixin: ...@@ -1879,27 +1916,20 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0: if this_peer_finished_flag.item() == 0.0:
break break
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None: if model_kwargs.get("past") is None:
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs` # the `encoder_outputs`
outputs = self( outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
) )
# past_key_values is required for fast decoding
if "past_key_values" not in outputs:
raise ValueError(
f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
"for contrastive search."
)
past_key_values = outputs.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens) # previous tokens)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -1913,23 +1943,42 @@ class GenerationMixin: ...@@ -1913,23 +1943,42 @@ class GenerationMixin:
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past = model_kwargs.get("past")
if past is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif not isinstance(past[0], (tuple, torch.Tensor)) or past[0][0].shape[0] != batch_size:
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
)
# contrastive_search main logic start: # contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty # degeneration penalty
bsz, seqlen, embed_dim = last_hidden_states.size()
# logits processor
logit_for_next_step = logits_processor(input_ids, logit_for_next_step) logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step) logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=top_k)
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores: if output_scores:
scores += (logit_for_next_step,) scores += (logit_for_next_step,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -1938,47 +1987,22 @@ class GenerationMixin: ...@@ -1938,47 +1987,22 @@ class GenerationMixin:
else (outputs.hidden_states,) else (outputs.hidden_states,)
) )
# enlarge the past_key_values # Replicates the new past_key_values to match the `top_k` candidates
new_key_values = [] new_key_values = []
for layer in past_key_values: for layer in model_kwargs["past"]:
items = [] items = []
# item is either the key or the value matrix # item is either the key or the value matrix
for item in layer: for item in layer:
bsz, num_head, seq_len, esz = item.size() items.append(item.repeat_interleave(top_k, dim=0))
item = (
item.unsqueeze(1)
.expand(-1, top_k, -1, -1, -1)
.reshape(bsz * top_k, num_head, seq_len, esz)
.contiguous()
) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items) new_key_values.append(items)
past_key_values = new_key_values model_kwargs["past"] = new_key_values
# build next attention mask
if "attention_mask" in model_inputs:
attention_mask = model_kwargs["attention_mask"] # [B, S]
attention_mask = attention_mask.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, attention_mask.size(-1))
else:
attention_mask = None
# encoder-decoder model also contains the `encoder_outputs`
if self.config.is_encoder_decoder and "encoder_outputs" in model_inputs:
encoder_outputs = model_inputs["encoder_outputs"]
else:
encoder_outputs = None
next_model_inputs = self.prepare_inputs_for_generation(
top_k_ids.view(-1, 1),
past=past_key_values,
attention_mask=attention_mask,
use_cache=True,
encoder_outputs=encoder_outputs,
)
# compute the candidate tokens by the language model and collects their hidden_states # compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self( outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
) )
past_key_values = outputs.past_key_values next_past_key_values = self._extract_past_from_model_output(outputs)
logits = outputs.logits[:, -1, :] logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models # name is different for encoder-decoder and decoder-only models
...@@ -1988,9 +2012,7 @@ class GenerationMixin: ...@@ -1988,9 +2012,7 @@ class GenerationMixin:
else: else:
next_hidden = outputs.hidden_states[-1] next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states full_hidden_states = outputs.hidden_states
context_hidden = ( context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence # model confidence
...@@ -2001,42 +2023,55 @@ class GenerationMixin: ...@@ -2001,42 +2023,55 @@ class GenerationMixin:
# (model confidence minus degeneration penalty); (6) decoder hidden_states # (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
next_hidden = next_hidden[range(bsz), selected_idx, :] next_hidden = next_hidden[range(batch_size), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
decoder_hidden_states = [] next_decoder_hidden_states = ()
for layer in full_hidden_states: for layer in full_hidden_states:
layer = torch.stack(torch.split(layer.squeeze(dim=1), top_k)) layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
layer = layer[range(bsz), selected_idx, :] next_decoder_hidden_states += (layer,)
decoder_hidden_states.append(layer)
# select the past_key_value # select the past_key_value
new_key_values = [] new_key_values = ()
for layer in past_key_values: for layer in next_past_key_values:
items = [] items = ()
# item is either the key or the value matrix # item is either the key or the value matrix
for item in layer: for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam // top_k)
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items.append(item) items += (item,)
new_key_values.append(items) new_key_values += (items,)
past_key_values = new_key_values next_past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
next_step_cross_attentions = ()
next_step_decoder_attentions = ()
if output_attentions:
for layer in outputs.cross_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_cross_attentions += (layer,)
for layer in outputs.decoder_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_decoder_attentions += (layer,)
outputs = Seq2SeqLMOutput( outputs = Seq2SeqLMOutput(
past_key_values=past_key_values, past_key_values=next_past_key_values,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=next_decoder_hidden_states,
decoder_attentions=next_step_decoder_attentions or None,
cross_attentions=next_step_cross_attentions or None,
) )
else: else:
next_step_attentions = ()
if output_attentions:
for layer in outputs.attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
next_step_attentions += (layer,)
outputs = CausalLMOutputWithPast( outputs = CausalLMOutputWithPast(
past_key_values=past_key_values, past_key_values=next_past_key_values,
hidden_states=decoder_hidden_states, hidden_states=next_decoder_hidden_states,
attentions=model_kwargs["attention_mask"], attentions=next_step_attentions or None,
) )
# contrastive_search main logic end # contrastive_search main logic end
...@@ -2071,12 +2106,17 @@ class GenerationMixin: ...@@ -2071,12 +2106,17 @@ class GenerationMixin:
return ContrastiveSearchEncoderDecoderOutput( return ContrastiveSearchEncoderDecoderOutput(
sequences=input_ids, sequences=input_ids,
scores=scores, scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
return ContrastiveSearchDecoderOnlyOutput( return ContrastiveSearchDecoderOnlyOutput(
sequences=input_ids, sequences=input_ids,
scores=scores, scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
else: else:
......
...@@ -18,7 +18,7 @@ import inspect ...@@ -18,7 +18,7 @@ import inspect
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor from ..test_modeling_common import floats_tensor, ids_tensor
...@@ -35,7 +35,6 @@ if is_torch_available(): ...@@ -35,7 +35,6 @@ if is_torch_available():
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Tokenizer, GPT2Tokenizer,
ImageGPTForCausalImageModeling, ImageGPTForCausalImageModeling,
OPTForCausalLM,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
T5ForConditionalGeneration, T5ForConditionalGeneration,
...@@ -623,6 +622,76 @@ class GenerationTesterMixin: ...@@ -623,6 +622,76 @@ class GenerationTesterMixin:
) )
return output_generate, output_group_beam_search return output_generate, output_group_beam_search
def _contrastive_generate(
self,
model,
input_ids,
attention_mask,
max_length,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
contrastive_search_kwargs = {
"penalty_alpha": 0.6,
"top_k": 5,
}
if model.config.is_encoder_decoder:
max_length = 4
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
kwargs = {}
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_length=max_length,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
)
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
output_contrastive = model.contrastive_search(
input_ids,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
**contrastive_search_kwargs,
)
return output_contrastive, output_generate
def test_greedy_generate(self): def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal # check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1336,6 +1405,64 @@ class GenerationTesterMixin: ...@@ -1336,6 +1405,64 @@ class GenerationTesterMixin:
for output in (output_beam_search, output_generate): for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams) self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_contrastive_generate(self):
# check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_contrastive, output_generate = self._contrastive_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
)
self.assertListEqual(output_contrastive.tolist(), output_generate.tolist())
def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]):
return
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_contrastive, output_generate = self._contrastive_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertListEqual(output_generate.sequences.tolist(), output_contrastive.sequences.tolist())
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_generate_with_head_masking(self): def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used.""" """Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
...@@ -1696,197 +1823,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1696,197 +1823,6 @@ class GenerationIntegrationTests(unittest.TestCase):
], ],
) )
@slow
def test_contrastive_search_bart(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@slow
def test_contrastive_search_t5(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
"permanent residence after the marriages, prosecutors say."
],
)
@slow
def test_contrastive_search_opt(self):
article = (
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?"
)
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
"you lived there?\nStatue: A hundred years.\nHuman: And you’re from what country?\nStatue: The United "
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
"country.\nHuman: What tyranny?\nStatue: They didn’t let me speak my mind.\nHuman: What was your "
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
"France.\nHuman: And your parents were French?\nStatue"
],
)
@tooslow
def test_contrastive_search_gptj(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and "
"research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, "
"Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's "
"parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating "
"a company that would apply deep learning to problems in healthcare, energy, transportation, and "
"other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 "
"million in cash and stock.[3] The acquisition was seen as a way for Google to enter the "
"fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns "
'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" '
'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."'
"[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google "
"employees"
],
)
@slow
def test_contrastive_search_gpt2(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
def test_max_length_backward_compat_greedy(self): def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
...@@ -3045,6 +2981,31 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -3045,6 +2981,31 @@ class GenerationIntegrationTests(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids, force_words_ids=[[[-1]]]) model.generate(input_ids, force_words_ids=[[[-1]]])
def test_contrastive_search_batched(self):
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
articles = ["Foo", "Bar Baz"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
model.config.eos_token_id = None
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids.to(torch_device)
input_ids = tokenizer(articles[1], return_tensors="pt").input_ids.to(torch_device)
output_sequences_batched = model.generate(
input_ids=input_ids_batched, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True
)
output_sequences = model.generate(
input_ids=input_ids, penalty_alpha=0.6, top_k=4, return_dict_in_generate=True, output_scores=True
)
batched_out = tokenizer.decode(output_sequences_batched.sequences[1], skip_special_tokens=True)
out = tokenizer.decode(output_sequences.sequences[0], skip_special_tokens=True)
self.assertEqual(batched_out, out)
# output_sequences_batched.scores[0][1] -> 1st set of logits, 2nd sequence
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
self.assertTrue(max_score_diff < 1e-5)
def test_validate_generation_inputs(self): def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
......
...@@ -1181,6 +1181,52 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -1181,6 +1181,52 @@ class BartModelIntegrationTests(unittest.TestCase):
) )
assert generated_summaries == EXPECTED assert generated_summaries == EXPECTED
@slow
def test_contrastive_search_bart(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
class BartStandaloneDecoderModelTester: class BartStandaloneDecoderModelTester:
def __init__( def __init__(
......
...@@ -763,3 +763,37 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -763,3 +763,37 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
model.generate(input_ids, do_sample=False, max_time=None, max_length=256) model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
@slow
def test_contrastive_search_gpt2(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(torch_device)
input_ids = gpt2_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = gpt2_model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
...@@ -572,3 +572,38 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): ...@@ -572,3 +572,38 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
model.generate(input_ids, do_sample=False, max_time=None, max_length=256) model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
@tooslow
def test_contrastive_search_gptj(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and "
"research laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16
).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom with offices in Mountain View, San Francisco, New York City, Paris, Tokyo, Seoul, "
"Beijing, Singapore, Tel Aviv, Dublin, Sydney, and Melbourne.[1]\n\nContents\n\nIn 2010, Google's "
"parent company, Alphabet, announced a $500 million investment in DeepMind, with the aim of creating "
"a company that would apply deep learning to problems in healthcare, energy, transportation, and "
"other areas.[2]\n\nOn April 23, 2014, Google announced that it had acquired DeepMind for $400 "
"million in cash and stock.[3] The acquisition was seen as a way for Google to enter the "
"fast-growing field of artificial intelligence (AI), which it had so far avoided due to concerns "
'about ethical and social implications.[4] Google co-founder Sergey Brin said that he was "thrilled" '
'to have acquired DeepMind, and that it would "help us push the boundaries of AI even further."'
"[5]\n\nDeepMind's founders, Demis Hassabis and Mustafa Suleyman, were joined by a number of Google "
"employees"
],
)
...@@ -490,3 +490,34 @@ class OPTGenerationTest(unittest.TestCase): ...@@ -490,3 +490,34 @@ class OPTGenerationTest(unittest.TestCase):
self.assertFalse( self.assertFalse(
torch.isnan(outputs.logits[0]).any().item() torch.isnan(outputs.logits[0]).any().item()
) # the first logits could contain NaNs if it fails ) # the first logits could contain NaNs if it fails
@slow
def test_contrastive_search_opt(self):
article = (
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?"
)
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
"you lived there?\nStatue: A hundred years.\nHuman: And you’re from what country?\nStatue: The United "
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
"country.\nHuman: What tyranny?\nStatue: They didn’t let me speak my mind.\nHuman: What was your "
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
"France.\nHuman: And your parents were French?\nStatue"
],
)
...@@ -30,7 +30,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor ...@@ -30,7 +30,14 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import ByT5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer from transformers import (
AutoTokenizer,
ByT5Tokenizer,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5Tokenizer,
)
from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.t5.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -1216,6 +1223,51 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -1216,6 +1223,51 @@ class T5ModelIntegrationTests(unittest.TestCase):
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation) self.assertEqual(translation, expected_translation)
@slow
def test_contrastive_search_t5(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
article = "summarize: " + article.strip()
t5_tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-base-cnn-dm")
t5_model = T5ForConditionalGeneration.from_pretrained("flax-community/t5-base-cnn-dm").to(torch_device)
input_ids = t5_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
).input_ids.to(torch_device)
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for "
"permanent residence after the marriages, prosecutors say."
],
)
@require_torch @require_torch
class TestAsymmetricT5(unittest.TestCase): class TestAsymmetricT5(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment