modeling_albert.py 55.4 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model. """

Aymeric Augustin's avatar
Aymeric Augustin committed
17
18
import math
import os
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
Aymeric Augustin's avatar
Aymeric Augustin committed
21

Lysandre's avatar
Lysandre committed
22
import torch
23
from packaging import version
24
from torch import nn
25
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Aymeric Augustin's avatar
Aymeric Augustin committed
26

Sylvain Gugger's avatar
Sylvain Gugger committed
27
28
from ...activations import ACT2FN
from ...file_utils import (
29
30
31
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
32
    add_start_docstrings_to_model_forward,
33
34
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from ...modeling_outputs import (
36
37
38
39
40
41
42
43
    BaseModelOutput,
    BaseModelOutputWithPooling,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
from ...modeling_utils import (
45
46
47
48
49
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
50
51
from ...utils import logging
from .configuration_albert import AlbertConfig
52

Aymeric Augustin's avatar
Aymeric Augustin committed
53

Lysandre Debut's avatar
Lysandre Debut committed
54
logger = logging.get_logger(__name__)
Lysandre's avatar
Lysandre committed
55

56
_CHECKPOINT_FOR_DOC = "albert-base-v2"
57
_CONFIG_FOR_DOC = "AlbertConfig"
58
59
_TOKENIZER_FOR_DOC = "AlbertTokenizer"

Lysandre's avatar
Lysandre committed
60

61
62
63
64
65
66
67
68
69
70
71
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "albert-base-v1",
    "albert-large-v1",
    "albert-xlarge-v1",
    "albert-xxlarge-v1",
    "albert-base-v2",
    "albert-large-v2",
    "albert-xlarge-v2",
    "albert-xxlarge-v2",
    # See all ALBERT models at https://huggingface.co/models?filter=albert
]
Lysandre's avatar
Lysandre committed
72
73


Lysandre's avatar
Lysandre committed
74
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
Patrick von Platen's avatar
Patrick von Platen committed
75
    """Load tf checkpoints in a pytorch model."""
Lysandre's avatar
Lysandre committed
76
77
    try:
        import re
78

Lysandre's avatar
Lysandre committed
79
80
81
        import numpy as np
        import tensorflow as tf
    except ImportError:
82
83
84
85
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
Lysandre's avatar
Lysandre committed
86
87
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
88
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
Lysandre's avatar
Lysandre committed
89
90
91
92
93
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
94
        logger.info(f"Loading TF weight {name} with shape {shape}")
Lysandre's avatar
Lysandre committed
95
96
97
98
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

Lysandre's avatar
Lysandre committed
99
100
    for name, array in zip(names, arrays):
        print(name)
101

Lysandre's avatar
Lysandre committed
102
    for name, array in zip(names, arrays):
Lysandre's avatar
Lysandre committed
103
        original_name = name
Lysandre's avatar
Lysandre committed
104
105
106
107
108

        # If saved from the TF HUB module
        name = name.replace("module/", "")

        # Renaming and simplifying
Lysandre's avatar
Lysandre committed
109
        name = name.replace("ffn_1", "ffn")
Lysandre's avatar
Lysandre committed
110
        name = name.replace("bert/", "albert/")
111
        name = name.replace("attention_1", "attention")
Lysandre's avatar
Lysandre committed
112
        name = name.replace("transform/", "")
113
114
        name = name.replace("LayerNorm_1", "full_layer_layer_norm")
        name = name.replace("LayerNorm", "attention/LayerNorm")
Lysandre's avatar
Lysandre committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        name = name.replace("transformer/", "")

        # The feed forward layer had an 'intermediate' step which has been abstracted away
        name = name.replace("intermediate/dense/", "")
        name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")

        # ALBERT attention was split between self and output which have been abstracted away
        name = name.replace("/output/", "/")
        name = name.replace("/self/", "/")

        # The pooler is a linear layer
        name = name.replace("pooler/dense", "pooler")

        # The classifier was simplified to predictions from cls/predictions
        name = name.replace("cls/predictions", "predictions")
        name = name.replace("predictions/attention", "predictions")

        # Naming was changed to be more explicit
133
134
135
        name = name.replace("embeddings/attention", "embeddings")
        name = name.replace("inner_group_", "albert_layers/")
        name = name.replace("group_", "albert_layer_groups/")
136
137
138
139
140

        # Classifier
        if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
            name = "classifier/" + name

141
        # No ALBERT model currently handles the next sentence prediction task
142
        if "seq_relationship" in name:
143
144
            name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
            name = name.replace("weights", "weight")
145

146
        name = name.split("/")
147
148

        # Ignore the gradients applied by the LAMB/ADAM optimizers.
149
150
151
152
153
154
155
        if (
            "adam_m" in name
            or "adam_v" in name
            or "AdamWeightDecayOptimizer" in name
            or "AdamWeightDecayOptimizer_1" in name
            or "global_step" in name
        ):
156
            logger.info(f"Skipping {'/'.join(name)}")
157
158
            continue

Lysandre's avatar
Lysandre committed
159
160
        pointer = model
        for m_name in name:
161
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
162
                scope_names = re.split(r"_(\d+)", m_name)
Lysandre's avatar
Lysandre committed
163
            else:
164
                scope_names = [m_name]
Lysandre's avatar
Lysandre committed
165

166
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
167
                pointer = getattr(pointer, "weight")
168
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
169
                pointer = getattr(pointer, "bias")
170
            elif scope_names[0] == "output_weights":
171
                pointer = getattr(pointer, "weight")
172
            elif scope_names[0] == "squad":
173
                pointer = getattr(pointer, "classifier")
Lysandre's avatar
Lysandre committed
174
175
            else:
                try:
176
                    pointer = getattr(pointer, scope_names[0])
Lysandre's avatar
Lysandre committed
177
                except AttributeError:
178
                    logger.info(f"Skipping {'/'.join(name)}")
Lysandre's avatar
Lysandre committed
179
                    continue
180
181
            if len(scope_names) >= 2:
                num = int(scope_names[1])
Lysandre's avatar
Lysandre committed
182
183
                pointer = pointer[num]

184
185
186
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
Lysandre's avatar
Lysandre committed
187
188
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
189
190
191
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
Lysandre's avatar
Lysandre committed
192
193
194
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
195
        print(f"Initialize PyTorch weight {name} from {original_name}")
Lysandre's avatar
Lysandre committed
196
197
198
199
200
        pointer.data = torch.from_numpy(array)

    return model


201
class AlbertEmbeddings(nn.Module):
Lysandre's avatar
Lysandre committed
202
203
204
    """
    Construct the embeddings from word, position and token_type embeddings.
    """
205

Lysandre's avatar
Lysandre committed
206
    def __init__(self, config):
207
        super().__init__()
208
        self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
Lysandre's avatar
Lysandre committed
209
210
211
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

212
213
214
215
216
217
218
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
219
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
220
221
222
223
224
225
        if version.parse(torch.__version__) > version.parse("1.6.0"):
            self.register_buffer(
                "token_type_ids",
                torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
                persistent=False,
            )
226

Sylvain Gugger's avatar
Sylvain Gugger committed
227
    # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
228
229
230
    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
231
232
233
234
235
236
237
238
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
239
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
240

241
242
243
        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
244
        if token_type_ids is None:
245
246
247
248
249
250
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
Lysandre's avatar
Lysandre committed
251

252
253
254
255
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

256
257
258
259
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
260
261
262
263
264
265
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class AlbertAttention(nn.Module):
Lysandre's avatar
Lysandre committed
266
    def __init__(self, config):
267
268
269
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
270
271
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads}"
272
            )
Lysandre's avatar
Lysandre committed
273
274

        self.num_attention_heads = config.num_attention_heads
275
        self.hidden_size = config.hidden_size
Lysandre's avatar
Lysandre committed
276
        self.attention_head_size = config.hidden_size // config.num_attention_heads
277
278
279
280
281
282
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

283
284
        self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
285
286
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
287
288
        self.pruned_heads = set()

289
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
290
291
292
293
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

Sylvain Gugger's avatar
Sylvain Gugger committed
294
    # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
295
296
297
298
299
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

Lysandre's avatar
Lysandre committed
300
301
302
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
303
304
305
        heads, index = find_pruneable_heads_and_indices(
            heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
        )
Lysandre's avatar
Lysandre committed
306
307
308
309
310
311
312
313
314
315
316
317

        # Prune linear layers
        self.query = prune_linear_layer(self.query, index)
        self.key = prune_linear_layer(self.key, index)
        self.value = prune_linear_layer(self.value, index)
        self.dense = prune_linear_layer(self.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.num_attention_heads = self.num_attention_heads - len(heads)
        self.all_head_size = self.attention_head_size * self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

318
319
320
321
    def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
Lysandre's avatar
Lysandre committed
322
323
324
325
326
327
328
329

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
330

Lysandre's avatar
Lysandre committed
331
332
333
334
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

Lysandre's avatar
Lysandre committed
351
352
353
354
355
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
356
        attention_probs = self.attention_dropout(attention_probs)
Lysandre's avatar
Lysandre committed
357
358
359
360
361
362

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)
363
        context_layer = context_layer.transpose(2, 1).flatten(2)
Lysandre's avatar
Lysandre committed
364

365
        projected_context_layer = self.dense(context_layer)
366
        projected_context_layer_dropout = self.output_dropout(projected_context_layer)
367
        layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
368
        return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
Lysandre's avatar
Lysandre committed
369
370


Lysandre's avatar
Lysandre committed
371
class AlbertLayer(nn.Module):
Lysandre's avatar
Lysandre committed
372
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
373
        super().__init__()
374

Lysandre's avatar
Lysandre committed
375
        self.config = config
376
377
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
Lysandre's avatar
Lysandre committed
378
        self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Lysandre's avatar
Lysandre committed
379
        self.attention = AlbertAttention(config)
380
        self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
Lysandre's avatar
Lysandre committed
381
        self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
382
        self.activation = ACT2FN[config.hidden_act]
383
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Lysandre's avatar
Lysandre committed
384

Joseph Liu's avatar
Joseph Liu committed
385
386
387
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
388
        attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
389
390

        ffn_output = apply_chunking_to_forward(
Lysandre's avatar
Lysandre committed
391
392
393
394
            self.ff_chunk,
            self.chunk_size_feed_forward,
            self.seq_len_dim,
            attention_output[0],
395
        )
396
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
Lysandre's avatar
Lysandre committed
397

398
        return (hidden_states,) + attention_output[1:]  # add attentions if we output them
Lysandre's avatar
Lysandre committed
399

400
401
402
403
404
405
    def ff_chunk(self, attention_output):
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        return ffn_output

Lysandre's avatar
Lysandre committed
406

Lysandre's avatar
Lysandre committed
407
class AlbertLayerGroup(nn.Module):
Lysandre's avatar
Lysandre committed
408
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
409
        super().__init__()
410

Lysandre's avatar
Lysandre committed
411
412
        self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

Joseph Liu's avatar
Joseph Liu committed
413
414
415
    def forward(
        self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
    ):
416
417
418
        layer_hidden_states = ()
        layer_attentions = ()

Lysandre's avatar
Lysandre committed
419
        for layer_index, albert_layer in enumerate(self.albert_layers):
420
            layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
421
422
            hidden_states = layer_output[0]

423
            if output_attentions:
424
425
                layer_attentions = layer_attentions + (layer_output[1],)

Joseph Liu's avatar
Joseph Liu committed
426
            if output_hidden_states:
Lysandre's avatar
Lysandre committed
427
                layer_hidden_states = layer_hidden_states + (hidden_states,)
Lysandre's avatar
Lysandre committed
428

429
        outputs = (hidden_states,)
Joseph Liu's avatar
Joseph Liu committed
430
        if output_hidden_states:
431
            outputs = outputs + (layer_hidden_states,)
432
        if output_attentions:
433
434
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)
Lysandre's avatar
Lysandre committed
435

Lysandre's avatar
Lysandre committed
436

Lysandre's avatar
Lysandre committed
437
438
class AlbertTransformer(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
439
        super().__init__()
440

Lysandre's avatar
Lysandre committed
441
        self.config = config
Lysandre's avatar
Lysandre committed
442
        self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
Lysandre's avatar
Lysandre committed
443
        self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
Lysandre's avatar
Lysandre committed
444

Joseph Liu's avatar
Joseph Liu committed
445
    def forward(
446
447
448
449
450
451
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        output_attentions=False,
        output_hidden_states=False,
452
        return_dict=True,
Joseph Liu's avatar
Joseph Liu committed
453
    ):
Lysandre's avatar
Lysandre committed
454
455
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)

456
457
        all_hidden_states = (hidden_states,) if output_hidden_states else None
        all_attentions = () if output_attentions else None
458

459
460
        head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask

461
462
        for i in range(self.config.num_hidden_layers):
            # Number of layers in a hidden group
Lysandre's avatar
Lysandre committed
463
            layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
464
465
466
467

            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))

468
469
470
471
            layer_group_output = self.albert_layer_groups[group_idx](
                hidden_states,
                attention_mask,
                head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
472
                output_attentions,
Joseph Liu's avatar
Joseph Liu committed
473
                output_hidden_states,
474
            )
475
476
            hidden_states = layer_group_output[0]

477
            if output_attentions:
Lysandre's avatar
Lysandre committed
478
                all_attentions = all_attentions + layer_group_output[-1]
479

Joseph Liu's avatar
Joseph Liu committed
480
            if output_hidden_states:
481
482
                all_hidden_states = all_hidden_states + (hidden_states,)

483
        if not return_dict:
484
485
486
487
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
Lysandre's avatar
Lysandre committed
488

Lysandre's avatar
Lysandre committed
489

490
class AlbertPreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
491
492
493
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
494
    """
