modeling_albert.py 55.8 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
21

Lysandre's avatar
Lysandre committed
22
import torch
23
from packaging import version
24
from torch import nn
25
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
26

Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
from ...activations import ACT2FN
from ...file_utils import (
29
30
31
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
32
    add_start_docstrings_to_model_forward,
33
34
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from ...modeling_outputs import (
36
37
38
39
40
41
42
43
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_albert import AlbertConfig
52

Aymeric Augustin's avatar
Aymeric Augustin committed
53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
55

56
_CHECKPOINT_FOR_DOC = "albert-base-v2"
57
_CONFIG_FOR_DOC = "AlbertConfig"
58
59
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
60

61
62
63
64
65
66
67
68
69
70
71
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
72
73


Lysandre's avatar
Lysandre committed
74
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
Patrick von Platen's avatar
Patrick von Platen committed
75
    """Load tf checkpoints in a pytorch model."""
Lysandre's avatar
Lysandre committed
76
77
    try:
        import re
78

Lysandre's avatar
Lysandre committed
79
80
81
        import numpy as np
        import tensorflow as tf
    except ImportError:
82
83
84
85
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
86
87
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
88
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
Lysandre's avatar
Lysandre committed
89
90
91
92
93
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
94
        logger.info(f"Loading TF weight {name} with shape {shape}")
Lysandre's avatar
Lysandre committed
95
96
97
98
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
99
100
    for name, array in zip(names, arrays):
        print(name)
101

Lysandre's avatar
Lysandre committed
102
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
103
        original_name = name
Lysandre's avatar
Lysandre committed
104
105
106
107
108

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
109
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
110
        name = name.replace("bert/", "albert/")
111
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
112
        name = name.replace("transform/", "")
113
114
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
133
134
135
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
136
137
138
139
140

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

141
        # No ALBERT model currently handles the next sentence prediction task
142
        if "seq_relationship" in name:
143
144
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
145

146
        name = name.split("/")
147
148

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
149
150
151
152
153
154
155
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
156
            logger.info(f"Skipping {'/'.join(name)}")
157
158
            continue

Lysandre's avatar
Lysandre committed
159
160
        pointer = model
        for m_name in name:
161
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
162
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
163
            else:
164
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
165

166
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
167
                pointer = getattr(pointer, "weight")
168
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
169
                pointer = getattr(pointer, "bias")
170
            elif scope_names[0] == "output_weights":
171
                pointer = getattr(pointer, "weight")
172
            elif scope_names[0] == "squad":
173
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
174
175
            else:
                try:
176
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
177
                except AttributeError:
178
                    logger.info(f"Skipping {'/'.join(name)}")
Lysandre's avatar
Lysandre committed
179
                    continue
180
181
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
182
183
                pointer = pointer[num]

184
185
186
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
187
188
            array = np.transpose(array)
        try:
189
190
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
Lysandre's avatar
Lysandre committed
191
192
193
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
194
        print(f"Initialize PyTorch weight {name} from {original_name}")
Lysandre's avatar
Lysandre committed
195
196
197
198
199
        pointer.data = torch.from_numpy(array)

    return model


200
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
201
202
203
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
204

Lysandre's avatar
Lysandre committed
205
    def __init__(self, config):
206
        super().__init__()
207
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
208
209
210
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

211
212
213
214
215
216
217
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
218
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
219
220
221
222
223
224
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            self.register_buffer(
                "token_type_ids",
                torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
                persistent=False,
            )
225

Sylvain Gugger's avatar
Sylvain Gugger committed
226
    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
227
228
229
    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
230
231
232
233
234
235
236
237
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
238
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
239

240
241
242
        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
243
        if token_type_ids is None:
244
245
246
247
248
249
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
250

251
252
253
254
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

255
256
257
258
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
259
260
261
262
263
264
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
265
    def __init__(self, config):
266
267
268
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
269
270
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads}"
271
            )
Lysandre's avatar
Lysandre committed
272
273

        self.num_attention_heads = config.num_attention_heads
