modeling_albert.py 53.2 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
21

Lysandre's avatar
Lysandre committed
22
23
import torch
import torch.nn as nn
Lysandre's avatar
Lysandre committed
24
from torch.nn import CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
25

Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
from ...activations import ACT2FN
from ...file_utils import (
28
29
30
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
31
    add_start_docstrings_to_model_forward,
32
33
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
34
from ...modeling_outputs import (
35
36
37
38
39
40
41
42
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
43
from ...modeling_utils import (
44
45
46
47
48
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
49
50
from ...utils import logging
from .configuration_albert import AlbertConfig
51

Aymeric Augustin's avatar
Aymeric Augustin committed
52

Lysandre Debut's avatar
Lysandre Debut committed
53
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
54

55
_CONFIG_FOR_DOC = "AlbertConfig"
56
57
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
58

59
60
61
62
63
64
65
66
67
68
69
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
70
71


Lysandre's avatar
Lysandre committed
72
73
74
75
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
    """ Load tf checkpoints in a pytorch model."""
    try:
        import re
76

Lysandre's avatar
Lysandre committed
77
78
79
        import numpy as np
        import tensorflow as tf
    except ImportError:
80
81
82
83
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
84
85
86
87
88
89
90
91
92
93
94
95
96
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
97
98
    for name, array in zip(names, arrays):
        print(name)
99

Lysandre's avatar
Lysandre committed
100
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
101
        original_name = name
Lysandre's avatar
Lysandre committed
102
103
104
105
106

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
107
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
108
        name = name.replace("bert/", "albert/")
109
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
110
        name = name.replace("transform/", "")
111
112
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
131
132
133
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
134
135
136
137
138

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

139
        # No ALBERT model currently handles the next sentence prediction task
140
        if "seq_relationship" in name:
141
142
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
143

144
        name = name.split("/")
145
146

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
147
148
149
150
151
152
153
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
154
155
156
            logger.info("Skipping {}".format("/".join(name)))
            continue

Lysandre's avatar
Lysandre committed
157
158
        pointer = model
        for m_name in name:
159
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
160
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
161
            else:
162
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
163

164
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
165
                pointer = getattr(pointer, "weight")
166
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
167
                pointer = getattr(pointer, "bias")
168
            elif scope_names[0] == "output_weights":
169
                pointer = getattr(pointer, "weight")
170
            elif scope_names[0] == "squad":
171
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
172
173
            else:
                try:
174
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
175
176
177
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
178
179
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
180
181
                pointer = pointer[num]

182
183
184
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
185
186
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
187
188
189
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
Lysandre's avatar
Lysandre committed
190
191
192
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
Lysandre's avatar
Lysandre committed
193
        print("Initialize PyTorch weight {} from {}".format(name, original_name))
Lysandre's avatar
Lysandre committed
194
195
196
197
198
        pointer.data = torch.from_numpy(array)

    return model


199
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
200
201
202
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
203

Lysandre's avatar
Lysandre committed
204
    def __init__(self, config):
205
        super().__init__()
206
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
207
208
209
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

210
211
212
213
214
215
216
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
217
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
218

Sylvain Gugger's avatar
Sylvain Gugger committed
219
    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
220
221
222
    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
223
224
225
226
227
228
229
230
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
231
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
232
233
234

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
235

236
237
238
239
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

240
241
242
243
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
244
245
246
247
248
249
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
250
    def __init__(self, config):
251
252
253
254
255
256
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
Lysandre's avatar
Lysandre committed
257
258

        self.num_attention_heads = config.num_attention_heads
259
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
260
        self.attention_head_size = config.hidden_size // config.num_attention_heads
261
262
263
264
265
266
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

267
268
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
269
270
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
271
272
        self.pruned_heads = set()

273
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
274
275
276
277
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

Sylvain Gugger's avatar
Sylvain Gugger committed
278
    # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
279
280
281
282
283
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
284
285
286
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
287
288
289
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
290
291
292
293
294
295
296
297
298
299
300
301

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

302
303
304
305
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
Lysandre's avatar
Lysandre committed
306
307
308
309
310
311
312
313

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
314

Lysandre's avatar
Lysandre committed
315
316
317
318
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

Lysandre's avatar
Lysandre committed
335
336
337
338
339
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
340
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
341
342
343
344
345
346
347
348

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
Lysandre's avatar
Lysandre committed
349
350

        # Should find a better way to do this
351
352
353
354
355
        w = (
            self.dense.weight.t()
            .view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
            .to(context_layer.dtype)
        )
356
        b = self.dense.bias.to(context_layer.dtype)
Lysandre's avatar
Lysandre committed
357
358

        projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
359
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
360
        layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
361
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
362
363


Lysandre's avatar
Lysandre committed
364
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
365
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
366
        super().__init__()
367

Lysandre's avatar
Lysandre committed
368
        self.config = config
369
370
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
371
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
372
        self.attention = AlbertAttention(config)
373
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
374
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
375
        self.activation = ACT2FN[config.hidden_act]
376
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
377

Joseph Liu's avatar
Joseph Liu committed
378
379
380
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
381
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
382
383

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
384
385
386
387
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
388
        )
389
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
390

391
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
392

393
394
395
396
397
398
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
399

Lysandre's avatar
Lysandre committed
400
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
401
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
402
        super().__init__()
403

Lysandre's avatar
Lysandre committed
404
405
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
406
407
408
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
409
410
411
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
412
        for layer_index, albert_layer in enumerate(self.albert_layers):
413
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
414
415
            hidden_states = layer_output[0]

416
            if output_attentions:
417
418
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
419
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
420
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
421

422
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
423
        if output_hidden_states:
424
            outputs = outputs + (layer_hidden_states,)
425
        if output_attentions:
426
427
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
428

Lysandre's avatar
Lysandre committed
429

Lysandre's avatar
Lysandre committed
430
431
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
432
        super().__init__()
433

Lysandre's avatar
Lysandre committed
434
        self.config = config
Lysandre's avatar
Lysandre committed
435
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
436
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
437

Joseph Liu's avatar
Joseph Liu committed
438
    def forward(
439
440
441
442
443
444
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
445
        return_dict=True,
Joseph Liu's avatar
Joseph Liu committed
446
    ):
Lysandre's avatar
Lysandre committed
447
448
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

449
450
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
451

452
453
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
454
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
455
456
457
458

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

459
460
461
462
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
463
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
464
                output_hidden_states,
465
            )
466
467
            hidden_states = layer_group_output[0]

468
            if output_attentions:
Lysandre's avatar
Lysandre committed
469
                all_attentions = all_attentions + layer_group_output[-1]
470

Joseph Liu's avatar
Joseph Liu committed
471
            if output_hidden_states:
472
473
                all_hidden_states = all_hidden_states + (hidden_states,)

474
        if not return_dict:
475
476
477
478
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
479

Lysandre's avatar
Lysandre committed
480

481
class AlbertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
482
483
484
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
485
    """
486

487
488
    config_class = AlbertConfig
    base_model_prefix = "albert"
489
    _keys_to_ignore_on_load_missing = [r"position_ids"]
490
491

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
492
        """Initialize the weights."""
493
        if isinstance(module, nn.Linear):
494
495
496
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
497
            if module.bias is not None:
498
                module.bias.data.zero_()
499
500
501
502
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
503
504
505
506
507
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


508
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
509
class AlbertForPreTrainingOutput(ModelOutput):
510
    """
511
    Output type of :class:`~transformers.AlbertForPreTraining`.
512
513

    Args:
514
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
515
516
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
517
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
518
519
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
520
521
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
522
523
524
525
526
527
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Sylvain Gugger's avatar
Sylvain Gugger committed
528
529
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
530
531
532
533
534

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

535
536
537
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
538
539
540
541
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
542
ALBERT_START_DOCSTRING = r"""
543

Sylvain Gugger's avatar
Sylvain Gugger committed
544
545
546
547
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
548
549
550
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
551

552
    Args:
553
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
554
555
556
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
557
558
559
"""

ALBERT_INPUTS_DOCSTRING = r"""
560
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
561
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
562
563
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
564
565
566
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
Lysandre's avatar
Lysandre committed
567

568
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
569
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
570
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
571
572

            - 1 for tokens that are **not masked**,
573
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
574

575
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
576
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
577
578
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
579
580
581

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
582

583
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
584
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
585
586
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
587

588
            `What are position IDs? <../glossary.html#position-ids>`_
589
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
590
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
591
592
593
594
595

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
596
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
597
598
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
599
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
602
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
603
604
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
605
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
606
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
607
608
"""

609
610
611
612
613

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
614
615
616
617
618
619
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    load_tf_weights = load_tf_weights_in_albert
    base_model_prefix = "albert"

620
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
621
        super().__init__(config)
Lysandre's avatar
Lysandre committed
622
623
624
625

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
626
627
628
629
630
631
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
632

Lysandre's avatar
Lysandre committed
633
634
        self.init_weights()

LysandreJik's avatar
LysandreJik committed
635
636
637
638
639
640
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

Lysandre's avatar
Lysandre committed
641
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
642
643
644
645
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
        a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
        model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
646

Lysandre's avatar
Lysandre committed
647
648
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
649

Sylvain Gugger's avatar
Sylvain Gugger committed
650
651
        Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
        information about head pruning
Lysandre's avatar
Lysandre committed
652
653
654
655
656
657
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

658
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
659
660
661
662
663
664
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
665
666
667
668
669
670
671
672
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
673
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
674
        output_hidden_states=None,
675
        return_dict=None,
676
    ):
