Unverified Commit b1aa4982 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Cleaning up `ConversationalPipeline` to support more than DialoGPT. (#10002)

* Cleaning up `ConversationalPipeline` to support more than DialoGPT.

Currently ConversationalPipeline was heavily biased towards DialoGPT
,which is the default model for this pipeline.

This PR proposes changes to put back the modifications specific to
DialoGPT into tokenizer-specific behavior wherever possible, by
creating `_build_conversation_input_ids` function that takes
conversation as input, and returns a list of ints corresponding
to the tokens. It feels natural to put here because all models
have probably different strategies to build input_ids from the
full conversation and it's the tokenizer's job to transform strings
into tokens (and vice-versa)

If `_build_conversation_input_ids` is missing, previous behavior is
used so we don't break anything so far (except for blenderbot where it's a fix).

This PR also contains a fix for too long inputs. There used
to be dead code for trying to limit the size of incoming input.
The introduced fixed is that we limit
within `_build_conversation_input_ids` to `tokenizer.model_max_length`.
It corresponds to the intent of the removed dead code and is actually
better because it corresponds to `model_max_length` which is different
from `max_length` (which is a default parameter for `generate`).

- Removed `history` logic from the Conversation as it's not relevant
anymore because tokenization logic has been moved to tokenizer.
And tokenizer cannot save any cache, and conversation cannot know
what is relevant or not.
Also it's not usable from `blenderbot` because the input_ids are
not append only (EOS tokens is always at the end).

- Added `iter_texts` method on `Conversation` because all
the code was literred with some form of this iteration of
past/generated_responses.

* Removing torch mention in types.

* Adding type checking to `_build_conversation_input_ids`.