274
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
275
        self.attention_head_size = config.hidden_size // config.num_attention_heads
276
277
278
279
280
281
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

282
283
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
284
285
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
286
287
        self.pruned_heads = set()

288
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
289
290
291
292
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

Sylvain Gugger's avatar
Sylvain Gugger committed
293
    # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
294
295
296
297
298
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
299
300
301
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
302
303
304
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
305
306
307
308
309
310
311
312
313
314
315
316

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

317
318
319
320
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
Lysandre's avatar
Lysandre committed
321
322
323
324
325
326
327
328

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
329

Lysandre's avatar
Lysandre committed
330
331
332
333
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

Lysandre's avatar
Lysandre committed
350
351
352
353
354
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
355
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
356
357
358
359
360
361

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)
362
        context_layer = context_layer.transpose(2, 1).flatten(2)
Lysandre's avatar
Lysandre committed
363

364
        projected_context_layer = self.dense(context_layer)
365
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
366
        layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
367
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
368
369


Lysandre's avatar
Lysandre committed
370
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
371
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
372
        super().__init__()
373

Lysandre's avatar
Lysandre committed
374
        self.config = config
375
376
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
377
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
378
        self.attention = AlbertAttention(config)
379
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
380
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
381
        self.activation = ACT2FN[config.hidden_act]
382
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
383

Joseph Liu's avatar
Joseph Liu committed
384
385
386
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
387
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
388
389

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
390
391
392
393
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
394
        )
395
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
396

397
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
398

399
400
401
402
403
404
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
405

Lysandre's avatar
Lysandre committed
406
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
407
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
408
        super().__init__()
409

Lysandre's avatar
Lysandre committed
410
411
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
412
413
414
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
415
416
417
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
418
        for layer_index, albert_layer in enumerate(self.albert_layers):
419
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
420
421
            hidden_states = layer_output[0]

422
            if output_attentions:
423
424
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
425
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
426
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
427

428
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
429
        if output_hidden_states:
430
            outputs = outputs + (layer_hidden_states,)
431
        if output_attentions:
432
433
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
434

Lysandre's avatar
Lysandre committed
435

Lysandre's avatar
Lysandre committed
436
437
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
438
        super().__init__()
439

Lysandre's avatar
Lysandre committed
440
        self.config = config
Lysandre's avatar
Lysandre committed
441
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
442
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
443

Joseph Liu's avatar
Joseph Liu committed
444
    def forward(
445
446
447
448
449
450
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
451
        return_dict=True,
Joseph Liu's avatar
Joseph Liu committed
452
    ):
Lysandre's avatar
Lysandre committed
453
454
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

455
456
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
457

458
459
        head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask

460
461
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
462
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
463
464
465
466

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

467
468
469
470
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
471
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
472
                output_hidden_states,
473
            )
474
475
            hidden_states = layer_group_output[0]

476
            if output_attentions:
Lysandre's avatar
Lysandre committed
477
                all_attentions = all_attentions + layer_group_output[-1]
478

Joseph Liu's avatar
Joseph Liu committed
479
            if output_hidden_states:
480
481
                all_hidden_states = all_hidden_states + (hidden_states,)

482
        if not return_dict:
483
484
485
486
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
487

Lysandre's avatar
Lysandre committed
488

489
class AlbertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
490
491
492
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
493
    """
494

495
    config_class = AlbertConfig
496
    load_tf_weights = load_tf_weights_in_albert
497
    base_model_prefix = "albert"
498
    _keys_to_ignore_on_load_missing = [r"position_ids"]
499
500

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
501
        """Initialize the weights."""
502
        if isinstance(module, nn.Linear):
503
504
505
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
506
            if module.bias is not None:
507
                module.bias.data.zero_()
508
509
510
511
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
512
513
514
515
516
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


517
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
518
class AlbertForPreTrainingOutput(ModelOutput):
519
    """
520
    Output type of :class:`~transformers.AlbertForPreTraining`.
521
522

    Args:
523
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
526
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
527
528
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
529
530
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
531
532
533
534
535
536
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Sylvain Gugger's avatar
Sylvain Gugger committed
537
538
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
539
540
541
542
543

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

