test_modeling_roberta.py 22.1 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import unittest
18
from copy import deepcopy
19

20
from transformers import RobertaConfig, is_torch_available
21
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
thomwolf's avatar
thomwolf committed
22

Yih-Dar's avatar
Yih-Dar committed
23
24
25
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
Aymeric Augustin's avatar
Aymeric Augustin committed
26
27


28
if is_torch_available():
thomwolf's avatar
thomwolf committed
29
    import torch
30

31
    from transformers import (
32
        RobertaForCausalLM,
33
        RobertaForMaskedLM,
34
35
        RobertaForMultipleChoice,
        RobertaForQuestionAnswering,
36
37
        RobertaForSequenceClassification,
        RobertaForTokenClassification,
38
39
        RobertaModel,
    )
Sylvain Gugger's avatar
Sylvain Gugger committed
40
    from transformers.models.roberta.modeling_roberta import (
41
42
43
        ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
        RobertaEmbeddings,
        create_position_ids_from_input_ids,
44
    )
45

46
47
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"

48

49
50
class RobertaModelTester:
    def __init__(
Lysandre's avatar
Lysandre committed
51
52
        self,
        parent,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    ):
        self.parent = parent
        self.batch_size = 13
        self.seq_length = 7
        self.is_training = True
        self.use_input_mask = True
        self.use_token_type_ids = True
        self.use_labels = True
        self.vocab_size = 99
        self.hidden_size = 32
        self.num_hidden_layers = 5
        self.num_attention_heads = 4
        self.intermediate_size = 37
        self.hidden_act = "gelu"
        self.hidden_dropout_prob = 0.1
        self.attention_probs_dropout_prob = 0.1
        self.max_position_embeddings = 512
        self.type_vocab_size = 16
        self.type_sequence_label_size = 2
        self.initializer_range = 0.02
        self.num_labels = 3
        self.num_choices = 4
        self.scope = None

    def prepare_config_and_inputs(self):
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
82
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
83
84
85
86
87
88
89
90
91
92
93
94
95

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

96
97
98
99
100
101
        config = self.get_config()

        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

    def get_config(self):
        return RobertaConfig(
102
103
104
105
106
107
108
109
110
111
112
113
114
            vocab_size=self.vocab_size,
            hidden_size=self.hidden_size,
            num_hidden_layers=self.num_hidden_layers,
            num_attention_heads=self.num_attention_heads,
            intermediate_size=self.intermediate_size,
            hidden_act=self.hidden_act,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            type_vocab_size=self.type_vocab_size,
            initializer_range=self.initializer_range,
        )

