modeling_albert.py 52.9 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
21

Lysandre's avatar
Lysandre committed
22
23
import torch
import torch.nn as nn
Lysandre's avatar
Lysandre committed
24
from torch.nn import CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
25

Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
from ...activations import ACT2FN
from ...file_utils import (
28
29
30
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
31
    add_start_docstrings_to_model_forward,
32
33
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
34
from ...modeling_outputs import (
35
36
37
38
39
40
41
42
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
43
from ...modeling_utils import (
44
45
46
47
48
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
49
50
from ...utils import logging
from .configuration_albert import AlbertConfig
51

Aymeric Augustin's avatar
Aymeric Augustin committed
52

Lysandre Debut's avatar
Lysandre Debut committed
53
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
54

55
_CONFIG_FOR_DOC = "AlbertConfig"
56
57
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
58

59
60
61
62
63
64
65
66
67
68
69
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
70
71


Lysandre's avatar
Lysandre committed
72
73
74
75
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
    """ Load tf checkpoints in a pytorch model."""
    try:
        import re
76

Lysandre's avatar
Lysandre committed
77
78
79
        import numpy as np
        import tensorflow as tf
    except ImportError:
80
81
82
83
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
84
85
86
87
88
89
90
91
92
93
94
95
96
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
97
98
    for name, array in zip(names, arrays):
        print(name)
99

Lysandre's avatar
Lysandre committed
100
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
101
        original_name = name
Lysandre's avatar
Lysandre committed
102
103
104
105
106

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
107
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
108
        name = name.replace("bert/", "albert/")
109
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
110
        name = name.replace("transform/", "")
111
112
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
131
132
133
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
134
135
136
137
138

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

139
        # No ALBERT model currently handles the next sentence prediction task
140
        if "seq_relationship" in name:
141
142
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
143

144
        name = name.split("/")
145
146

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
147
148
149
150
151
152
153
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
154
155
156
            logger.info("Skipping {}".format("/".join(name)))
            continue

Lysandre's avatar
Lysandre committed
157
158
        pointer = model
        for m_name in name:
159
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
160
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
161
            else:
162
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
163

164
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
165
                pointer = getattr(pointer, "weight")
166
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
167
                pointer = getattr(pointer, "bias")
168
            elif scope_names[0] == "output_weights":
169
                pointer = getattr(pointer, "weight")
170
            elif scope_names[0] == "squad":
171
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
172
173
            else:
                try:
174
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
175
176
177
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
178
179
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
180
181
                pointer = pointer[num]

182
183
184
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
185
186
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
187
188
189
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
Lysandre's avatar
Lysandre committed
190
191
192
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
Lysandre's avatar
Lysandre committed
193
        print("Initialize PyTorch weight {} from {}".format(name, original_name))
Lysandre's avatar
Lysandre committed
194
195
196
197
198
        pointer.data = torch.from_numpy(array)

    return model


199
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
200
201
202
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
203

Lysandre's avatar
Lysandre committed
204
    def __init__(self, config):
205
        super().__init__()
206
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
207
208
209
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

210
211
212
213
214
215
216
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
217
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
218

Sylvain Gugger's avatar
Sylvain Gugger committed
219
    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
220
221
222
223
224
225
226
227
228
229
230
231
232
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
233

234
235
236
237
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

238
239
240
241
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
242
243
244
245
246
247
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
248
    def __init__(self, config):
249
250
251
252
253
254
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
Lysandre's avatar
Lysandre committed
255
256

        self.num_attention_heads = config.num_attention_heads
257
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
258
        self.attention_head_size = config.hidden_size // config.num_attention_heads
259
260
261
262
263
264
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

265
266
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
267
268
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
269
270
        self.pruned_heads = set()

271
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
272
273
274
275
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

Sylvain Gugger's avatar
Sylvain Gugger committed
276
    # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
277
278
279
280
281
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
282
283
284
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
285
286
287
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
288
289
290
291
292
293
294
295
296
297
298
299

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

300
301
302
303
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
Lysandre's avatar
Lysandre committed
304
305
306
307
308
309
310
311

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
312

Lysandre's avatar
Lysandre committed
313
314
315
316
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

Lysandre's avatar
Lysandre committed
333
334
335
336
337
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
338
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
339
340
341
342
343
344
345
346

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
Lysandre's avatar
Lysandre committed
347
348

        # Should find a better way to do this
349
350
351
352
353
        w = (
            self.dense.weight.t()
            .view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
            .to(context_layer.dtype)
        )
354
        b = self.dense.bias.to(context_layer.dtype)
Lysandre's avatar
Lysandre committed
355
356

        projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
357
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
358
        layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
359
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
360
361


Lysandre's avatar
Lysandre committed
362
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
363
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
364
        super().__init__()
365

Lysandre's avatar
Lysandre committed
366
        self.config = config
367
368
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
369
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
370
        self.attention = AlbertAttention(config)
371
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
372
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
373
        self.activation = ACT2FN[config.hidden_act]
374
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
375

Joseph Liu's avatar
Joseph Liu committed
376
377
378
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
379
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
380
381

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
382
383
384
385
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
386
        )
387
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
388

389
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
390

391
392
393
394
395
396
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
397

Lysandre's avatar
Lysandre committed
398
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
399
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
400
        super().__init__()
401

Lysandre's avatar
Lysandre committed
402
403
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
404
405
406
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
407
408
409
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
410
        for layer_index, albert_layer in enumerate(self.albert_layers):
411
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
412
413
            hidden_states = layer_output[0]

414
            if output_attentions:
415
416
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
417
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
418
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
419

420
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
421
        if output_hidden_states:
422
            outputs = outputs + (layer_hidden_states,)
423
        if output_attentions:
424
425
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
426

Lysandre's avatar
Lysandre committed
427

Lysandre's avatar
Lysandre committed
428
429
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
430
        super().__init__()
431

Lysandre's avatar
Lysandre committed
432
        self.config = config
Lysandre's avatar
Lysandre committed
433
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
434
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
435

Joseph Liu's avatar
Joseph Liu committed
436
    def forward(
437
438
439
440
441
442
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
443
        return_dict=True,
Joseph Liu's avatar
Joseph Liu committed
444
    ):
Lysandre's avatar
Lysandre committed
445
446
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

447
448
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
449

450
451
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
452
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
453
454
455
456

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

457
458
459
460
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
461
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
462
                output_hidden_states,
463
            )
464
465
            hidden_states = layer_group_output[0]

466
            if output_attentions:
Lysandre's avatar
Lysandre committed
467
                all_attentions = all_attentions + layer_group_output[-1]
468

Joseph Liu's avatar
Joseph Liu committed
469
            if output_hidden_states:
470
471
                all_hidden_states = all_hidden_states + (hidden_states,)

472
        if not return_dict:
473
474
475
476
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
477

Lysandre's avatar
Lysandre committed
478

479
class AlbertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
480
481
482
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
483
    """
484

485
486
    config_class = AlbertConfig
    base_model_prefix = "albert"
487
    _keys_to_ignore_on_load_missing = [r"position_ids"]
488
489

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
490
        """Initialize the weights."""
491
492
493
494
495
496
497
498
499
500
501
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, (nn.Linear)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


502
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
503
class AlbertForPreTrainingOutput(ModelOutput):
504
    """
505
    Output type of :class:`~transformers.AlbertForPreTraining`.
506
507

    Args:
508
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
509
510
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
511
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
512
513
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
514
515
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
516
517
518
519
520
521
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Sylvain Gugger's avatar
Sylvain Gugger committed
522
523
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
524
525
526
527
528

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

529
530
531
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
532
533
534
535
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
536
ALBERT_START_DOCSTRING = r"""
537

Sylvain Gugger's avatar
Sylvain Gugger committed
538
539
540
541
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
542
543
544
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
545

546
    Args:
547
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
548
549
550
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
551
552
553
"""

ALBERT_INPUTS_DOCSTRING = r"""
554
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
555
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
556
557
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
558
559
560
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
Lysandre's avatar
Lysandre committed
561

562
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
563
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
564
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
565
566

            - 1 for tokens that are **not masked**,
567
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
568

569
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
570
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
571
572
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
575

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
576

577
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
578
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
579
580
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
581

582
            `What are position IDs? <../glossary.html#position-ids>`_
583
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
584
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
585
586
587
588
589

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
590
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
591
592
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
593
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
596
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
597
598
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
599
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
600
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
601
602
"""

603
604
605
606
607

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
608
609
610
611
612
613
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    load_tf_weights = load_tf_weights_in_albert
    base_model_prefix = "albert"

614
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
615
        super().__init__(config)
Lysandre's avatar
Lysandre committed
616
617
618
619

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
620
621
622
623
624
625
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
626

Lysandre's avatar
Lysandre committed
627
628
        self.init_weights()

LysandreJik's avatar
LysandreJik committed
629
630
631
632
633
634
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

Lysandre's avatar
Lysandre committed
635
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
636
637
638
639
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
        a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
        model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
640

Lysandre's avatar
Lysandre committed
641
642
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
643

Sylvain Gugger's avatar
Sylvain Gugger committed
644
645
        Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
        information about head pruning
Lysandre's avatar
Lysandre committed
646
647
648
649
650
651
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

652
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
653
654
655
656
657
658
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
659
660
661
662
663
664
665
666
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
667
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
668
        output_hidden_states=None,
669
        return_dict=None,
670
    ):
671
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
672
673
674
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
675
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
676

LysandreJik's avatar
LysandreJik committed
677
678
679
680
681
682
683
684
685
686
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
687

Lysandre's avatar
Lysandre committed
688
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
689
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
690
        if token_type_ids is None:
LysandreJik's avatar
LysandreJik committed
691
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
692
693

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
694
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
695
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
696
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
697

698
699
700
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
701
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
702
703
704
705
706
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
707
            return_dict=return_dict,
708
        )
Lysandre's avatar
Lysandre committed
709

Lysandre's avatar
Lysandre committed
710
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
711

712
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
713

714
        if not return_dict:
715
716
717
718
719
720
721
722
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
723

724

725
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
726
    """
727
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
Sylvain Gugger's avatar
Sylvain Gugger committed
728
729
    `sentence order prediction (classification)` head.
    """,
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.predictions.decoder

745
746
747
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

748
749
750
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

751
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
752
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
753
754
755
756
757
758
759
760
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
761
        labels=None,
762
        sentence_order_label=None,
763
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
764
        output_hidden_states=None,
765
        return_dict=None,
766
767
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
768
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
769
770
771
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
772
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
773
774
775
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
            A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
776

Lysandre's avatar
Lysandre committed
777
        Returns:
778

Sylvain Gugger's avatar
Sylvain Gugger committed
779
        Example::
780

Lysandre's avatar
Lysandre committed
781
782
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
783

Lysandre's avatar
Lysandre committed
784
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
785
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
786

Lysandre's avatar
Lysandre committed
787
788
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
789

Lysandre's avatar
Lysandre committed
790
791
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
792
793

        """
794
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
795

796
797
798
799
800
801
802
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
803
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
804
            output_hidden_states=output_hidden_states,
805
            return_dict=return_dict,
806
807
808
809
810
811
812
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

813
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
814
        if labels is not None and sentence_order_label is not None:
815
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
816
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
817
818
819
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

820
        if not return_dict:
821
822
823
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
824
        return AlbertForPreTrainingOutput(
825
826
827
828
829
830
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
831
832


Lysandre's avatar
Lysandre committed
833
834
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
835
        super().__init__()
Lysandre's avatar
Lysandre committed
836
837
838
839
840
841
842

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]

843
844
845
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
846
847
848
849
850
851
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
852
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
853
854
855

        return prediction_scores

Lysandre's avatar
Lysandre committed
856

857
858
859
860
861
862
863
864
865
866
867
868
869
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


870
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
871
872
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
873
)
874
class AlbertForMaskedLM(AlbertPreTrainedModel):
875

876
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
877

Lysandre's avatar
Lysandre committed
878
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
879
        super().__init__(config)
Lysandre's avatar
Lysandre committed
880

881
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
882
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
883

Lysandre's avatar
Lysandre committed
884
        self.init_weights()
Lysandre's avatar
Lysandre committed
885

LysandreJik's avatar
LysandreJik committed
886
887
888
    def get_output_embeddings(self):
        return self.predictions.decoder

889
890
891
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

892
893
894
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

895
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
896
897
898
899
900
901
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
902
903
904
905
906
907
908
909
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
910
        labels=None,
911
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
912
        output_hidden_states=None,
913
        return_dict=None,
914
    ):
915
        r"""
916
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
919
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
920
        """
921
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
922

LysandreJik's avatar
LysandreJik committed
923
924
925
926
927
928
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
929
            inputs_embeds=inputs_embeds,
930
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
931
            output_hidden_states=output_hidden_states,
932
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
933
        )
934
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
935
936

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
937

938
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
939
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
940
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
941
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
942

943
        if not return_dict:
944
945
946
947
948
949
950
951
952
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
953
954


955
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
956
957
958
959
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
960
961
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
962
963
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
964
        super().__init__(config)
Lysandre's avatar
Lysandre committed
965
966
967
        self.num_labels = config.num_labels

        self.albert = AlbertModel(config)
968
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
969
970
971
972
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

973
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
974
975
976
977
978
979
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
980
981
982
983
984
985
986
987
988
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
989
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
990
        output_hidden_states=None,
991
        return_dict=None,
992
    ):
993
        r"""
994
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
995
996
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
997
998
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
999
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
1000

LysandreJik's avatar
LysandreJik committed
1001
1002
1003
1004
1005
1006
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1007
            inputs_embeds=inputs_embeds,
1008
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1009
            output_hidden_states=output_hidden_states,
1010
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1011
        )
Lysandre's avatar
Lysandre committed
1012
1013
1014
1015
1016
1017

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1018
        loss = None
Lysandre's avatar
Lysandre committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1028
        if not return_dict:
1029
1030
1031
1032
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1033
1034
1035
1036
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1037
        )
Lysandre's avatar
Lysandre committed
1038
1039


1040
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1041
1042
1043
1044
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
1045
1046
1047
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1048

1049
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1050

1051
1052
1053
1054
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1055
        self.albert = AlbertModel(config, add_pooling_layer=False)
1056
1057
1058
1059
1060
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

1061
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1062
1063
1064
1065
1066
1067
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1068
1069
1070
1071
1072
1073
1074
1075
1076
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1077
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1078
        output_hidden_states=None,
1079
        return_dict=None,
1080
1081
    ):
        r"""
1082
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1083
1084
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
1085
        """
1086
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1087
1088
1089
1090
1091
1092
1093
1094

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1095
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1096
            output_hidden_states=output_hidden_states,
1097
            return_dict=return_dict,
1098
1099
1100
1101
1102
1103
1104
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1105
        loss = None
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1117
        if not return_dict:
1118
1119
1120
1121
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1122
1123
1124
1125
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1126
        )
1127
1128


1129
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1130
1131
1132
1133
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
1134
1135
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1136
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1137

1138
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1139

Lysandre's avatar
Lysandre committed
1140
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1141
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1142
1143
        self.num_labels = config.num_labels

1144
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1145
1146
1147
1148
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1149
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1150
1151
1152
1153
1154
1155
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1166
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1167
        output_hidden_states=None,
1168
        return_dict=None,
1169
    ):
1170
        r"""
1171
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1172
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1173
1174
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1175
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1176
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1177
1178
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1179
        """
1180
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1181
1182
1183
1184
1185
1186
1187

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1188
            inputs_embeds=inputs_embeds,
1189
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1190
            output_hidden_states=output_hidden_states,
1191
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1192
        )
Lysandre's avatar
Lysandre committed
1193
1194
1195
1196
1197
1198
1199
1200

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1201
        total_loss = None
Lysandre's avatar
Lysandre committed
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1218
        if not return_dict:
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1229
1230
1231


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1232
1233
1234
1235
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()

1248
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1249
1250
1251
1252
1253
1254
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="albert-base-v2",
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1265
        output_hidden_states=None,
1266
        return_dict=None,
1267
1268
    ):
        r"""
1269
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1270
1271
1272
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
            `input_ids` above)
1273
        """
1274
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1294
            output_hidden_states=output_hidden_states,
1295
            return_dict=return_dict,
1296
1297
1298
1299
1300
1301
1302
1303
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1304
        loss = None
1305
1306
1307
1308
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1309
        if not return_dict:
1310
1311
1312
1313
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1314
1315
1316
1317
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1318
        )