544
545
546
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
547
548
549
550
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
551
ALBERT_START_DOCSTRING = r"""
552

Sylvain Gugger's avatar
Sylvain Gugger committed
553
554
555
556
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
560

561
    Args:
562
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
563
564
565
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
566
567
568
"""

ALBERT_INPUTS_DOCSTRING = r"""
569
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
570
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
571
572
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
575
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
Lysandre's avatar
Lysandre committed
576

577
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
578
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
579
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
580
581

            - 1 for tokens that are **not masked**,
582
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
583

584
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
585
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
586
587
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
588
589
590

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
591

592
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
593
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
596

597
            `What are position IDs? <../glossary.html#position-ids>`_
598
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
599
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
602
603
604

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
605
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
606
607
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
608
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
611
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
612
613
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
614
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
615
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
616
617
"""

618
619
620
621
622

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
623
624
625
626
627
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    base_model_prefix = "albert"

628
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
629
        super().__init__(config)
Lysandre's avatar
Lysandre committed
630
631
632
633

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
634
635
636
637
638
639
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
640

641
642
        # Initialize weights and apply final processing
        self.post_init()
Lysandre's avatar
Lysandre committed
643

LysandreJik's avatar
LysandreJik committed
644
645
646
647
648
649
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

Lysandre's avatar
Lysandre committed
650
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
651
652
653
654
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
        a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
        model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
655

Lysandre's avatar
Lysandre committed
656
657
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
658

Sylvain Gugger's avatar
Sylvain Gugger committed
659
660
        Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
        information about head pruning
Lysandre's avatar
Lysandre committed
661
662
663
664
665
666
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

667
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
668
    @add_code_sample_docstrings(
669
        processor_class=_TOKENIZER_FOR_DOC,
670
        checkpoint=_CHECKPOINT_FOR_DOC,
671
672
673
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
674
675
676
677
678
679
680
681
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
682
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
683
        output_hidden_states=None,
684
        return_dict=None,
685
    ):
686
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
687
688
689
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
690
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691

LysandreJik's avatar
LysandreJik committed
692
693
694
695
696
697
698
699
700
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

701
        batch_size, seq_length = input_shape
LysandreJik's avatar
LysandreJik committed
702
        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
703

Lysandre's avatar
Lysandre committed
704
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
705
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
706
        if token_type_ids is None:
707
708
709
710
711
712
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
713
714

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
715
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
716
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
717
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
718

719
720
721
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
722
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
723
724
725
726
727
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
728
            return_dict=return_dict,
729
        )
Lysandre's avatar
Lysandre committed
730

Lysandre's avatar
Lysandre committed
731
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
732

733
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
734

735
        if not return_dict:
736
737
738
739
740
741
742
743
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
744

745

746
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
747
    """
748
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
Sylvain Gugger's avatar
Sylvain Gugger committed
749
750
    `sentence order prediction (classification)` head.
    """,
751
752
753
754
755
756
757
758
759
760
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

761
762
        # Initialize weights and apply final processing
        self.post_init()
763
764
765
766

    def get_output_embeddings(self):
        return self.predictions.decoder

767
768
769
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

770
771
772
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

773
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
774
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
775
776
777
778
779
780
781
782
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
783
        labels=None,
784
        sentence_order_label=None,
785
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
786
        output_hidden_states=None,
787
        return_dict=None,
788
789
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
790
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
791
792
793
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
794
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
795
796
797
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
            A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
798

Lysandre's avatar
Lysandre committed
799
        Returns:
800

Sylvain Gugger's avatar
Sylvain Gugger committed
801
        Example::
802

Lysandre's avatar
Lysandre committed
803
804
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
805

Lysandre's avatar
Lysandre committed
806
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
807
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
808

Lysandre's avatar
Lysandre committed
809
810
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
811