* Fixing import in strings.
parent ae37ceac
......@@ -14,12 +14,15 @@
# limitations under the License.
"""Tokenization class for Blenderbot."""
from typing import List
from typing import TYPE_CHECKING, List
from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
......@@ -74,6 +77,23 @@ class BlenderbotTokenizer(RobertaTokenizer):
"""
return token_ids_0 + [self.eos_token_id]
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
inputs = []
for is_user, text in conversation.iter_texts():
if is_user:
# We need to space prefix as it's being done within blenderbot
inputs.append(" " + text)
else:
# Generated responses should contain them already.
inputs.append(text)
full_string = " ".join(inputs)
input_ids = self.encode(full_string)
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
return input_ids
def get_pairs(word):
"""
......
......@@ -18,7 +18,7 @@
import json
import os
from functools import lru_cache
from typing import Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import regex as re
......@@ -26,6 +26,9 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
......@@ -296,3 +299,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
......@@ -16,7 +16,7 @@
import json
from typing import Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from tokenizers import pre_tokenizers
......@@ -26,6 +26,10 @@ from ...utils import logging
from .tokenization_gpt2 import GPT2Tokenizer
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
......@@ -171,3 +175,13 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
"""This corresponds to DialoGPT variants of models."""
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
import uuid
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
......@@ -70,8 +69,6 @@ class Conversation:
self.past_user_inputs: List[str] = past_user_inputs
self.generated_responses: List[str] = generated_responses
self.new_user_input: Optional[str] = text
self._index: int = 0
self._history: List[int] = []
def __eq__(self, other):
if not isinstance(other, Conversation):
......@@ -128,6 +125,19 @@ class Conversation:
"""
self.generated_responses.append(response)
def iter_texts(self):
"""
Iterates over all blobs of the conversation.
Retuns: Iterator of (is_user, text_chunk) in chronological order of the conversation. ``is_user`` is a
:obj:`bool`, ``text_chunks`` is a :obj:`str`.
"""
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
yield True, user_input
yield False, generated_response
if self.new_user_input:
yield True, self.new_user_input
def __repr__(self):
"""
Generates a string representation of the conversation.
......@@ -139,11 +149,9 @@ class Conversation:
suggestions? bot >> The Big Lebowski
"""
output = "Conversation id: {} \n".format(self.uuid)
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
output += "user >> {} \n".format(user_input)
output += "bot >> {} \n".format(generated_response)
if self.new_user_input is not None:
output += "user >> {} \n".format(self.new_user_input)
for is_user, text in self.iter_texts():
name = "user" if is_user else "bot"
output += "{} >> {} \n".format(name, text)
return output
......@@ -191,34 +199,6 @@ class ConversationalPipeline(Pipeline):
self.min_length_for_response = min_length_for_response
def _get_history(self, conversation):
"""
Private function (subject to change) that simply tokenizes and concatenates past inputs. Also saves that
tokenization into the conversation state.
Args:
conversation (:class:`~transformers.Conversation`)
Returns:
:obj:`List[int]`: The list of tokens for the past input of that conversation.
"""
# Make a copy to prevent messing cache up if there's an error
# within this function
history = conversation._history.copy()
index = conversation._index
new_index = index
for i, (past_user_input, generated_response) in enumerate(
zip(conversation.past_user_inputs[index:], conversation.generated_responses[index:])
):
for el in (past_user_input, generated_response):
new_history = self._parse_and_tokenize([el])[0]
history.extend(new_history)
new_index = i + index + 1
conversation._index = new_index
conversation._history = history
# Hand back a copy to caller so they can't accidently modify our cache.
return history.copy()
def __call__(
self,
conversations: Union[Conversation, List[Conversation]],
......@@ -249,7 +229,7 @@ class ConversationalPipeline(Pipeline):
for conversation in conversations:
assert isinstance(
conversation, Conversation
), "DialoguePipeline expects a Conversation or list of Conversations as an input"
), "ConversationalPipeline expects a Conversation or list of Conversations as an input"
if conversation.new_user_input is None:
raise ValueError(
"Conversation with UUID {} does not contain new user input to process. "
......@@ -261,14 +241,11 @@ class ConversationalPipeline(Pipeline):
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
else:
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")
raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input")
with self.device_placement():
inputs = self._parse_and_tokenize([conversation.new_user_input for conversation in conversations])
histories = [self._get_history(conversation) for conversation in conversations]
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
inputs = self._concat_inputs_history(inputs, histories, max_length)
inputs = self._parse_and_tokenize(conversations)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
......@@ -277,11 +254,6 @@ class ConversationalPipeline(Pipeline):
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
if input_length > 0.9 * max_length:
logger.warning(
"Longest conversation length: {} is bigger than 0.9 * max_length: {}. "
"You might consider trimming the early phase of the conversation".format(input_length, max_length)
)
generated_responses = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -318,18 +290,6 @@ class ConversationalPipeline(Pipeline):
else:
return output
def _parse_and_tokenize(
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs
def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
"""
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as
......@@ -363,28 +323,23 @@ class ConversationalPipeline(Pipeline):
outputs.append(sequence_tokens)
return outputs
def _concat_inputs_history(self, inputs: List[List[int]], histories: List[Optional[List[int]]], max_length: int):
"""
Builds an input prepended by the history for this conversation, allowing multi-turn conversation with context
"""
outputs = []
for new_input, history in zip(inputs, histories):
if history is not None:
new_input = history + new_input
if len(new_input) > max_length - self.min_length_for_response:
cutoff_eos_index = 0
while len(new_input) - cutoff_eos_index > max_length - self.min_length_for_response:
if cutoff_eos_index >= len(new_input):
break
cutoff_eos_index = new_input[cutoff_eos_index:].index(self.tokenizer.eos_token_id)
if cutoff_eos_index == 0 or cutoff_eos_index == len(new_input) - 1:
break
def _legacy_parse_and_tokenize(self, conversation: List[Conversation]) -> List[int]:
eos_token_id = self.tokenizer.eos_token_id
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
if len(input_ids) > self.tokenizer.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]:
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
input_ids = [self.tokenizer._build_conversation_input_ids(conversation) for conversation in conversations]
else:
logger.warning(
f"Cutting history off because it's too long ({len(new_input)} > {max_length - self.min_length_for_response}) for underlying model"
)
outputs.append(new_input)
padded_outputs = self.tokenizer.pad(
{"input_ids": outputs}, padding="longest", return_attention_mask=True, return_tensors=self.framework
# If the tokenizer cannot handle conversations, we default to only the old version
input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations]
inputs = self.tokenizer.pad(
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors="pt"
)
return padded_outputs
return inputs
......@@ -15,6 +15,7 @@
import unittest
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Conversation,
......@@ -87,11 +88,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
with self.assertLogs("transformers", level="WARNING") as log:
result = conversation_agent([conversation_1, conversation_2], max_length=48)
self.assertEqual(len(log.output), 2)
self.assertIn("You might consider trimming the early phase of the conversation", log.output[0])
self.assertIn("Setting `pad_token_id`", log.output[1])
# Two conversations in one pass
self.assertEqual(result, [conversation_1, conversation_2])
......@@ -111,12 +108,7 @@ class SimpleConversationPipelineTests(unittest.TestCase):
# One conversation with history
conversation_2.add_user_input("Why do you recommend it?")
with self.assertLogs("transformers", level="WARNING") as log:
result = conversation_agent(conversation_2, max_length=64)
self.assertEqual(len(log.output), 3)
self.assertIn("Cutting history off because it's too long", log.output[0])
self.assertIn("You might consider trimming the early phase of the conversation", log.output[1])
self.assertIn("Setting `pad_token_id`", log.output[2])
self.assertEqual(result, conversation_2)
self.assertEqual(
......@@ -128,65 +120,6 @@ class SimpleConversationPipelineTests(unittest.TestCase):
),
)
@require_torch
def test_history_cache(self):
conversation_agent = self.get_pipeline()
conversation = Conversation(
"Why do you recommend it?",
past_user_inputs=["What's the last book you have read?"],
generated_responses=["b"],
)
with self.assertLogs("transformers", level="WARNING") as log:
_ = conversation_agent(conversation, max_length=64)
self.assertEqual(len(log.output), 3)
self.assertIn("Cutting history off because it's too long (63 > 32) for underlying model", log.output[0])
self.assertIn("63 is bigger than 0.9 * max_length: 64", log.output[1])
self.assertIn("Setting `pad_token_id`", log.output[2])
self.assertEqual(conversation._index, 1)
self.assertEqual(
conversation._history,
[
87,
104,
97,
116,
39,
115,
32,
116,
104,
101,
32,
108,
97,
115,
116,
32,
98,
111,
111,
107,
32,
121,
111,
117,
32,
104,
97,
118,
101,
32,
114,
101,
97,
100,
63,
259, # EOS
98, # b
259, # EOS
],
)
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
......@@ -276,6 +209,102 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
self.assertEqual(result.generated_responses[1], "It's a comedy.")
@require_torch
@slow
def test_integration_torch_conversation_dialogpt_input_ids(self):
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
conversation_1 = Conversation("hello")
inputs = nlp._parse_and_tokenize([conversation_1])
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
inputs = nlp._parse_and_tokenize([conversation_2])
self.assertEqual(
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
)
inputs = nlp._parse_and_tokenize([conversation_1, conversation_2])
self.assertEqual(
inputs["input_ids"].tolist(),
[
[31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
],
)
@require_torch
@slow
def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer)
# test1
conversation_1 = Conversation("hello")
inputs = nlp._parse_and_tokenize([conversation_1])
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
# test2
conversation_1 = Conversation(
"I like lasagne.",
past_user_inputs=["hello"],
generated_responses=[
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
],
)
inputs = nlp._parse_and_tokenize([conversation_1])
self.assertEqual(
inputs["input_ids"].tolist(),
[
# This should be compared with the same conversation on ParlAI `safe_interactive` demo.
[
1710, # hello
86,
228, # Double space
228,
946,
304,
398,
6881,
558,
964,
38,
452,
315,
265,
6252,
452,
322,
968,
6884,
3146,
278,
306,
265,
617,
87,
388,
75,
341,
286,
521,
21,
228, # Double space
228,
281, # I like lasagne.
398,
6881,
558,
964,
21,
2, # EOS
]
],
)
@require_torch
@slow
def test_integration_torch_conversation_blenderbot_400M(self):
......@@ -295,11 +324,11 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
" Hello! How are you doing today? I just got back from a walk with my dog.",
)
conversation_1 = Conversation(" Lasagne hello")
conversation_1 = Conversation("Lasagne hello")
result = nlp(conversation_1, encoder_no_repeat_ngram_size=3)
self.assertEqual(
result.generated_responses[0],
" Lasagne is my favorite Italian dish. Do you like lasagne?",
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
)
conversation_1 = Conversation(
......@@ -311,10 +340,7 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
)
self.assertEqual(
result.generated_responses[0],
# ParlAI implementation output, we have a different one, but it's our
# second best, you can check by using num_return_sequences=10
# " Hello! How are you? I'm just getting ready to go to work, how about you?",
" Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.",
" Me too. I like how it can be topped with vegetables, meats, and condiments.",
)
@require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment