test_pipelines_conversational.py 17.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
import unittest

17
from transformers import (
18
    AutoModelForCausalLM,
19
20
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
21
22
    BlenderbotSmallForConditionalGeneration,
    BlenderbotSmallTokenizer,
23
24
25
26
27
    Conversation,
    ConversationalPipeline,
    is_torch_available,
    pipeline,
)
28
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
29
30
31
32

from .test_pipelines_common import MonoInputPipelineCommonMixin


33
34
if is_torch_available():
    import torch
35
    from torch import nn
36
37
38

    from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel

39
40
41
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0


42
@is_pipeline_test
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class SimpleConversationPipelineTests(unittest.TestCase):
    def get_pipeline(self):
        # When
        config = GPT2Config(
            vocab_size=263,
            n_ctx=128,
            max_length=128,
            n_embd=64,
            n_layer=1,
            n_head=8,
            bos_token_id=256,
            eos_token_id=257,
        )
        model = GPT2LMHeadModel(config)
        # Force model output to be L
        V, D = model.lm_head.weight.shape
59
        bias = torch.zeros(V)
60
        bias[76] = 1
61
        weight = torch.zeros((V, D), requires_grad=True)
62

63
64
        model.lm_head.bias = nn.Parameter(bias)
        model.lm_head.weight = nn.Parameter(weight)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

        # # Created with:
        # import tempfile

        # from tokenizers import Tokenizer, models
        # from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

        # vocab = [(chr(i), i) for i in range(256)]
        # tokenizer = Tokenizer(models.Unigram(vocab))
        # with tempfile.NamedTemporaryFile() as f:
        #     tokenizer.save(f.name)
        #     real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="<eos>", bos_token="<bos>")

        # real_tokenizer._tokenizer.save("dummy.json")
        # Special tokens are automatically added at load time.
        tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test")
        conversation_agent = pipeline(
            task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer
        )
        return conversation_agent

    @require_torch
    def test_integration_torch_conversation(self):
        conversation_agent = self.get_pipeline()
        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")
        self.assertEqual(len(conversation_1.past_user_inputs), 0)
        self.assertEqual(len(conversation_2.past_user_inputs), 0)

94
        result = conversation_agent([conversation_1, conversation_2], max_length=48)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

        # Two conversations in one pass
        self.assertEqual(result, [conversation_1, conversation_2])
        self.assertEqual(
            result,
            [
                Conversation(
                    None,
                    past_user_inputs=["Going to the movies tonight - any suggestions?"],
                    generated_responses=["L"],
                ),
                Conversation(
                    None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"]
                ),
            ],
        )

        # One conversation with history
        conversation_2.add_user_input("Why do you recommend it?")
114
        result = conversation_agent(conversation_2, max_length=64)
115
116
117
118
119
120
121
122
123
124
125
126

        self.assertEqual(result, conversation_2)
        self.assertEqual(
            result,
            Conversation(
                None,
                past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
                generated_responses=["L", "L"],
            ),
        )


127
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
128
129
130
131
132
    pipeline_task = "conversational"
    small_models = []  # Models tested without the @slow decorator
    large_models = ["microsoft/DialoGPT-medium"]  # Models tested with the @slow decorator
    invalid_inputs = ["Hi there!", Conversation()]

133
    def _test_pipeline(
134
        self, conversation_agent
135
    ):  # override the default test method to check that the output is a `Conversation` object
136
        self.assertIsNotNone(conversation_agent)
137

138
139
140
        # We need to recreate conversation for successive tests to pass as
        # Conversation objects get *consumed* by the pipeline
        conversation = Conversation("Hi there!")
141
        mono_result = conversation_agent(conversation)
142
143
        self.assertIsInstance(mono_result, Conversation)

144
        conversations = [Conversation("Hi there!"), Conversation("How are you?")]
145
        multi_result = conversation_agent(conversations)
146
147
        self.assertIsInstance(multi_result, list)
        self.assertIsInstance(multi_result[0], Conversation)
148
        # Conversation have been consumed and are not valid anymore
149
        # Inactive conversations passed to the pipeline raise a ValueError
150
151
        self.assertRaises(ValueError, conversation_agent, conversation)
        self.assertRaises(ValueError, conversation_agent, conversations)
152
153

        for bad_input in self.invalid_inputs:
154
155
            self.assertRaises(Exception, conversation_agent, bad_input)
        self.assertRaises(Exception, conversation_agent, self.invalid_inputs)
156
157
158
159
160

    @require_torch
    @slow
    def test_integration_torch_conversation(self):
        # When
