modeling_albert.py 55.4 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
21

Lysandre's avatar
Lysandre committed
22
import torch
23
from packaging import version
24
from torch import nn
25
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
26

Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
from ...activations import ACT2FN
from ...file_utils import (
29
30
31
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
32
    add_start_docstrings_to_model_forward,
33
34
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from ...modeling_outputs import (
36
37
38
39
40
41
42
43
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_albert import AlbertConfig
52

Aymeric Augustin's avatar
Aymeric Augustin committed
53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
55

56
_CHECKPOINT_FOR_DOC = "albert-base-v2"
57
_CONFIG_FOR_DOC = "AlbertConfig"
58
59
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
60

61
62
63
64
65
66
67
68
69
70
71
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
72
73


Lysandre's avatar
Lysandre committed
74
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
Patrick von Platen's avatar
Patrick von Platen committed
75
    """Load tf checkpoints in a pytorch model."""
Lysandre's avatar
Lysandre committed
76
77
    try:
        import re
78

Lysandre's avatar
Lysandre committed
79
80
81
        import numpy as np
        import tensorflow as tf
    except ImportError:
82
83
84
85
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
86
87
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
88
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
Lysandre's avatar
Lysandre committed
89
90
91
92
93
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
94
        logger.info(f"Loading TF weight {name} with shape {shape}")
Lysandre's avatar
Lysandre committed
95
96
97
98
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
99
100
    for name, array in zip(names, arrays):
        print(name)
101

Lysandre's avatar
Lysandre committed
102
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
103
        original_name = name
Lysandre's avatar
Lysandre committed
104
105
106
107
108

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
109
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
110
        name = name.replace("bert/", "albert/")
111
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
112
        name = name.replace("transform/", "")
113
114
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
133
134
135
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
136
137
138
139
140

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

141
        # No ALBERT model currently handles the next sentence prediction task
142
        if "seq_relationship" in name:
143
144
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
145

146
        name = name.split("/")
147
148

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
149
150
151
152
153
154
155
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
156
            logger.info(f"Skipping {'/'.join(name)}")
157
158
            continue

Lysandre's avatar
Lysandre committed
159
160
        pointer = model
        for m_name in name:
161
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
162
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
163
            else:
164
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
165

166
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
167
                pointer = getattr(pointer, "weight")
168
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
169
                pointer = getattr(pointer, "bias")
170
            elif scope_names[0] == "output_weights":
171
                pointer = getattr(pointer, "weight")
172
            elif scope_names[0] == "squad":
173
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
174
175
            else:
                try:
176
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
177
                except AttributeError:
178
                    logger.info(f"Skipping {'/'.join(name)}")
Lysandre's avatar
Lysandre committed
179
                    continue
180
181
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
182
183
                pointer = pointer[num]

184
185
186
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
187
188
            array = np.transpose(array)
        try:
189
190
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
Lysandre's avatar
Lysandre committed
191
192
193
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
194
        print(f"Initialize PyTorch weight {name} from {original_name}")
Lysandre's avatar
Lysandre committed
195
196
197
198
199
        pointer.data = torch.from_numpy(array)

    return model


200
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
201
202
203
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
204

Lysandre's avatar
Lysandre committed
205
    def __init__(self, config):
206
        super().__init__()
207
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
208
209
210
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

211
212
213
214
215
216
217
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
218
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
219
220
221
222
223
224
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            self.register_buffer(
                "token_type_ids",
                torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
                persistent=False,
            )
225

Sylvain Gugger's avatar
Sylvain Gugger committed
226
    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
227
228
229
    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
230
231
232
233
234
235
236
237
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
238
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
239

240
241
242
        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
243
        if token_type_ids is None:
244
245
246
247
248
249
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
250

251
252
253
254
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

255
256
257
258
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
259
260
261
262
263
264
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
265
    def __init__(self, config):
266
267
268
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
269
270
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads}"
271
            )
Lysandre's avatar
Lysandre committed
272
273

        self.num_attention_heads = config.num_attention_heads
274
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
275
        self.attention_head_size = config.hidden_size // config.num_attention_heads
276
277
278
279
280
281
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

282
283
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
284
285
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
286
287
        self.pruned_heads = set()

288
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
289
290
291
292
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

Sylvain Gugger's avatar
Sylvain Gugger committed
293
    # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
294
295
296
297
298
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
299
300
301
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
302
303
304
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
305
306
307
308
309
310
311
312
313
314
315
316

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

317
318
319
320
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
Lysandre's avatar
Lysandre committed
321
322
323
324
325
326
327
328

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
329

Lysandre's avatar
Lysandre committed
330
331
332
333
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

Lysandre's avatar
Lysandre committed
350
351
352
353
354
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
355
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
356
357
358
359
360
361

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)
362
        context_layer = context_layer.transpose(2, 1).flatten(2)
Lysandre's avatar
Lysandre committed
363

364
        projected_context_layer = self.dense(context_layer)
365
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
366
        layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
367
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
368
369


Lysandre's avatar
Lysandre committed
370
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
371
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
372
        super().__init__()
373

Lysandre's avatar
Lysandre committed
374
        self.config = config
375
376
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
377
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
378
        self.attention = AlbertAttention(config)
379
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
380
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
381
        self.activation = ACT2FN[config.hidden_act]
382
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
383

Joseph Liu's avatar
Joseph Liu committed
384
385
386
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
387
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
388
389

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
390
391
392
393
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
394
        )
395
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
396

397
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
398

399
400
401
402
403
404
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
405

Lysandre's avatar
Lysandre committed
406
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
407
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
408
        super().__init__()
409

Lysandre's avatar
Lysandre committed
410
411
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
412
413
414
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
415
416
417
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
418
        for layer_index, albert_layer in enumerate(self.albert_layers):
419
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
420
421
            hidden_states = layer_output[0]

422
            if output_attentions:
423
424
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
425
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
426
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
427

428
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
429
        if output_hidden_states:
430
            outputs = outputs + (layer_hidden_states,)
431
        if output_attentions:
432
433
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
434

Lysandre's avatar
Lysandre committed
435

Lysandre's avatar
Lysandre committed
436
437
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
438
        super().__init__()
439

Lysandre's avatar
Lysandre committed
440
        self.config = config
Lysandre's avatar
Lysandre committed
441
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
442
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
443

Joseph Liu's avatar
Joseph Liu committed
444
    def forward(
445
446
447
448
449
450
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
451
        return_dict=True,
Joseph Liu's avatar
Joseph Liu committed
452
    ):
Lysandre's avatar
Lysandre committed
453
454
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

455
456
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
457

458
459
        head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask

460
461
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
462
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
463
464
465
466

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

467
468
469
470
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
471
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
472
                output_hidden_states,
473
            )
474
475
            hidden_states = layer_group_output[0]

476
            if output_attentions:
Lysandre's avatar
Lysandre committed
477
                all_attentions = all_attentions + layer_group_output[-1]
478

Joseph Liu's avatar
Joseph Liu committed
479
            if output_hidden_states:
480
481
                all_hidden_states = all_hidden_states + (hidden_states,)

482
        if not return_dict:
483
484
485
486
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
487

Lysandre's avatar
Lysandre committed
488

489
class AlbertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
490
491
492
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
493
    """
494

495
    config_class = AlbertConfig
496
    load_tf_weights = load_tf_weights_in_albert
497
    base_model_prefix = "albert"
498
    _keys_to_ignore_on_load_missing = [r"position_ids"]
499
500

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
501
        """Initialize the weights."""
502
        if isinstance(module, nn.Linear):
503
504
505
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
506
            if module.bias is not None:
507
                module.bias.data.zero_()
508
509
510
511
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
512
513
514
515
516
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


517
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
518
class AlbertForPreTrainingOutput(ModelOutput):
519
    """
520
    Output type of :class:`~transformers.AlbertForPreTraining`.
521
522

    Args:
523
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
526
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
527
528
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
529
530
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
531
532
533
534
535
536
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Sylvain Gugger's avatar
Sylvain Gugger committed
537
538
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
539
540
541
542
543

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

544
545
546
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
547
548
549
550
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
551
ALBERT_START_DOCSTRING = r"""
552

Sylvain Gugger's avatar
Sylvain Gugger committed
553
554
555
556
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
560

561
    Args:
562
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
563
564
565
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
566
567
568
"""

ALBERT_INPUTS_DOCSTRING = r"""
569
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
570
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
571
572
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
575
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
Lysandre's avatar
Lysandre committed
576

577
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
578
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
579
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
580
581

            - 1 for tokens that are **not masked**,
582
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
583

584
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
585
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
586
587
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
588
589
590

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
591

592
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
593
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
596

597
            `What are position IDs? <../glossary.html#position-ids>`_
598
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
599
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
602
603
604

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
605
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
606
607
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
608
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
611
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
612
613
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
614
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
615
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
616
617
"""

618
619
620
621
622

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
623
624
625
626
627
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    base_model_prefix = "albert"

628
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
629
        super().__init__(config)
Lysandre's avatar
Lysandre committed
630
631
632
633

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
634
635
636
637
638
639
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
640

Lysandre's avatar
Lysandre committed
641
642
        self.init_weights()

LysandreJik's avatar
LysandreJik committed
643
644
645
646
647
648
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

Lysandre's avatar
Lysandre committed
649
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
650
651
652
653
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
        a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
        model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
654

Lysandre's avatar
Lysandre committed
655
656
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
657

Sylvain Gugger's avatar
Sylvain Gugger committed
658
659
        Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
        information about head pruning
Lysandre's avatar
Lysandre committed
660
661
662
663
664
665
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

666
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
667
668
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
669
        checkpoint=_CHECKPOINT_FOR_DOC,
670
671
672
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
673
674
675
676
677
678
679
680
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
681
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
682
        output_hidden_states=None,
683
        return_dict=None,
684
    ):
685
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
686
687
688
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
689
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
690

LysandreJik's avatar
LysandreJik committed
691
692
693
694
695
696
697
698
699
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

700
        batch_size, seq_length = input_shape
LysandreJik's avatar
LysandreJik committed
701
        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
702

Lysandre's avatar
Lysandre committed
703
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
704
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
705
        if token_type_ids is None:
706
707
708
709
710
711
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
712
713

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
714
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
715
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
716
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
717

718
719
720
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
721
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
722
723
724
725
726
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
727
            return_dict=return_dict,
728
        )
Lysandre's avatar
Lysandre committed
729

Lysandre's avatar
Lysandre committed
730
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
731

732
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
733

734
        if not return_dict:
735
736
737
738
739
740
741
742
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
743

744

745
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
746
    """
747
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
Sylvain Gugger's avatar
Sylvain Gugger committed
748
749
    `sentence order prediction (classification)` head.
    """,
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.predictions.decoder

765
766
767
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

768
769
770
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

771
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
772
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
773
774
775
776
777
778
779
780
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
781
        labels=None,
782
        sentence_order_label=None,
783
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
784
        output_hidden_states=None,
785
        return_dict=None,
786
787
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
788
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
789
790
791
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
792
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
793
794
795
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
            A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
796

Lysandre's avatar
Lysandre committed
797
        Returns:
798

Sylvain Gugger's avatar
Sylvain Gugger committed
799
        Example::
800

Lysandre's avatar
Lysandre committed
801
802
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
803

Lysandre's avatar
Lysandre committed
804
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
805
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
806

Lysandre's avatar
Lysandre committed
807
808
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
809

Lysandre's avatar
Lysandre committed
810
811
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
812
813

        """
814
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
815

816
817
818
819
820
821
822
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
823
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
824
            output_hidden_states=output_hidden_states,
825
            return_dict=return_dict,
826
827
828
829
830
831
832
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

833
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
834
        if labels is not None and sentence_order_label is not None:
835
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
836
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
837
838
839
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

840
        if not return_dict:
841
842
843
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
844
        return AlbertForPreTrainingOutput(
845
846
847
848
849
850
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
851
852


Lysandre's avatar
Lysandre committed
853
854
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
855
        super().__init__()
Lysandre's avatar
Lysandre committed
856
857
858
859
860
861

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]
862
863
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
864
865
866
867
868
869
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
870
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
871
872
873

        return prediction_scores

Sylvain Gugger's avatar
Sylvain Gugger committed
874
875
876
877
    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias

Lysandre's avatar
Lysandre committed
878

879
880
881
882
883
884
885
886
887
888
889
890
891
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


892
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
893
894
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
895
)
896
class AlbertForMaskedLM(AlbertPreTrainedModel):
897

898
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
899

Lysandre's avatar
Lysandre committed
900
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
901
        super().__init__(config)
Lysandre's avatar
Lysandre committed
902

903
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
904
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
905

Lysandre's avatar
Lysandre committed
906
        self.init_weights()
Lysandre's avatar
Lysandre committed
907

LysandreJik's avatar
LysandreJik committed
908
909
910
    def get_output_embeddings(self):
        return self.predictions.decoder

911
912
913
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

914
915
916
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

917
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
918
919
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
920
        checkpoint=_CHECKPOINT_FOR_DOC,
921
922
923
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
924
925
926
927
928
929
930
931
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
932
        labels=None,
933
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
934
        output_hidden_states=None,
935
        return_dict=None,
936
    ):
937
        r"""
938
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
939
940
941
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
942
        """
943
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
944

LysandreJik's avatar
LysandreJik committed
945
946
947
948
949
950
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
951
            inputs_embeds=inputs_embeds,
952
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
953
            output_hidden_states=output_hidden_states,
954
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
955
        )
956
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
957
958

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
959

960
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
961
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
962
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
963
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
964

965
        if not return_dict:
966
967
968
969
970
971
972
973
974
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
975
976


977
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
978
979
980
981
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
982
983
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
984
985
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
986
        super().__init__(config)
Lysandre's avatar
Lysandre committed
987
        self.num_labels = config.num_labels
988
        self.config = config
Lysandre's avatar
Lysandre committed
989
990

        self.albert = AlbertModel(config)
991
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
992
993
994
995
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

996
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
997
998
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
999
        checkpoint=_CHECKPOINT_FOR_DOC,
1000
1001
1002
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1003
1004
1005
1006
1007
1008
1009
1010
1011
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1012
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1013
        output_hidden_states=None,
1014
        return_dict=None,
1015
    ):
1016
        r"""