677
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
678
679
680
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
681
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
682

LysandreJik's avatar
LysandreJik committed
683
684
685
686
687
688
689
690
691
692
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
693

Lysandre's avatar
Lysandre committed
694
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
695
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
696
        if token_type_ids is None:
LysandreJik's avatar
LysandreJik committed
697
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
698
699

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
700
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
701
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
702
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
703

704
705
706
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
707
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
708
709
710
711
712
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
713
            return_dict=return_dict,
714
        )
Lysandre's avatar
Lysandre committed
715

Lysandre's avatar
Lysandre committed
716
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
717

718
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
719

720
        if not return_dict:
721
722
723
724
725
726
727
728
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
729

730

731
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
732
    """
733
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
Sylvain Gugger's avatar
Sylvain Gugger committed
734
735
    `sentence order prediction (classification)` head.
    """,
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.predictions.decoder

751
752
753
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

754
755
756
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

757
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
758
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
759
760
761
762
763
764
765
766
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
767
        labels=None,
768
        sentence_order_label=None,
769
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
770
        output_hidden_states=None,
771
        return_dict=None,
772
773
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
774
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
775
776
777
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
778
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
779
780
781
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
            A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
782

Lysandre's avatar
Lysandre committed
783
        Returns:
784

Sylvain Gugger's avatar
Sylvain Gugger committed
785
        Example::
786

Lysandre's avatar
Lysandre committed
787
788
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
789

Lysandre's avatar
Lysandre committed
790
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
791
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
792

Lysandre's avatar
Lysandre committed
793
794
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
795

Lysandre's avatar
Lysandre committed
796
797
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
798
799

        """
800
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
801

802
803
804
805
806
807
808
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
809
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
810
            output_hidden_states=output_hidden_states,
811
            return_dict=return_dict,
812
813
814
815
816
817
818
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

819
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
820
        if labels is not None and sentence_order_label is not None:
821
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
822
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
823
824
825
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

826
        if not return_dict:
827
828
829
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
830
        return AlbertForPreTrainingOutput(
831
832
833
834
835
836
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
837
838


Lysandre's avatar
Lysandre committed
839
840
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
841
        super().__init__()
Lysandre's avatar
Lysandre committed
842
843
844
845
846
847
848

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]

849
850
851
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
852
853
854
855
856
857
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
858
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
859
860
861

        return prediction_scores

Lysandre's avatar
Lysandre committed
862

863
864
865
866
867
868
869
870
871
872
873
874
875
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


876
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
877
878
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
879
)
880
class AlbertForMaskedLM(AlbertPreTrainedModel):
881

882
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
883

Lysandre's avatar
Lysandre committed
884
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
885
        super().__init__(config)
Lysandre's avatar
Lysandre committed
886

887
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
888
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
889

Lysandre's avatar
Lysandre committed
890
        self.init_weights()
Lysandre's avatar
Lysandre committed
891

LysandreJik's avatar
LysandreJik committed
892
893
894
    def get_output_embeddings(self):
        return self.predictions.decoder

895
896
897
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

898
899
900
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

901
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
902
903
904
905
906
907
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
908
909
910
911
912
913
914
915
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
916
        labels=None,
917
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
918
        output_hidden_states=None,
919
        return_dict=None,
920
    ):
921
        r"""