495

496
497
    config_class = AlbertConfig
    base_model_prefix = "albert"
498
    _keys_to_ignore_on_load_missing = [r"position_ids"]
499
500

    def _init_weights(self, module):
Lysandre's avatar
Lysandre committed
501
        """Initialize the weights."""
502
        if isinstance(module, nn.Linear):
503
504
505
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
506
            if module.bias is not None:
507
                module.bias.data.zero_()
508
509
510
511
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
512
513
514
515
516
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


517
@dataclass
Sylvain Gugger's avatar
Sylvain Gugger committed
518
class AlbertForPreTrainingOutput(ModelOutput):
519
    """
520
    Output type of :class:`~transformers.AlbertForPreTraining`.
521
522

    Args:
523
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
524
525
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
526
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
527
528
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
529
530
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
531
532
533
534
535
536
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Sylvain Gugger's avatar
Sylvain Gugger committed
537
538
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.
539
540
541
542
543

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

544
545
546
    loss: Optional[torch.FloatTensor] = None
    prediction_logits: torch.FloatTensor = None
    sop_logits: torch.FloatTensor = None
547
548
549
550
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
551
ALBERT_START_DOCSTRING = r"""
552

Sylvain Gugger's avatar
Sylvain Gugger committed
553
554
555
556
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
560

