"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "314cdf7c2583ba523a2184124cbabdc537bb38fc"
Unverified Commit aeb18b92 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding new `encoder_no_repeat_ngram_size` to `generate`. (#9984)

Adding new `encoder_no_repeat_ngram_size` to `generate`.

Blenderbot results seemed off compared to original ParlAI script:
`https://parl.ai/projects/recipes/`

. Notably the model seems
to repeat a lot what was said during the conversation.

The actual problem was that `no_repeat_ngram_size` actually applies
to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies
to the previously generated ids (within the decoder). The history
conversation of blenderbot is within the `encoder` part so that
explains why HF's implementation had the repetitions.

This fix was focused on blenderbot *not* small and added tests
for those because they are quite different in configuration.

This change includes:

- Adding a new EncoderNoRepeatLogitProcessor.
- Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`)
- Adding 1 new config parameter `encoder_no_repeat_ngram_size`.
- Adding 2 tests, one for the pipeline (high level, inputs exhibited
repeat behavior, one low level for EncoderNoRepeatLogitProcessor)
- Factored NoRepeatLogitProcessor so that logic could be reused.

Further work:

- Blenderbot conversational pipeline still does not behave correctly
 as they way input is prepared within the pipeline is still incorrect
(follow up PR)
- Blenderbot allows the bot to have personas, which is done by
prepending "your personna: XXXX" to the input, this could be explored
too in a follow up PR.

@patrickvonplaten
@LysandreJik

* Update src/transformers/generation_logits_process.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Doc quality.

* Fixing test.

* Last fixes.

* Fixing to account for batch_size.

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e89c959a
...@@ -117,6 +117,9 @@ class PretrainedConfig(object): ...@@ -117,6 +117,9 @@ class PretrainedConfig(object):
- **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the - **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the
:obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size
can only occur once. can only occur once.
- **encoder_no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by
default in the :obj:`generate` method of the model for ``encoder_no_repeat_ngram_size``. If set to int > 0,
all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the ``decoder_input_ids``.
- **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated - **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated
that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the
words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word,
...@@ -205,6 +208,7 @@ class PretrainedConfig(object): ...@@ -205,6 +208,7 @@ class PretrainedConfig(object):
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
......
...@@ -235,6 +235,41 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -235,6 +235,41 @@ class TopKLogitsWarper(LogitsWarper):
return scores return scores
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens
class NoRepeatNGramLogitsProcessor(LogitsProcessor): class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r""" r"""
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
...@@ -253,36 +288,53 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): ...@@ -253,36 +288,53 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0] num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens): for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf") scores[i, banned_tokens] = -float("inf")
return scores return scores
def _calc_banned_ngram_tokens(
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
) -> List[Iterable[int]]: r"""
"""Copied from fairseq for no_repeat_ngram in beam_search""" :class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids.
if cur_len + 1 < self.ngram_size: See `ParlAI <https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350>`__.
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)] Args:
generated_ngrams = [{} for _ in range(num_hypos)] encoder_ngram_size (:obj:`int`):
for idx in range(num_hypos): All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids.
gen_tokens = prev_input_ids[idx].tolist() encoder_input_ids (:obj:`int`):
generated_ngram = generated_ngrams[idx] The encoder_input_ids that should not be repeated within the decoder ids.
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): """
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
def _get_generated_ngrams(hypo_idx): raise ValueError(
# Before decoding the next token, prevent decoding of ngrams that have already appeared f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
start_idx = cur_len + 1 - self.ngram_size )
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) self.ngram_size = encoder_ngram_size
return generated_ngrams[hypo_idx].get(ngram_idx, []) if len(encoder_input_ids.shape) == 1:
encoder_input_ids = encoder_input_ids.unsqueeze(0)
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] self.batch_size = encoder_input_ids.shape[0]
return banned_tokens self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# B x num_beams
num_hypos = scores.shape[0]
num_beams = num_hypos // self.batch_size
cur_len = input_ids.shape[-1]
banned_batch_tokens = [
_get_generated_ngrams(
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
)
for hypo_idx in range(num_hypos)
]
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
class NoBadWordsLogitsProcessor(LogitsProcessor): class NoBadWordsLogitsProcessor(LogitsProcessor):
......
...@@ -23,6 +23,7 @@ from torch.nn import functional as F ...@@ -23,6 +23,7 @@ from torch.nn import functional as F
from .file_utils import ModelOutput from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
...@@ -537,6 +538,8 @@ class GenerationMixin: ...@@ -537,6 +538,8 @@ class GenerationMixin:
self, self,
repetition_penalty: float, repetition_penalty: float,
no_repeat_ngram_size: int, no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]], bad_words_ids: List[List[int]],
min_length: int, min_length: int,
eos_token_id: int, eos_token_id: int,
...@@ -555,6 +558,11 @@ class GenerationMixin: ...@@ -555,6 +558,11 @@ class GenerationMixin:
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
) )
encoder_no_repeat_ngram_size = (
encoder_no_repeat_ngram_size
if encoder_no_repeat_ngram_size is not None
else self.config.encoder_no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
...@@ -574,6 +582,13 @@ class GenerationMixin: ...@@ -574,6 +582,13 @@ class GenerationMixin:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
if self.config.is_encoder_decoder:
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
)
if bad_words_ids is not None: if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1: if min_length is not None and eos_token_id is not None and min_length > -1:
...@@ -601,6 +616,7 @@ class GenerationMixin: ...@@ -601,6 +616,7 @@ class GenerationMixin:
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None, no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None, num_return_sequences: Optional[int] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
...@@ -661,6 +677,9 @@ class GenerationMixin: ...@@ -661,6 +677,9 @@ class GenerationMixin:
sequences. sequences.
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once. If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
``decoder_input_ids``.
bad_words_ids(:obj:`List[List[int]]`, `optional`): bad_words_ids(:obj:`List[List[int]]`, `optional`):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use :obj:`tokenizer(bad_word, should not appear in the generated text, use :obj:`tokenizer(bad_word,
...@@ -820,6 +839,9 @@ class GenerationMixin: ...@@ -820,6 +839,9 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id pad_token_id = eos_token_id
# Storing encoder_input_ids for logits_processor that could use them
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs # add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
...@@ -862,6 +884,8 @@ class GenerationMixin: ...@@ -862,6 +884,8 @@ class GenerationMixin:
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=encoder_input_ids,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
...@@ -1638,6 +1662,7 @@ class GenerationMixin: ...@@ -1638,6 +1662,7 @@ class GenerationMixin:
beam_idx = beam_outputs["next_beam_indices"] beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
cur_len = cur_len + 1 cur_len = cur_len + 1
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
......
...@@ -128,6 +128,7 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -128,6 +128,7 @@ class BlenderbotConfig(PretrainedConfig):
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
encoder_no_repeat_ngram_size=3,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -136,6 +137,7 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -136,6 +137,7 @@ class BlenderbotConfig(PretrainedConfig):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
**kwargs, **kwargs,
) )
......
...@@ -27,6 +27,7 @@ if is_torch_available(): ...@@ -27,6 +27,7 @@ if is_torch_available():
import torch.nn.functional as F import torch.nn.functional as F
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
...@@ -208,6 +209,68 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -208,6 +209,68 @@ class LogitsProcessorTest(unittest.TestCase):
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]] torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
) )
def test_encoder_no_repeat_ngram_dist_processor(self):
vocab_size = 3
num_beams = 2
batch_size = 1
encoder_input_ids = torch.tensor([1, 2, 1, 1], device=torch_device, dtype=torch.long)
input_ids = torch.tensor([[1, 2, 1], [8, 0, 2]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]])
# 3-gram would forbid 1st token at 1st beam and no token at 2nd beam
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]]
)
# Batched input
vocab_size = 3
num_beams = 2
batch_size = 2
encoder_input_ids = torch.tensor([[1, 2, 1, 1], [0, 0, 2, 1]], device=torch_device, dtype=torch.long)
input_ids = torch.tensor([[1, 2, 1], [1, 0, 2], [0, 0, 0], [0, 2, 2]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2gram
# Batch 1
# - Beam 1: tokens (1, 2) forbidden
# - Beam 2: tokens (1) forbidden
# Batch 2
# - Beam 1: tokens (0, 2) forbidden
# - Beam 2: tokens (1) forbidden
self.assertListEqual(
torch.isinf(filtered_scores_2_gram).tolist(),
[[False, True, True], [False, True, False], [True, False, True], [False, True, False]],
)
# Batch 1
# - Beam 1: tokens (1) forbidden
# - Beam 2: tokens () forbidden
# Batch 2
# - Beam 1: tokens (2) forbidden
# - Beam 2: tokens () forbidden
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(),
[[False, True, False], [False, False, False], [False, False, True], [False, False, False]],
)
def test_no_bad_words_dist_processor(self): def test_no_bad_words_dist_processor(self):
vocab_size = 5 vocab_size = 5
batch_size = 2 batch_size = 2
......
...@@ -276,6 +276,47 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas ...@@ -276,6 +276,47 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
self.assertEqual(result.generated_responses[1], "It's a comedy.") self.assertEqual(result.generated_responses[1], "It's a comedy.")
@require_torch
@slow
def test_integration_torch_conversation_blenderbot_400M(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
conversation_1 = Conversation("hello")
result = nlp(
conversation_1,
)
self.assertEqual(
result.generated_responses[0],
# ParlAI implementation output, we have a different one, but it's our
# second best, you can check by using num_return_sequences=10
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
" Hello! How are you doing today? I just got back from a walk with my dog.",
)
conversation_1 = Conversation(" Lasagne hello")
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
self.assertEqual(
result.generated_responses[0],
" Lasagne is my favorite Italian dish. Do you like lasagne?",
)
conversation_1 = Conversation(
"Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne."
)
result = nlp(
conversation_1,
encoder_no_repeat_ngram_size=3,
)
self.assertEqual(
result.generated_responses[0],
# ParlAI implementation output, we have a different one, but it's our
# second best, you can check by using num_return_sequences=10
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
)
@require_torch @require_torch
@slow @slow
def test_integration_torch_conversation_encoder_decoder(self): def test_integration_torch_conversation_encoder_decoder(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment