modeling_albert.py 52.2 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
19
import warnings
20
21
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
22

Lysandre's avatar
Lysandre committed
23
24
import torch
import torch.nn as nn
Lysandre's avatar
Lysandre committed
25
from torch.nn import CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
26

27
from .activations import ACT2FN
28
from .configuration_albert import AlbertConfig
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from .file_utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
    replace_return_docstrings,
)
from .modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
45
46
47
48
49
50
from .modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Lysandre Debut's avatar
Lysandre Debut committed
51
from .utils import logging
52

Aymeric Augustin's avatar
Aymeric Augustin committed
53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
55

56
_CONFIG_FOR_DOC = "AlbertConfig"
57
58
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
59

60
61
62
63
64
65
66
67
68
69
70
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
71
72


Lysandre's avatar
Lysandre committed
73
74
75
76
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
    """ Load tf checkpoints in a pytorch model."""
    try:
        import re
77

Lysandre's avatar
Lysandre committed
78
79
80
        import numpy as np
        import tensorflow as tf
    except ImportError:
81
82
83
84
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
85
86
87
88
89
90
91
92
93
94
95
96
97
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
98
99
    for name, array in zip(names, arrays):
        print(name)
100

Lysandre's avatar
Lysandre committed
101
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
102
        original_name = name
Lysandre's avatar
Lysandre committed
103
104
105
106
107

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
108
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
109
        name = name.replace("bert/", "albert/")
110
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
111
        name = name.replace("transform/", "")
112
113
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
132
133
134
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
135
136
137
138
139

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

140
        # No ALBERT model currently handles the next sentence prediction task
141
        if "seq_relationship" in name:
142
143
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
144

145
        name = name.split("/")
146
147

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
148
149
150
151
152
153
154
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
155
156
157
            logger.info("Skipping {}".format("/".join(name)))
            continue

Lysandre's avatar
Lysandre committed
158
159
        pointer = model
        for m_name in name:
160
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
161
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
162
            else:
163
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
164

165
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
166
                pointer = getattr(pointer, "weight")
167
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
168
                pointer = getattr(pointer, "bias")
169
            elif scope_names[0] == "output_weights":
170
                pointer = getattr(pointer, "weight")
171
            elif scope_names[0] == "squad":
172
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
173
174
            else:
                try:
175
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
176
177
178
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
179
180
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
181
182
                pointer = pointer[num]

183
184
185
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
186
187
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
188
189
190
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
Lysandre's avatar
Lysandre committed
191
192
193
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
Lysandre's avatar
Lysandre committed
194
        print("Initialize PyTorch weight {} from {}".format(name, original_name))
Lysandre's avatar
Lysandre committed
195
196
197
198
199
        pointer.data = torch.from_numpy(array)

    return model


200
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
201
202
203
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
204

Lysandre's avatar
Lysandre committed
205
    def __init__(self, config):
206
        super().__init__()
207
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
208
209
210
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    # Copied from transformers.modeling_bert.BertEmbeddings.forward
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
233

234
235
236
237
238
239
240
241
242
243
244
245
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
246
    def __init__(self, config):
247
248
249
250
251
252
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
Lysandre's avatar
Lysandre committed
253
254

        self.num_attention_heads = config.num_attention_heads
255
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
256
        self.attention_head_size = config.hidden_size // config.num_attention_heads
257
258
259
260
261
262
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

263
264
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
265
266
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
267
268
        self.pruned_heads = set()

269
270
271
272
273
274
    # Copied from transformers.modeling_bert.BertSelfAttention.transpose_for_scores
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
275
276
277
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
278
279
280
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
281
282
283
284
285
286
287
288
289
290
291
292

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

293
    def forward(self, input_ids, attention_mask=None, head_mask=None, output_attentions=False):
Lysandre's avatar
Lysandre committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        mixed_query_layer = self.query(input_ids)
        mixed_key_layer = self.key(input_ids)
        mixed_value_layer = self.value(input_ids)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
314
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
315
316
317
318
319
320
321
322

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
Lysandre's avatar
Lysandre committed
323
324

        # Should find a better way to do this
325
326
327
328
329
        w = (
            self.dense.weight.t()
            .view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
            .to(context_layer.dtype)
        )
330
        b = self.dense.bias.to(context_layer.dtype)
Lysandre's avatar
Lysandre committed
331
332

        projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
333
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
Lysandre's avatar
Lysandre committed
334
        layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
335
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
336
337


Lysandre's avatar
Lysandre committed
338
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
339
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
340
        super().__init__()
341

Lysandre's avatar
Lysandre committed
342
        self.config = config
343
344
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
345
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
346
        self.attention = AlbertAttention(config)
347
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
348
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
349
        self.activation = ACT2FN[config.hidden_act]
350
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
351

Joseph Liu's avatar
Joseph Liu committed
352
353
354
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
355
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
356
357

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
358
359
360
361
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
362
        )
363
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
364

365
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
366

367
368
369
370
371
372
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
373

Lysandre's avatar
Lysandre committed
374
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
375
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
376
        super().__init__()
377

Lysandre's avatar
Lysandre committed
378
379
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
380
381
382
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
383
384
385
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
386
        for layer_index, albert_layer in enumerate(self.albert_layers):
387
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
388
389
            hidden_states = layer_output[0]

390
            if output_attentions:
391
392
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
393
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
394
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
395

396
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
397
        if output_hidden_states:
398
            outputs = outputs + (layer_hidden_states,)
399
        if output_attentions:
400
401
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
402

Lysandre's avatar
Lysandre committed
403

Lysandre's avatar
Lysandre committed
404
405
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
406
        super().__init__()
407

Lysandre's avatar
Lysandre committed
408
        self.config = config
Lysandre's avatar
Lysandre committed
409
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
410
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
411

Joseph Liu's avatar
Joseph Liu committed
412
    def forward(
413
414
415
416
417
418
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
419
        return_dict=False,
Joseph Liu's avatar
Joseph Liu committed
420
    ):
Lysandre's avatar
Lysandre committed
421
422
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

423
424
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
425

426
427
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
428
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
429
430
431
432

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

433
434
435
436
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
437
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
438
                output_hidden_states,
439
            )
440
441
            hidden_states = layer_group_output[0]

442
            if output_attentions:
Lysandre's avatar
Lysandre committed
443
                all_attentions = all_attentions + layer_group_output[-1]
444

Joseph Liu's avatar
Joseph Liu committed
445
            if output_hidden_states:
446
447
                all_hidden_states = all_hidden_states + (hidden_states,)

448
        if not return_dict:
449
450
451
452
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
453

Lysandre's avatar
Lysandre committed
454

455
class AlbertPreTrainedModel(PreTrainedModel):
Lysandre's avatar
Lysandre committed
456
457
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
458
    """
459

460
461
    config_class = AlbertConfig
    base_model_prefix = "albert"
Patrick von Platen's avatar
Patrick von Platen committed
462
    authorized_missing_keys = [r"position_ids"]
463
464

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
465
        """Initialize the weights."""
466
467
468
469
470
471
472
473
474
475
476
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, (nn.Linear)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


477
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
478
class AlbertForPreTrainingOutput(ModelOutput):
479
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
480
    Output type of :class:`~transformers.AlbertForPreTrainingModel`.
481
482

    Args:
483
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
484
485
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
486
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

504
505
506
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
507
508
509
510
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
511
ALBERT_START_DOCSTRING = r"""
512

Sylvain Gugger's avatar
Sylvain Gugger committed
513
514
515
516
517
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
Lysandre's avatar
Lysandre committed
518
519
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
    usage and behavior.
520

521
    Args:
522
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
523
524
525
526
527
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

ALBERT_INPUTS_DOCSTRING = r"""
528
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
529
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
530
531
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
532
533
534
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`.
            See :meth:`transformers.PreTrainedTokenizer.__call__` and
            :meth:`transformers.PreTrainedTokenizer.encode` for details.
Lysandre's avatar
Lysandre committed
535

536
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
537
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
538
539
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
540
541
542

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **maked**.
Lysandre's avatar
Lysandre committed
543

544
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
545
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
546
            Segment token indices to indicate first and second portions of the inputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
547
548
549
550
            Indices are selected in ``[0, 1]``:

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
551

552
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
553
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
554
555
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
556

557
            `What are position IDs? <../glossary.html#position-ids>`_
558
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
559
560
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
561
562
563
564
565

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
566
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
567
568
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
569
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
570
571
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
572
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
575
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
576
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
577
578
"""

579
580
581
582
583

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
584
585
586
587
588
589
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    load_tf_weights = load_tf_weights_in_albert
    base_model_prefix = "albert"

590
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
591
        super().__init__(config)
Lysandre's avatar
Lysandre committed
592
593
594
595

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
596
597
598
599
600
601
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
602

Lysandre's avatar
Lysandre committed
603
604
        self.init_weights()

LysandreJik's avatar
LysandreJik committed
605
606
607
608
609
610
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

611
612
613
614
615
    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self.embeddings.word_embeddings
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.embeddings.word_embeddings = new_embeddings
        return self.embeddings.word_embeddings
Lysandre's avatar
Lysandre committed
616

Lysandre's avatar
Lysandre committed
617
    def _prune_heads(self, heads_to_prune):
Lysandre's avatar
Lysandre committed
618
619
620
621
622
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
        If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
        is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
623

Lysandre's avatar
Lysandre committed
624
625
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
626

Lysandre's avatar
Lysandre committed
627
628
        Any layer with in index other than [0,1,2,3] will result in an error.
        See base class PreTrainedModel for more information about head pruning
Lysandre's avatar
Lysandre committed
629
630
631
632
633
634
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

Sylvain Gugger's avatar
Sylvain Gugger committed
635
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
636
637
638
639
640
641
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
642
643
644
645
646
647
648
649
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
650
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
651
        output_hidden_states=None,
652
        return_dict=None,
653
    ):
654
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
655
656
657
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
658
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
659

LysandreJik's avatar
LysandreJik committed
660
661
662
663
664
665
666
667
668
669
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
670

Lysandre's avatar
Lysandre committed
671
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
672
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
673
        if token_type_ids is None:
LysandreJik's avatar
LysandreJik committed
674
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
675
676

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
677
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
678
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
679
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
680

681
682
683
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
684
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
685
686
687
688
689
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
690
            return_dict=return_dict,
691
        )
Lysandre's avatar
Lysandre committed
692

Lysandre's avatar
Lysandre committed
693
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
694

695
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
696

697
        if not return_dict:
698
699
700
701
702
703
704
705
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
706

707

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
@add_start_docstrings(
    """Albert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
    a `sentence order prediction (classification)` head. """,
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.predictions.decoder

726
727
728
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
729
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
730
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
731
732
733
734
735
736
737
738
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
739
        labels=None,
740
        sentence_order_label=None,
741
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
742
        output_hidden_states=None,
743
        return_dict=None,
744
        **kwargs,
745
746
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
747
748
749
750
751
752
753
754
755
756
757
758
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates original order (sequence A, then sequence B),
            ``1`` indicates switched order (sequence B, then sequence A).
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
759

Lysandre's avatar
Lysandre committed
760
        Returns:
761

Sylvain Gugger's avatar
Sylvain Gugger committed
762
        Example::
763

Lysandre's avatar
Lysandre committed
764
765
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
766

Lysandre's avatar
Lysandre committed
767
768
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2', return_dict=True)
769

Lysandre's avatar
Lysandre committed
770
771
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
772

Lysandre's avatar
Lysandre committed
773
774
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
775
776
777

        """

Sylvain Gugger's avatar
Sylvain Gugger committed
778
779
780
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
781
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
782
783
784
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
785
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
786

787
788
789
790
791
792
793
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
794
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
795
            output_hidden_states=output_hidden_states,
796
            return_dict=return_dict,
797
798
799
800
801
802
803
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

804
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
805
        if labels is not None and sentence_order_label is not None:
806
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
807
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
808
809
810
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

811
        if not return_dict:
812
813
814
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
815
        return AlbertForPreTrainingOutput(
816
817
818
819
820
821
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
822
823


Lysandre's avatar
Lysandre committed
824
825
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
826
        super().__init__()
Lysandre's avatar
Lysandre committed
827
828
829
830
831
832
833

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]

834
835
836
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
837
838
839
840
841
842
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
843
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
844
845
846

        return prediction_scores

Lysandre's avatar
Lysandre committed
847

848
849
850
851
852
853
854
855
856
857
858
859
860
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


861
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
862
863
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
864
)
865
class AlbertForMaskedLM(AlbertPreTrainedModel):
866
867
868

    authorized_unexpected_keys = [r"pooler"]

Lysandre's avatar
Lysandre committed
869
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
870
        super().__init__(config)
Lysandre's avatar
Lysandre committed
871

872
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
873
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
874

Lysandre's avatar
Lysandre committed
875
        self.init_weights()
Lysandre's avatar
Lysandre committed
876

LysandreJik's avatar
LysandreJik committed
877
878
879
    def get_output_embeddings(self):
        return self.predictions.decoder

880
881
882
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
883
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
884
885
886
887
888
889
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
890
891
892
893
894
895
896
897
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
898
        labels=None,
899
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
900
        output_hidden_states=None,
901
        return_dict=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
902
        **kwargs
903
    ):
904
        r"""
905
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
906
            Labels for computing the masked language modeling loss.
Lysandre's avatar
Lysandre committed
907
908
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with
909
            labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
910
911
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
912
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
913
914
915
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
916
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
919
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
920
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
921

LysandreJik's avatar
LysandreJik committed
922
923
924
925
926
927
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
928
            inputs_embeds=inputs_embeds,
929
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
930
            output_hidden_states=output_hidden_states,
931
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
932
        )
933
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
934
935

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
936

937
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
938
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
939
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
940
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
941

942
        if not return_dict:
943
944
945
946
947
948
949
950
951
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
952
953


954
955
@add_start_docstrings(
    """Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
956
    the pooled output) e.g. for GLUE tasks. """,
957
958
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
959
960
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
961
        super().__init__(config)
Lysandre's avatar
Lysandre committed
962
963
964
        self.num_labels = config.num_labels

        self.albert = AlbertModel(config)
965
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
966
967
968
969
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

Sylvain Gugger's avatar
Sylvain Gugger committed
970
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
971
972
973
974
975
976
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
977
978
979
980
981
982
983
984
985
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
986
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
987
        output_hidden_states=None,
988
        return_dict=None,
989
    ):
990
        r"""
991
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
992
993
994
995
996
            Labels for computing the sequence classification/regression loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
997
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
998

LysandreJik's avatar
LysandreJik committed
999
1000
1001
1002
1003
1004
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1005
            inputs_embeds=inputs_embeds,
1006
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1007
            output_hidden_states=output_hidden_states,
1008
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1009
        )
Lysandre's avatar
Lysandre committed
1010
1011
1012
1013
1014
1015

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1016
        loss = None
Lysandre's avatar
Lysandre committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1026
        if not return_dict:
1027
1028
1029
1030
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1031
1032
1033
1034
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1035
        )
Lysandre's avatar
Lysandre committed
1036
1037


1038
1039
1040
1041
1042
1043
@add_start_docstrings(
    """Albert Model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1044
1045
1046

    authorized_unexpected_keys = [r"pooler"]

1047
1048
1049
1050
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1051
        self.albert = AlbertModel(config, add_pooling_layer=False)
1052
1053
1054
1055
1056
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

Sylvain Gugger's avatar
Sylvain Gugger committed
1057
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1058
1059
1060
1061
1062
1063
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1064
1065
1066
1067
1068
1069
1070
1071
1072
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1073
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1074
        output_hidden_states=None,
1075
        return_dict=None,
1076
1077
    ):
        r"""
1078
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1079
1080
1081
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
        """
1082
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1083
1084
1085
1086
1087
1088
1089
1090

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1091
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1092
            output_hidden_states=output_hidden_states,
1093
            return_dict=return_dict,
1094
1095
1096
1097
1098
1099
1100
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1101
        loss = None
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1113
        if not return_dict:
1114
1115
1116
1117
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1118
1119
1120
1121
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1122
        )
1123
1124


1125
1126
@add_start_docstrings(
    """Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
Lysandre's avatar
Lysandre committed
1127
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1128
1129
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1130
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1131
1132
1133

    authorized_unexpected_keys = [r"pooler"]

Lysandre's avatar
Lysandre committed
1134
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1135
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1136
1137
        self.num_labels = config.num_labels

1138
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1139
1140
1141
1142
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

Sylvain Gugger's avatar
Sylvain Gugger committed
1143
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1144
1145
1146
1147
1148
1149
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1160
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1161
        output_hidden_states=None,
1162
        return_dict=None,
1163
    ):
1164
        r"""
1165
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1166
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1167
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
1168
            Position outside of the sequence are not taken into account for computing the loss.
1169
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1170
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1171
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
1172
1173
            Position outside of the sequence are not taken into account for computing the loss.
        """
1174
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1175
1176
1177
1178
1179
1180
1181

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1182
            inputs_embeds=inputs_embeds,
1183
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1184
            output_hidden_states=output_hidden_states,
1185
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1186
        )
Lysandre's avatar
Lysandre committed
1187
1188
1189
1190
1191
1192
1193
1194

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1195
        total_loss = None
Lysandre's avatar
Lysandre committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1212
        if not return_dict:
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239


@add_start_docstrings(
    """Albert Model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()

Sylvain Gugger's avatar
Sylvain Gugger committed
1240
    @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1241
1242
1243
1244
1245
1246
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1257
        output_hidden_states=None,
1258
        return_dict=None,
1259
1260
    ):
        r"""
1261
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1262
1263
1264
1265
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
        """
1266
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1286
            output_hidden_states=output_hidden_states,
1287
            return_dict=return_dict,
1288
1289
1290
1291
1292
1293
1294
1295
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1296
        loss = None
1297
1298
1299
1300
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1301
        if not return_dict:
1302
1303
1304
1305
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1306
1307
1308
1309
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1310
        )