561
    Args:
562
        config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
563
564
565
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
566
567
568
"""

ALBERT_INPUTS_DOCSTRING = r"""
569
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
570
        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
571
572
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
573
574
575
            Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
            details.
Lysandre's avatar
Lysandre committed
576

577
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
578
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
579
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
580
581

            - 1 for tokens that are **not masked**,
582
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
583

584
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
585
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
586
587
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
            1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
588
589
590

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
591

592
            `What are token type IDs? <../glossary.html#token-type-ids>`_
Sylvain Gugger's avatar
Sylvain Gugger committed
593
        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
            config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
596

597
            `What are position IDs? <../glossary.html#position-ids>`_
598
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
599
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
600
601
602
603
604

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
605
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
606
607
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
608
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
611
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
612
613
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
614
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
615
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
616
617
"""

618
619
620
621
622

@add_start_docstrings(
    "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
    ALBERT_START_DOCSTRING,
)
623
624
625
626
627
628
class AlbertModel(AlbertPreTrainedModel):

    config_class = AlbertConfig
    load_tf_weights = load_tf_weights_in_albert
    base_model_prefix = "albert"

629
    def __init__(self, config, add_pooling_layer=True):
Julien Chaumond's avatar
Julien Chaumond committed
630
        super().__init__(config)
Lysandre's avatar
Lysandre committed
631
632
633
634

        self.config = config
        self.embeddings = AlbertEmbeddings(config)
        self.encoder = AlbertTransformer(config)
635
636
637
638
639
640
        if add_pooling_layer:
            self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
            self.pooler_activation = nn.Tanh()
        else:
            self.pooler = None
            self.pooler_activation = None
Lysandre's avatar
Lysandre committed
641

Lysandre's avatar
Lysandre committed
642
643
        self.init_weights()

LysandreJik's avatar
LysandreJik committed
644
645
646
647
648
649
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

Lysandre's avatar
Lysandre committed
650
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
651
652
653
654
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
        a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
        model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
Lysandre's avatar
Lysandre committed
655

Lysandre's avatar
Lysandre committed
656
657
        These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
        while [2,3] correspond to the two inner groups of the second hidden layer.
Lysandre's avatar
Lysandre committed
658

Sylvain Gugger's avatar
Sylvain Gugger committed
659
660
        Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
        information about head pruning
Lysandre's avatar
Lysandre committed
661
662
663
664
665
666
        """
        for layer, heads in heads_to_prune.items():
            group_idx = int(layer / self.config.inner_group_num)
            inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
            self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)

667
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
668
669
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
670
        checkpoint=_CHECKPOINT_FOR_DOC,
671
672
673
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
674
675
676
677
678
679
680
681
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
682
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
683
        output_hidden_states=None,
684
        return_dict=None,
685
    ):
686
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
687
688
689
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
690
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691

LysandreJik's avatar
LysandreJik committed
692
693
694
695
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
696
            batch_size, seq_length = input_shape
LysandreJik's avatar
LysandreJik committed
697
698
699
700
701
702
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device
Lysandre's avatar
Lysandre committed
703

Lysandre's avatar
Lysandre committed
704
        if attention_mask is None:
LysandreJik's avatar
LysandreJik committed
705
            attention_mask = torch.ones(input_shape, device=device)
Lysandre's avatar
Lysandre committed
706
        if token_type_ids is None:
707
708
709
710
711
712
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
Lysandre's avatar
Lysandre committed
713
714

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
715
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
Lysandre's avatar
Lysandre committed
716
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
717
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
Lysandre's avatar
Lysandre committed
718

719
720
721
        embedding_output = self.embeddings(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
722
        encoder_outputs = self.encoder(
Joseph Liu's avatar
Joseph Liu committed
723
724
725
726
727
            embedding_output,
            extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
728
            return_dict=return_dict,
729
        )
Lysandre's avatar
Lysandre committed
730

Lysandre's avatar
Lysandre committed
731
        sequence_output = encoder_outputs[0]
Lysandre's avatar
Lysandre committed
732

733
        pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
Lysandre's avatar
Lysandre committed
734

735
        if not return_dict:
736
737
738
739
740
741
742
743
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
744

745

746
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
747
    """
748
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
Sylvain Gugger's avatar
Sylvain Gugger committed
749
750
    `sentence order prediction (classification)` head.
    """,
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
    ALBERT_START_DOCSTRING,
)
class AlbertForPreTraining(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.predictions = AlbertMLMHead(config)
        self.sop_classifier = AlbertSOPHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.predictions.decoder

766
767
768
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

769
770
771
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

772
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
773
    @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
774
775
776
777
778
779
780
781
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
782
        labels=None,
783
        sentence_order_label=None,
784
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
785
        output_hidden_states=None,
786
        return_dict=None,
787
788
    ):
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
789
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
790
791
792
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
793
        sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