Lysandre's avatar
Lysandre committed
812
813
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
814
815

        """
816
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
817

818
819
820
821
822
823
824
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
825
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
826
            output_hidden_states=output_hidden_states,
827
            return_dict=return_dict,
828
829
830
831
832
833
834
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

835
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
836
        if labels is not None and sentence_order_label is not None:
837
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
838
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
839
840
841
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

842
        if not return_dict:
843
844
845
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
846
        return AlbertForPreTrainingOutput(
847
848
849
850
851
852
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
853
854


Lysandre's avatar
Lysandre committed
855
856
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
857
        super().__init__()
Lysandre's avatar
Lysandre committed
858
859
860
861
862
863

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]
864
865
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
866
867
868
869
870
871
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
872
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
873
874
875

        return prediction_scores

Sylvain Gugger's avatar
Sylvain Gugger committed
876
877
878
879
    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias

Lysandre's avatar
Lysandre committed
880

881
882
883
884
885
886
887
888
889
890
891
892
893
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


894
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
895
896
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
897
)
898
class AlbertForMaskedLM(AlbertPreTrainedModel):
899

900
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
901

Lysandre's avatar
Lysandre committed
902
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
903
        super().__init__(config)
Lysandre's avatar
Lysandre committed
904

905
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
906
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
907

908
909
        # Initialize weights and apply final processing
        self.post_init()
Lysandre's avatar
Lysandre committed
910

LysandreJik's avatar
LysandreJik committed
911
912
913
    def get_output_embeddings(self):
        return self.predictions.decoder

914
915
916
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

917
918
919
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

920
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
921
    @add_code_sample_docstrings(
922
        processor_class=_TOKENIZER_FOR_DOC,
923
        checkpoint=_CHECKPOINT_FOR_DOC,
924
925
926
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
927
928
929
930
931
932
933
934
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
935
        labels=None,
936
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
937
        output_hidden_states=None,
938
        return_dict=None,
939
    ):
940
        r"""
941
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
942
943
944
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
945
        """
946
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
947

LysandreJik's avatar
LysandreJik committed
948
949
950
951
952
953
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
954
            inputs_embeds=inputs_embeds,
955
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
956
            output_hidden_states=output_hidden_states,
957
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
958
        )
959
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
960
961

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
962

963
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
964
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
965
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
966
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
967

968
        if not return_dict:
969
970
971
972
973
974
975
976
977
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
978
979


980
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
981
982
983
984
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
985
986
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
987
988
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
989
        super().__init__(config)
Lysandre's avatar
Lysandre committed
990
        self.num_labels = config.num_labels
991
        self.config = config
Lysandre's avatar
Lysandre committed
992
993

        self.albert = AlbertModel(config)
994
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
995
996
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

997
998
        # Initialize weights and apply final processing
        self.post_init()
Lysandre's avatar
Lysandre committed
999

1000
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1001
    @add_code_sample_docstrings(
1002
        processor_class=_TOKENIZER_FOR_DOC,
1003
        checkpoint=_CHECKPOINT_FOR_DOC,
1004
1005
1006
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1007
1008
1009
1010
1011
1012
1013
1014
1015
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1016
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1017
        output_hidden_states=None,
1018
        return_dict=None,
1019
    ):
1020
        r"""
1021
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1022
1023
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1024
1025
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
1026
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
1027

LysandreJik's avatar
LysandreJik committed
1028
1029
1030
1031
1032
1033
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1034
            inputs_embeds=inputs_embeds,
1035
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1036
            output_hidden_states=output_hidden_states,
1037
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1038
        )
Lysandre's avatar
Lysandre committed
1039
1040
1041
1042
1043
1044

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1045
        loss = None
Lysandre's avatar
Lysandre committed
1046
        if labels is not None:
1047
1048
1049
1050
1051
1052
1053
1054
1055
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
Lysandre's avatar
Lysandre committed
1056
                loss_fct = MSELoss()
1057
1058
1059
1060
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
1061
            elif self.config.problem_type == "single_label_classification":
Lysandre's avatar
Lysandre committed
1062
1063
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1064
1065
1066
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
Lysandre's avatar
Lysandre committed
1067

1068
        if not return_dict:
1069
1070
1071
1072
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1073
1074
1075
1076
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1077
        )