161
        conversation_agent = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM)
162
163
164
165
166
167
        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        conversation_2 = Conversation("What's the last book you have read?")
        # Then
        self.assertEqual(len(conversation_1.past_user_inputs), 0)
        self.assertEqual(len(conversation_2.past_user_inputs), 0)
        # When
168
        result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
169
170
171
172
173
174
175
176
177
178
179
180
        # Then
        self.assertEqual(result, [conversation_1, conversation_2])
        self.assertEqual(len(result[0].past_user_inputs), 1)
        self.assertEqual(len(result[1].past_user_inputs), 1)
        self.assertEqual(len(result[0].generated_responses), 1)
        self.assertEqual(len(result[1].generated_responses), 1)
        self.assertEqual(result[0].past_user_inputs[0], "Going to the movies tonight - any suggestions?")
        self.assertEqual(result[0].generated_responses[0], "The Big Lebowski")
        self.assertEqual(result[1].past_user_inputs[0], "What's the last book you have read?")
        self.assertEqual(result[1].generated_responses[0], "The Last Question")
        # When
        conversation_2.add_user_input("Why do you recommend it?")
181
        result = conversation_agent(conversation_2, do_sample=False, max_length=1000)
182
183
184
185
186
187
188
189
190
191
192
        # Then
        self.assertEqual(result, conversation_2)
        self.assertEqual(len(result.past_user_inputs), 2)
        self.assertEqual(len(result.generated_responses), 2)
        self.assertEqual(result.past_user_inputs[1], "Why do you recommend it?")
        self.assertEqual(result.generated_responses[1], "It's a good book.")

    @require_torch
    @slow
    def test_integration_torch_conversation_truncated_history(self):
        # When
193
        conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM)
194
195
196
197
        conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
        # Then
        self.assertEqual(len(conversation_1.past_user_inputs), 0)
        # When
198
        result = conversation_agent(conversation_1, do_sample=False, max_length=36)
199
200
201
202
203
204
205
206
        # Then
        self.assertEqual(result, conversation_1)
        self.assertEqual(len(result.past_user_inputs), 1)
        self.assertEqual(len(result.generated_responses), 1)
        self.assertEqual(result.past_user_inputs[0], "Going to the movies tonight - any suggestions?")
        self.assertEqual(result.generated_responses[0], "The Big Lebowski")
        # When
        conversation_1.add_user_input("Is it an action movie?")
207
        result = conversation_agent(conversation_1, do_sample=False, max_length=36)
208
209
210
211
212
213
        # Then
        self.assertEqual(result, conversation_1)
        self.assertEqual(len(result.past_user_inputs), 2)
        self.assertEqual(len(result.generated_responses), 2)
        self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
        self.assertEqual(result.generated_responses[1], "It's a comedy.")
214

215
216
217
218
219
    @require_torch
    @slow
    def test_integration_torch_conversation_dialogpt_input_ids(self):
        tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
        model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
220
        conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
221
222

        conversation_1 = Conversation("hello")
223
        inputs = conversation_agent._parse_and_tokenize([conversation_1])
224
225
226
        self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])

        conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
227
        inputs = conversation_agent._parse_and_tokenize([conversation_2])
228
229
230
231
        self.assertEqual(
            inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
        )

232
        inputs = conversation_agent._parse_and_tokenize([conversation_1, conversation_2])
233
234
235
236
237
238
239
240
241
242
243
244
245
        self.assertEqual(
            inputs["input_ids"].tolist(),
            [
                [31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
                [31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
            ],
        )

    @require_torch
    @slow
    def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
        tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
        model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
246
        conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
247
248
249

        # test1
        conversation_1 = Conversation("hello")
250
        inputs = conversation_agent._parse_and_tokenize([conversation_1])
251
252
253
254
255
256
257
258
259
260
        self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])

        # test2
        conversation_1 = Conversation(
            "I like lasagne.",
            past_user_inputs=["hello"],
            generated_responses=[
                " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
            ],
        )
261
        inputs = conversation_agent._parse_and_tokenize([conversation_1])
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        self.assertEqual(
            inputs["input_ids"].tolist(),
            [
                # This should be compared with the same conversation on ParlAI `safe_interactive` demo.
                [
                    1710,  # hello
                    86,
                    228,  # Double space
                    228,
                    946,
                    304,
                    398,
                    6881,
                    558,
                    964,
                    38,
                    452,
                    315,
                    265,
                    6252,
                    452,
                    322,
                    968,
                    6884,
                    3146,
                    278,
                    306,
                    265,
                    617,
                    87,
                    388,
                    75,
                    341,
                    286,
                    521,
                    21,
                    228,  # Double space
                    228,
                    281,  # I like lasagne.
                    398,
                    6881,
                    558,
                    964,
                    21,
                    2,  # EOS
                ]
            ],
        )