794
795
796
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
            A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
797

Lysandre's avatar
Lysandre committed
798
        Returns:
799

Sylvain Gugger's avatar
Sylvain Gugger committed
800
        Example::
801

Lysandre's avatar
Lysandre committed
802
803
            >>> from transformers import AlbertTokenizer, AlbertForPreTraining
            >>> import torch
804

Lysandre's avatar
Lysandre committed
805
            >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
806
            >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
807

Lysandre's avatar
Lysandre committed
808
809
            >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            >>> outputs = model(input_ids)
810

Lysandre's avatar
Lysandre committed
811
812
            >>> prediction_logits = outputs.prediction_logits
            >>> sop_logits = outputs.sop_logits
813
814

        """
815
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
816

817
818
819
820
821
822
823
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
824
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
825
            output_hidden_states=output_hidden_states,
826
            return_dict=return_dict,
827
828
829
830
831
832
833
        )

        sequence_output, pooled_output = outputs[:2]

        prediction_scores = self.predictions(sequence_output)
        sop_scores = self.sop_classifier(pooled_output)

834
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
835
        if labels is not None and sentence_order_label is not None:
836
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
837
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
838
839
840
            sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
            total_loss = masked_lm_loss + sentence_order_loss

841
        if not return_dict:
842
843
844
            output = (prediction_scores, sop_scores) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

Sylvain Gugger's avatar
Sylvain Gugger committed
845
        return AlbertForPreTrainingOutput(
846
847
848
849
850
851
            loss=total_loss,
            prediction_logits=prediction_scores,
            sop_logits=sop_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
852
853


Lysandre's avatar
Lysandre committed
854
855
class AlbertMLMHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
856
        super().__init__()
Lysandre's avatar
Lysandre committed
857
858
859
860
861
862

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
        self.activation = ACT2FN[config.hidden_act]
863
864
        self.decoder.bias = self.bias

Lysandre's avatar
Lysandre committed
865
866
867
868
869
870
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

Martin Malmsten's avatar
Martin Malmsten committed
871
        prediction_scores = hidden_states
Lysandre's avatar
Lysandre committed
872
873
874

        return prediction_scores

Sylvain Gugger's avatar
Sylvain Gugger committed
875
876
877
878
    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias

Lysandre's avatar
Lysandre committed
879

880
881
882
883
884
885
886
887
888
889
890
891
892
class AlbertSOPHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dropout = nn.Dropout(config.classifier_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pooled_output):
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits


893
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
894
895
    "Albert Model with a `language modeling` head on top.",
    ALBERT_START_DOCSTRING,
896
)
897
class AlbertForMaskedLM(AlbertPreTrainedModel):
898

899
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
900

Lysandre's avatar
Lysandre committed
901
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
902
        super().__init__(config)
Lysandre's avatar
Lysandre committed
903

904
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
905
        self.predictions = AlbertMLMHead(config)
Lysandre's avatar
Lysandre committed
906

Lysandre's avatar
Lysandre committed
907
        self.init_weights()
Lysandre's avatar
Lysandre committed
908

LysandreJik's avatar
LysandreJik committed
909
910
911
    def get_output_embeddings(self):
        return self.predictions.decoder

912
913
914
    def set_output_embeddings(self, new_embeddings):
        self.predictions.decoder = new_embeddings

915
916
917
    def get_input_embeddings(self):
        return self.albert.embeddings.word_embeddings

918
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
919
920
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
921
        checkpoint=_CHECKPOINT_FOR_DOC,
922
923
924
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
925
926
927
928
929
930
931
932
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
933
        labels=None,
934
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
935
        output_hidden_states=None,
936
        return_dict=None,
937
    ):
938
        r"""