1017
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1018
1019
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1020
1021
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
1022
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
1023

LysandreJik's avatar
LysandreJik committed
1024
1025
1026
1027
1028
1029
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1030
            inputs_embeds=inputs_embeds,
1031
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1032
            output_hidden_states=output_hidden_states,
1033
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1034
        )
Lysandre's avatar
Lysandre committed
1035
1036
1037
1038
1039
1040

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1041
        loss = None
Lysandre's avatar
Lysandre committed
1042
        if labels is not None:
1043
1044
1045
1046
1047
1048
1049
1050
1051
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
Lysandre's avatar
Lysandre committed
1052
                loss_fct = MSELoss()
1053
1054
1055
1056
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
1057
            elif self.config.problem_type == "single_label_classification":
Lysandre's avatar
Lysandre committed
1058
1059
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1060
1061
1062
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
Lysandre's avatar
Lysandre committed
1063

1064
        if not return_dict:
1065
1066
1067
1068
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1069
1070
1071
1072
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1073
        )
Lysandre's avatar
Lysandre committed
1074
1075


1076
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1077
1078
1079
1080
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
1081
1082
1083
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1084

1085
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1086

1087
1088
1089
1090
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1091
        self.albert = AlbertModel(config, add_pooling_layer=False)
1092
1093
1094
1095
1096
1097
        classifier_dropout_prob = (
            config.classifier_dropout_prob
            if config.classifier_dropout_prob is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout_prob)
1098
1099
1100
1101
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

1102
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1103
1104
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1105
        checkpoint=_CHECKPOINT_FOR_DOC,
1106
1107
1108
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1109
1110
1111
1112
1113
1114
1115
1116
1117
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1118
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1119
        output_hidden_states=None,
1120
        return_dict=None,
1121
1122
    ):
        r"""
1123
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
1126
        """
1127
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1128
1129
1130
1131
1132
1133
1134
1135

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1136
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1137
            output_hidden_states=output_hidden_states,
1138
            return_dict=return_dict,
1139
1140
1141
1142
1143
1144
1145
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1146
        loss = None
1147
1148
1149
1150
1151
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
1152
1153
1154
1155
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
1156
1157
1158
1159
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1160
        if not return_dict:
1161
1162
1163
1164
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1165
1166
1167
1168
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1169
        )
1170
1171


1172
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1173
1174
1175
1176
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
1177
1178
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1179
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1180

1181
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1182

Lysandre's avatar
Lysandre committed
1183
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1184
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1185
1186
        self.num_labels = config.num_labels

1187
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1188
1189
1190
1191
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1192
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1193
1194
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1195
        checkpoint=_CHECKPOINT_FOR_DOC,
1196
1197
1198
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1209
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1210
        output_hidden_states=None,
1211
        return_dict=None,
1212
    ):
1213
        r"""
1214
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1215
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1216
1217
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1218
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1219
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1220
1221
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1222
        """
1223
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1224
1225
1226
1227
1228
1229
1230

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1231
            inputs_embeds=inputs_embeds,
1232
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1233
            output_hidden_states=output_hidden_states,
1234
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1235
        )
Lysandre's avatar
Lysandre committed
1236
1237
1238
1239
1240

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
Fan Zhang's avatar
Fan Zhang committed
1241
1242
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
Lysandre's avatar
Lysandre committed
1243

1244
        total_loss = None
Lysandre's avatar
Lysandre committed
1245
1246
1247
1248
1249
1250
1251
1252
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
1253
1254
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)
Lysandre's avatar
Lysandre committed
1255
1256
1257
1258
1259
1260

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1261
        if not return_dict:
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1272
1273
1274


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1275
1276
1277
1278
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
1279
1280
1281
1282
1283
1284
1285
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
1286
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
1287
1288
1289
1290
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()

1291
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1292
1293
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1294
        checkpoint=_CHECKPOINT_FOR_DOC,
1295
1296
1297
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1308
        output_hidden_states=None,
1309
        return_dict=None,
1310
1311
    ):
        r"""
1312
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1313
1314
1315
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
            `input_ids` above)
1316
        """
1317
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1337
            output_hidden_states=output_hidden_states,
1338
            return_dict=return_dict,
1339
1340
1341
1342
1343
1344
1345
1346
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1347
        loss = None
1348
1349
1350
1351
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1352
        if not return_dict:
1353
1354
1355
1356
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1357
1358
1359
1360
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1361
        )