115
116
117
118
119
    def get_pipeline_config(self):
        config = self.get_config()
        config.vocab_size = 300
        return config

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def prepare_config_and_inputs_for_decoder(self):
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = self.prepare_config_and_inputs()

        config.is_decoder = True
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)

        return (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

    def create_and_check_model(
148
149
150
151
152
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = RobertaModel(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
153
154
155
156
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        result = model(input_ids, token_type_ids=token_type_ids)
        result = model(input_ids)

Stas Bekman's avatar
Stas Bekman committed
157
158
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def create_and_check_model_as_decoder(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        config.add_cross_attention = True
        model = RobertaModel(config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
        result = model(
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            encoder_hidden_states=encoder_hidden_states,
        )
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))

    def create_and_check_for_causal_lm(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        model = RobertaForCausalLM(config=config)
        model.to(torch_device)
        model.eval()
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def create_and_check_decoder_model_past_large_inputs(
        self,
        config,
        input_ids,
        token_type_ids,
        input_mask,
        sequence_labels,
        token_labels,
        choice_labels,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        config.is_decoder = True
        config.add_cross_attention = True
        model = RobertaForCausalLM(config=config).to(torch_device).eval()

        # make sure that ids don't start with pad token
        mask = input_ids.ne(config.pad_token_id).long()
        input_ids = input_ids * mask

        # first forward pass
        outputs = model(
            input_ids,
            attention_mask=input_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values

        # create hypothetical multiple next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

        # make sure that ids don't start with pad token
        mask = next_tokens.ne(config.pad_token_id).long()
        next_tokens = next_tokens * mask
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)

        output_from_no_past = model(
            next_input_ids,
            attention_mask=next_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_hidden_states=True,
        )["hidden_states"][0]
        output_from_past = model(
            next_tokens,
            attention_mask=next_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            output_hidden_states=True,
        )["hidden_states"][0]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))

279
    def create_and_check_for_masked_lm(
280
281
282
283
284
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = RobertaForMaskedLM(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
285
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
286
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
287

288
    def create_and_check_for_token_classification(
289
290
291
292
293
294
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_labels = self.num_labels
        model = RobertaForTokenClassification(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
295
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
Stas Bekman's avatar
Stas Bekman committed
296
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
297

298
    def create_and_check_for_multiple_choice(
299
300
301
302
303
304
305
306
307
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        config.num_choices = self.num_choices
        model = RobertaForMultipleChoice(config=config)
        model.to(torch_device)
        model.eval()
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
308
        result = model(
309
310
311
312
313
            multiple_choice_inputs_ids,
            attention_mask=multiple_choice_input_mask,
            token_type_ids=multiple_choice_token_type_ids,
            labels=choice_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
314
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
315

316
    def create_and_check_for_question_answering(
317
318
319
320
321
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = RobertaForQuestionAnswering(config=config)
        model.to(torch_device)
        model.eval()
Sylvain Gugger's avatar
Sylvain Gugger committed
322
        result = model(
323
324
325
326
327
328
            input_ids,
            attention_mask=input_mask,
            token_type_ids=token_type_ids,
            start_positions=sequence_labels,
            end_positions=sequence_labels,
        )
Stas Bekman's avatar
Stas Bekman committed
329
330
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

    def prepare_config_and_inputs_for_common(self):
        config_and_inputs = self.prepare_config_and_inputs()
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
        ) = config_and_inputs
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
        return config, inputs_dict


347
@require_torch
348
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
349

350
351
    all_model_classes = (
        (
352
            RobertaForCausalLM,
353
354
355
356
357
358
359
360
361
362
            RobertaForMaskedLM,
            RobertaModel,
            RobertaForSequenceClassification,
            RobertaForTokenClassification,
            RobertaForMultipleChoice,
            RobertaForQuestionAnswering,
        )
        if is_torch_available()
        else ()
    )
363
    all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
364
    fx_compatible = True
365
366

    def setUp(self):
367
        self.model_tester = RobertaModelTester(self)
368
369
370
371
372
        self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)

    def test_config(self):
        self.config_tester.run_common_tests()

373
    def test_model(self):
374
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
375
376
        self.model_tester.create_and_check_model(*config_and_inputs)

377
378
379
380
381
382
    def test_model_various_embeddings(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
        for type in ["absolute", "relative_key", "relative_key_query"]:
            config_and_inputs[0].position_embedding_type = type
            self.model_tester.create_and_check_model(*config_and_inputs)

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    def test_model_as_decoder(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)

    def test_model_as_decoder_with_default_input_mask(self):
        # This regression test was failing with PyTorch < 1.3
        (
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()

        input_mask = None

        self.model_tester.create_and_check_model_as_decoder(
            config,
            input_ids,
            token_type_ids,
            input_mask,
            sequence_labels,
            token_labels,
            choice_labels,
            encoder_hidden_states,
            encoder_attention_mask,
        )

    def test_for_causal_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
        self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
418

419
420
421
422
    def test_decoder_model_past_with_large_inputs(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)

423
424
    def test_for_masked_lm(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
425
        self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
426

Lysandre's avatar
Lysandre committed
427
428
    def test_for_token_classification(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
429
        self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
Lysandre's avatar
Lysandre committed
430
431
432

    def test_for_multiple_choice(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
433
        self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
Lysandre's avatar
Lysandre committed
434
435
436

    def test_for_question_answering(self):
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
437
        self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
Lysandre's avatar
Lysandre committed
438

439
    @slow
440
    def test_model_from_pretrained(self):
441
        for model_name in ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
442
            model = RobertaModel.from_pretrained(model_name)
443
444
            self.assertIsNotNone(model)

Dom Hudson's avatar
Dom Hudson committed
445
    def test_create_position_ids_respects_padding_index(self):
Lysandre's avatar
Lysandre committed
446
        """Ensure that the default position ids only assign a sequential . This is a regression
Dom Hudson's avatar
Dom Hudson committed
447
448
449
450
451
452
453
454
455
        test for https://github.com/huggingface/transformers/issues/1761

        The position ids should be masked with the embedding object's padding index. Therefore, the
        first available non-padding position index is RobertaEmbeddings.padding_idx + 1
        """
        config = self.model_tester.prepare_config_and_inputs()[0]
        model = RobertaEmbeddings(config=config)

        input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
456
457
458
        expected_positions = torch.as_tensor(
            [[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
        )
Dom Hudson's avatar
Dom Hudson committed
459

Sam Shleifer's avatar
Sam Shleifer committed
460
        position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
461
        self.assertEqual(position_ids.shape, expected_positions.shape)
Dom Hudson's avatar
Dom Hudson committed
462
463
464
        self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))

    def test_create_position_ids_from_inputs_embeds(self):
Lysandre's avatar
Lysandre committed
465
        """Ensure that the default position ids only assign a sequential . This is a regression
Dom Hudson's avatar
Dom Hudson committed
466
467
468
469
470
471
        test for https://github.com/huggingface/transformers/issues/1761

        The position ids should be masked with the embedding object's padding index. Therefore, the
        first available non-padding position index is RobertaEmbeddings.padding_idx + 1
        """
        config = self.model_tester.prepare_config_and_inputs()[0]
472
473
        embeddings = RobertaEmbeddings(config=config)

474
        inputs_embeds = torch.empty(2, 4, 30)
475
476
477
478
479
480
481
482
        expected_single_positions = [
            0 + embeddings.padding_idx + 1,
            1 + embeddings.padding_idx + 1,
            2 + embeddings.padding_idx + 1,
            3 + embeddings.padding_idx + 1,
        ]
        expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
        position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
483
484
        self.assertEqual(position_ids.shape, expected_positions.shape)
        self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
485
486


Lysandre Debut's avatar
Lysandre Debut committed
487
@require_torch
488
class RobertaModelIntegrationTest(TestCasePlus):
489
    @slow
490
    def test_inference_masked_lm(self):
491
        model = RobertaForMaskedLM.from_pretrained("roberta-base")
492

493
        input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
494
495
        with torch.no_grad():
            output = model(input_ids)[0]
496
        expected_shape = torch.Size((1, 11, 50265))
497
        self.assertEqual(output.shape, expected_shape)
498
        # compare the actual values for a slice.
499
500
        expected_slice = torch.tensor(
            [[[33.8802, -4.3103, 22.7761], [4.6539, -2.8098, 13.6253], [1.8228, -3.6898, 8.8600]]]
501
        )
502
503
504
505
506
507

        # roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
        # roberta.eval()
        # expected_slice = roberta.model.forward(input_ids)[0][:, :3, :3].detach()

        self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
508

509
    @slow
510
    def test_inference_no_head(self):
511
        model = RobertaModel.from_pretrained("roberta-base")
512

513
        input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
514
515
        with torch.no_grad():
            output = model(input_ids)[0]
516
        # compare the actual values for a slice.
517
518
        expected_slice = torch.tensor(
            [[[-0.0231, 0.0782, 0.0074], [-0.1854, 0.0540, -0.0175], [0.0548, 0.0799, 0.1687]]]
519
        )
520
521
522
523
524
525

        # roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
        # roberta.eval()
        # expected_slice = roberta.extract_features(input_ids)[:, :3, :3].detach()

        self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
526

527
    @slow
528
    def test_inference_classification_head(self):
529
        model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli")
530

531
        input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
532
533
        with torch.no_grad():
            output = model(input_ids)[0]
534
        expected_shape = torch.Size((1, 3))
535
        self.assertEqual(output.shape, expected_shape)
536
537
538
539
540
541
542
        expected_tensor = torch.tensor([[-0.9469, 0.3913, 0.5118]])

        # roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
        # roberta.eval()
        # expected_tensor = roberta.predict("mnli", input_ids, return_logits=True).detach()

        self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

    # XXX: this might be a candidate for common tests if we have many of those
    def test_lm_head_ignore_keys(self):
        keys_to_ignore_on_save_tied = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
        keys_to_ignore_on_save_untied = [r"lm_head.decoder.bias"]
        config = RobertaConfig.from_pretrained(ROBERTA_TINY)
        config_tied = deepcopy(config)
        config_tied.tie_word_embeddings = True
        config_untied = deepcopy(config)
        config_untied.tie_word_embeddings = False
        for cls in [RobertaForMaskedLM, RobertaForCausalLM]:
            model = cls(config_tied)
            self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_tied, cls)

            # the keys should be different when embeddings aren't tied
            model = cls(config_untied)
            self.assertEqual(model._keys_to_ignore_on_save, keys_to_ignore_on_save_untied, cls)

            # test that saving works with updated ignore keys - just testing that it doesn't fail
            model.save_pretrained(self.get_auto_remove_tmp_dir())