939
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
940
941
942
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
943
        """
944
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Sylvain Gugger's avatar
Sylvain Gugger committed
945

LysandreJik's avatar
LysandreJik committed
946
947
948
949
950
951
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
952
            inputs_embeds=inputs_embeds,
953
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
954
            output_hidden_states=output_hidden_states,
955
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
956
        )
957
        sequence_outputs = outputs[0]
Lysandre's avatar
Lysandre committed
958
959

        prediction_scores = self.predictions(sequence_outputs)
Lysandre's avatar
Lysandre committed
960

961
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
962
        if labels is not None:
LysandreJik's avatar
LysandreJik committed
963
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
964
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Lysandre's avatar
Lysandre committed
965

966
        if not return_dict:
967
968
969
970
971
972
973
974
975
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Lysandre's avatar
Lysandre committed
976
977


978
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
979
980
981
982
    """
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
983
984
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
985
986
class AlbertForSequenceClassification(AlbertPreTrainedModel):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
987
        super().__init__(config)
Lysandre's avatar
Lysandre committed
988
        self.num_labels = config.num_labels
989
        self.config = config
Lysandre's avatar
Lysandre committed
990
991

        self.albert = AlbertModel(config)
992
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
Lysandre's avatar
Lysandre committed
993
994
995
996
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

997
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
998
999
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1000
        checkpoint=_CHECKPOINT_FOR_DOC,
1001
1002
1003
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1004
1005
1006
1007
1008
1009
1010
1011
1012
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1013
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1014
        output_hidden_states=None,
1015
        return_dict=None,
1016
    ):
1017
        r"""
1018
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1019
1020
            Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
            config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1021
1022
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
        """
1023
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Lysandre's avatar
Lysandre committed
1024

LysandreJik's avatar
LysandreJik committed
1025
1026
1027
1028
1029
1030
        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1031
            inputs_embeds=inputs_embeds,
1032
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1033
            output_hidden_states=output_hidden_states,
1034
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1035
        )
Lysandre's avatar
Lysandre committed
1036
1037
1038
1039
1040
1041

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1042
        loss = None
Lysandre's avatar
Lysandre committed
1043
        if labels is not None:
1044
1045
1046
1047
1048
1049
1050
1051
1052
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
Lysandre's avatar
Lysandre committed
1053
                loss_fct = MSELoss()