922
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
923
924
925
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
926
        """
927
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
928

LysandreJik's avatar
LysandreJik committed
929
930
931
932
933
934
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
935
            inputs_embeds=inputs_embeds,
936
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
937
            output_hidden_states=output_hidden_states,
938
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
939
        )
940
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
941
942

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
943

944
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
945
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
946
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
947
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
948

949
        if not return_dict:
950
951
952
953
954
955
956
957
958
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
959
960


961
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
962
963
964
965
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
966
967
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
968
969
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
970
        super().__init__(config)
Lysandre's avatar
Lysandre committed
971
972
973
        self.num_labels = config.num_labels

        self.albert = AlbertModel(config)
974
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
975
976
977
978
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

979
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
980
981
982
983
984
985
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
986
987
988
989
990
991
992
993
994
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
995
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
996
        output_hidden_states=None,
997
        return_dict=None,
998
    ):
999
        r"""
1000
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1001
1002
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1003
1004
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
1005
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
1006

LysandreJik's avatar
LysandreJik committed
1007
1008
1009
1010
1011
1012
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1013
            inputs_embeds=inputs_embeds,
1014
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1015
            output_hidden_states=output_hidden_states,
1016
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1017
        )
Lysandre's avatar
Lysandre committed
1018
1019
1020
1021
1022
1023

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1024
        loss = None
Lysandre's avatar
Lysandre committed
1025
1026
1027
1028
1029
1030
1031
1032
1033
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1034
        if not return_dict:
1035
1036
1037
1038
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1039
1040
1041
1042
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1043
        )
Lysandre's avatar
Lysandre committed
1044
1045


1046
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1047
1048
1049
1050
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
1051
1052
1053
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1054

1055
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1056

1057
1058
1059
1060
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1061
        self.albert = AlbertModel(config, add_pooling_layer=False)
1062
1063
1064
1065
1066
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

1067
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1068
1069
1070
1071
1072
1073
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1074
1075
1076
1077
1078
1079
1080
1081
1082
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1083
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1084
        output_hidden_states=None,
1085
        return_dict=None,
1086
1087
    ):
        r"""
1088
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1089
1090
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
1091
        """
1092
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1093
1094
1095
1096
1097
1098
1099
1100

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1101
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1102
            output_hidden_states=output_hidden_states,
1103
            return_dict=return_dict,
1104
1105
1106
1107
1108
1109
1110
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1111
        loss = None
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1123
        if not return_dict:
1124
1125
1126
1127
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1128
1129
1130
1131
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1132
        )
1133
1134


1135
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1136
1137
1138
1139
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
1140
1141
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1142
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1143

1144
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1145

Lysandre's avatar
Lysandre committed
1146
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1147
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1148
1149
        self.num_labels = config.num_labels

1150
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1151
1152
1153
1154
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1155
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1156
1157
1158
1159
1160
1161
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1172
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1173
        output_hidden_states=None,
1174
        return_dict=None,
1175
    ):
1176
        r"""
1177
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1178
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1179
1180
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1181
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1182
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1183
1184
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1185
        """
1186
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1187
1188
1189
1190
1191
1192
1193

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1194
            inputs_embeds=inputs_embeds,
1195
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1196
            output_hidden_states=output_hidden_states,
1197
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1198
        )
Lysandre's avatar
Lysandre committed
1199
1200
1201
1202
1203
1204
1205
1206

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1207
        total_loss = None
Lysandre's avatar
Lysandre committed
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1224
        if not return_dict:
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1235
1236
1237


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1238
1239
1240
1241
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()

1254
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1255
1256
1257
1258
1259
1260
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1271
        output_hidden_states=None,
1272
        return_dict=None,
1273
1274
    ):
        r"""
1275
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1276
1277
1278
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
            `input_ids` above)
1279
        """
1280
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1300
            output_hidden_states=output_hidden_states,
1301
            return_dict=return_dict,
1302
1303
1304
1305
1306
1307
1308
1309
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1310
        loss = None
1311
1312
1313
1314
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1315
        if not return_dict:
1316
1317
1318
1319
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1320
1321
1322
1323
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1324
        )