311
312
313
314
315
    @require_torch
    @slow
    def test_integration_torch_conversation_blenderbot_400M(self):
        tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
        model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill")
316
        conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
317
318

        conversation_1 = Conversation("hello")
319
        result = conversation_agent(
320
321
322
323
324
325
326
327
328
329
            conversation_1,
        )
        self.assertEqual(
            result.generated_responses[0],
            # ParlAI implementation output, we have a different one, but it's our
            # second best, you can check by using num_return_sequences=10
            # " Hello! How are you? I'm just getting ready to go to work, how about you?",
            " Hello! How are you doing today? I just got back from a walk with my dog.",
        )

330
        conversation_1 = Conversation("Lasagne   hello")
331
        result = conversation_agent(conversation_1, encoder_no_repeat_ngram_size=3)
332
333
        self.assertEqual(
            result.generated_responses[0],
334
            " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie.",
335
336
337
338
339
        )

        conversation_1 = Conversation(
            "Lasagne   hello   Lasagne is my favorite Italian dish. Do you like lasagne?   I like lasagne."
        )
340
        result = conversation_agent(
341
342
343
344
345
            conversation_1,
            encoder_no_repeat_ngram_size=3,
        )
        self.assertEqual(
            result.generated_responses[0],
346
            " Me too. I like how it can be topped with vegetables, meats, and condiments.",
347
348
        )

349
350
351
352
    @require_torch
    @slow
    def test_integration_torch_conversation_encoder_decoder(self):
        # When
Lysandre Debut's avatar
Lysandre Debut committed
353
354
        tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M")
        model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M")
355
        conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM)
356
357
358
359
360
361
362

        conversation_1 = Conversation("My name is Sarah and I live in London")
        conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
        # Then
        self.assertEqual(len(conversation_1.past_user_inputs), 0)
        self.assertEqual(len(conversation_2.past_user_inputs), 0)
        # When
363
        result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        # Then
        self.assertEqual(result, [conversation_1, conversation_2])
        self.assertEqual(len(result[0].past_user_inputs), 1)
        self.assertEqual(len(result[1].past_user_inputs), 1)
        self.assertEqual(len(result[0].generated_responses), 1)
        self.assertEqual(len(result[1].generated_responses), 1)
        self.assertEqual(result[0].past_user_inputs[0], "My name is Sarah and I live in London")
        self.assertEqual(
            result[0].generated_responses[0],
            "hi sarah, i live in london as well. do you have any plans for the weekend?",
        )
        self.assertEqual(
            result[1].past_user_inputs[0], "Going to the movies tonight, What movie would you recommend? "
        )
        self.assertEqual(
            result[1].generated_responses[0], "i don't know... i'm not really sure. what movie are you going to see?"
        )
        # When
        conversation_1.add_user_input("Not yet, what about you?")
        conversation_2.add_user_input("What's your name?")
384
        result = conversation_agent([conversation_1, conversation_2], do_sample=False, max_length=1000)
385
386
387
388
389
390
391
392
393
394
        # Then
        self.assertEqual(result, [conversation_1, conversation_2])
        self.assertEqual(len(result[0].past_user_inputs), 2)
        self.assertEqual(len(result[1].past_user_inputs), 2)
        self.assertEqual(len(result[0].generated_responses), 2)
        self.assertEqual(len(result[1].generated_responses), 2)
        self.assertEqual(result[0].past_user_inputs[1], "Not yet, what about you?")
        self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
        self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
        self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    @require_torch
    @slow
    def test_from_pipeline_conversation(self):
        model_id = "facebook/blenderbot_small-90M"

        # from model id
        conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)

        # from model object
        model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
        tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
        conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)

        conversation = Conversation("My name is Sarah and I live in London")
        conversation_copy = Conversation("My name is Sarah and I live in London")

        result_model_id = conversation_agent_from_model_id([conversation])
        result_model = conversation_agent_from_model([conversation_copy])

        # check for equality
        self.assertEqual(
            result_model_id.generated_responses[0],
            "hi sarah, i live in london as well. do you have any plans for the weekend?",
        )
        self.assertEqual(
            result_model_id.generated_responses[0],
            result_model.generated_responses[0],
        )