1054
1055
1056
1057
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
1058
            elif self.config.problem_type == "single_label_classification":
Lysandre's avatar
Lysandre committed
1059
1060
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1061
1062
1063
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
Lysandre's avatar
Lysandre committed
1064

1065
        if not return_dict:
1066
1067
1068
1069
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1070
1071
1072
1073
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1074
        )
Lysandre's avatar
Lysandre committed
1075
1076


1077
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1078
1079
1080
1081
    """
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    """,
1082
1083
1084
    ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
1085

1086
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1087

1088
1089
1090
1091
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

1092
        self.albert = AlbertModel(config, add_pooling_layer=False)
1093
1094
1095
1096
1097
1098
        classifier_dropout_prob = (
            config.classifier_dropout_prob
            if config.classifier_dropout_prob is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout_prob)
1099
1100
1101
1102
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

1103
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1104
1105
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1106
        checkpoint=_CHECKPOINT_FOR_DOC,
1107
1108
1109
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1110
1111
1112
1113
1114
1115
1116
1117
1118
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1119
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1120
        output_hidden_states=None,
1121
        return_dict=None,
1122
1123
    ):
        r"""
1124
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1125
1126
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
1127
        """
1128
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1129
1130
1131
1132
1133
1134
1135
1136

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1137
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1138
            output_hidden_states=output_hidden_states,
1139
            return_dict=return_dict,
1140
1141
1142
1143
1144
1145
1146
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

1147
        loss = None
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

1159
        if not return_dict:
1160
1161
1162
1163
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1164
1165
1166
1167
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1168
        )
1169
1170


1171
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1172
1173
1174
1175
    """
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
1176
1177
    ALBERT_START_DOCSTRING,
)
Lysandre's avatar
Lysandre committed
1178
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1179

1180
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
1181

Lysandre's avatar
Lysandre committed
1182
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1183
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1184
1185
        self.num_labels = config.num_labels

1186
        self.albert = AlbertModel(config, add_pooling_layer=False)
Lysandre's avatar
Lysandre committed
1187
1188
1189
1190
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1191
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1192
1193
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1194
        checkpoint=_CHECKPOINT_FOR_DOC,
1195
1196
1197
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1208
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1209
        output_hidden_states=None,
1210
        return_dict=None,
1211
    ):
1212
        r"""
1213
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1214
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1215
1216
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1217
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1218
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1219
1220
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
1221
        """
1222
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
LysandreJik's avatar
LysandreJik committed
1223
1224
1225
1226
1227
1228
1229

        outputs = self.albert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
1230
            inputs_embeds=inputs_embeds,
1231
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1232
            output_hidden_states=output_hidden_states,
1233
            return_dict=return_dict,
LysandreJik's avatar
LysandreJik committed
1234
        )
Lysandre's avatar
Lysandre committed
1235
1236
1237
1238
1239

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
Fan Zhang's avatar
Fan Zhang committed
1240
1241
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
Lysandre's avatar
Lysandre committed
1242

1243
        total_loss = None
Lysandre's avatar
Lysandre committed
1244
1245
1246
1247
1248
1249
1250
1251
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
1252
1253
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)
Lysandre's avatar
Lysandre committed
1254
1255
1256
1257
1258
1259

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

1260
        if not return_dict:
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
1271
1272
1273


@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1274
1275
1276
1277
    """
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    """,
1278
1279
1280
1281
1282
1283
1284
    ALBERT_START_DOCSTRING,
)
class AlbertForMultipleChoice(AlbertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.albert = AlbertModel(config)
1285
        self.dropout = nn.Dropout(config.classifier_dropout_prob)
1286
1287
1288
1289
        self.classifier = nn.Linear(config.hidden_size, 1)

        self.init_weights()

1290
    @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1291
1292
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
1293
        checkpoint=_CHECKPOINT_FOR_DOC,
1294
1295
1296
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1307
        output_hidden_states=None,
1308
        return_dict=None,
1309
1310
    ):
        r"""
1311
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1312
1313
1314
            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
            num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
            `input_ids` above)
1315
        """
1316
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1336
            output_hidden_states=output_hidden_states,
1337
            return_dict=return_dict,
1338
1339
1340
1341
1342
1343
1344
1345
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)

1346
        loss = None
1347
1348
1349
1350
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)

1351
        if not return_dict:
1352
1353
1354
1355
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1356
1357
1358
1359
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
1360
        )