Lysandre's avatar
Lysandre committed
1078
1079


1080
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
1082
1083
1084
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
1085
1086
1087
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1088

1089
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1090

1091
1092
1093
1094
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1095
        self.albert = AlbertModel(config, add_pooling_layer=False)
1096
1097
1098
1099
1100
1101
        classifier_dropout_prob = (
            config.classifier_dropout_prob
            if config.classifier_dropout_prob is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout_prob)
1102
1103
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

1104
1105
        # Initialize weights and apply final processing
        self.post_init()
1106

1107
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1108
    @add_code_sample_docstrings(
1109
        processor_class=_TOKENIZER_FOR_DOC,
1110
        checkpoint=_CHECKPOINT_FOR_DOC,
1111
1112
1113
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1114
1115
1116
1117
1118
1119
1120
1121
1122
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1123
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1124
        output_hidden_states=None,
1125
        return_dict=None,
1126
1127
    ):
        r"""
1128
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1129
1130
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
1131
        """
1132
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1133
1134
1135
1136
1137
1138
1139
1140

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1141
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1142
            output_hidden_states=output_hidden_states,
1143
            return_dict=return_dict,
1144
1145
1146
1147
1148
1149
1150
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1151
        loss = None
1152
1153
1154
1155
1156
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
1157
1158
1159
1160
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
1161
1162
1163
1164
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1165
        if not return_dict:
1166
1167
1168
1169
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1170
1171
1172
1173
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1174
        )
1175
1176


1177
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1178
1179
1180
1181
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
1182
1183
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1184
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1185

1186
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1187

Lysandre's avatar
Lysandre committed
1188
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1189
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1190
1191
        self.num_labels = config.num_labels

1192
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1193
1194
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

1195
1196
        # Initialize weights and apply final processing
        self.post_init()
Lysandre's avatar
Lysandre committed
1197

1198
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1199
    @add_code_sample_docstrings(
1200
        processor_class=_TOKENIZER_FOR_DOC,
1201
        checkpoint=_CHECKPOINT_FOR_DOC,
1202
1203
1204
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1215
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1216
        output_hidden_states=None,
1217
        return_dict=None,
1218
    ):
1219
        r"""
1220
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1221
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1222
1223
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1224
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1225
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1226
1227
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1228
        """
1229
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1230
1231
1232
1233
1234
1235
1236

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1237
            inputs_embeds=inputs_embeds,
1238
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1239
            output_hidden_states=output_hidden_states,
1240
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1241
        )
Lysandre's avatar
Lysandre committed
1242
1243
1244
1245
1246

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
Fan Zhang's avatar
Fan Zhang committed
1247
1248
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
Lysandre's avatar
Lysandre committed
1249

1250
        total_loss = None
Lysandre's avatar
Lysandre committed
1251
1252
1253
1254
1255
1256
1257
1258
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
1259
1260
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)
Lysandre's avatar
Lysandre committed
1261
1262
1263
1264
1265
1266

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1267
        if not return_dict:
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1278
1279
1280


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1281
1282
1283
1284
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
1285
1286
1287
1288
1289
1290
1291
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
1292
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
1293
1294
        self.classifier = nn.Linear(config.hidden_size, 1)

1295
1296
        # Initialize weights and apply final processing
        self.post_init()
1297

1298
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1299
    @add_code_sample_docstrings(
1300
        processor_class=_TOKENIZER_FOR_DOC,
1301
        checkpoint=_CHECKPOINT_FOR_DOC,
1302
1303
1304
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1315
        output_hidden_states=None,
1316
        return_dict=None,
1317
1318
    ):
        r"""
1319
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1320
1321
1322
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
            `input_ids` above)
1323
        """
1324
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1344
            output_hidden_states=output_hidden_states,
1345
            return_dict=return_dict,
1346
1347
1348
1349
1350
1351
1352
1353
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1354
        loss = None
1355
1356
1357
1358
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1359
        if not return_dict:
1360
1361
1362
1363
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1364
1365
1366
1367
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1368
        )