modeling.py 20.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
15
"""PyTorch BERT model."""
thomwolf's avatar
thomwolf committed
16
17
18
19
20
21
22
23
24
25
26

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import json
import math
import six
import torch
import torch.nn as nn
27
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
28

lukovnikov's avatar
lukovnikov committed
29
30
31
32

ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish}


thomwolf's avatar
thomwolf committed
33
def gelu(x):
thomwolf's avatar
thomwolf committed
34
35
36
37
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
thomwolf's avatar
thomwolf committed
38
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
thomwolf's avatar
thomwolf committed
39

thomwolf's avatar
thomwolf committed
40

lukovnikov's avatar
lukovnikov committed
41
42
43
44
def swish(x):
    return x * torch.sigmoid(x)


thomwolf's avatar
thomwolf committed
45
class BertConfig(object):
thomwolf's avatar
thomwolf committed
46
47
    """Configuration class to store the configuration of a `BertModel`.
    """
thomwolf's avatar
thomwolf committed
48
    def __init__(self,
thomwolf's avatar
thomwolf committed
49
50
51
52
53
54
55
56
57
58
59
                vocab_size,
                hidden_size=768,
                num_hidden_layers=12,
                num_attention_heads=12,
                intermediate_size=3072,
                hidden_act="gelu",
                hidden_dropout_prob=0.1,
                attention_probs_dropout_prob=0.1,
                max_position_embeddings=512,
                type_vocab_size=16,
                initializer_range=0.02):
thomwolf's avatar
thomwolf committed
60
61
62
63
64
65
66
67
68
69
70
        """Constructs BertConfig.

        Args:
            vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
lukovnikov's avatar
lukovnikov committed
71
                encoder and pooler. If string, "gelu", "relu" and "swish" supported.
thomwolf's avatar
thomwolf committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
        """
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BertConfig` from a Python dictionary of parameters."""
        config = BertConfig(vocab_size=None)
        for (key, value) in six.iteritems(json_object):
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BertConfig` from a json file of parameters."""
        with open(json_file, "r") as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


thomwolf's avatar
thomwolf committed
121
class BERTLayerNorm(nn.Module):
thomwolf's avatar
thomwolf committed
122
    def __init__(self, config, variance_epsilon=1e-12):
thomwolf's avatar
thomwolf committed
123
124
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
thomwolf's avatar
thomwolf committed
125
126
127
128
129
130
131
132
133
134
        super(BERTLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(config.hidden_size))
        self.beta = nn.Parameter(torch.zeros(config.hidden_size))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta
thomwolf's avatar
thomwolf committed
135

thomwolf's avatar
thomwolf committed
136
class BERTEmbeddings(nn.Module):
thomwolf's avatar
thomwolf committed
137
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
138
        super(BERTEmbeddings, self).__init__()
thomwolf's avatar
thomwolf committed
139
140
        """Construct the embedding module from word, position and token_type embeddings.
        """
thomwolf's avatar
thomwolf committed
141
142
143
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
144

thomwolf's avatar
thomwolf committed
145
146
147
        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BERTLayerNorm(config)
thomwolf's avatar
thomwolf committed
148
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
149
150
151

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
152
153
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
thomwolf's avatar
thomwolf committed
154
        if token_type_ids is None:
155
            token_type_ids = torch.zeros_like(input_ids)
thomwolf's avatar
thomwolf committed
156
157
158
159

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
thomwolf's avatar
thomwolf committed
160

thomwolf's avatar
thomwolf committed
161
162
163
164
165
166
167
168
169
170
171
172
173
        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BERTSelfAttention(nn.Module):
    def __init__(self, config):
        super(BERTSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
thomwolf's avatar
thomwolf committed
174
175
176
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
thomwolf's avatar
thomwolf committed
177

thomwolf's avatar
thomwolf committed
178
179
180
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
thomwolf's avatar
thomwolf committed
181

thomwolf's avatar
thomwolf committed
182
183
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

thomwolf's avatar
thomwolf committed
184
    def transpose_for_scores(self, x):
thomwolf's avatar
thomwolf committed
185
186
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
thomwolf's avatar
thomwolf committed
187
        return x.permute(0, 2, 1, 3)
thomwolf's avatar
thomwolf committed
188

thomwolf's avatar
thomwolf committed
189
    def forward(self, hidden_states, attention_mask):
thomwolf's avatar
thomwolf committed
190
191
192
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
thomwolf's avatar
thomwolf committed
193

thomwolf's avatar
thomwolf committed
194
        query_layer = self.transpose_for_scores(mixed_query_layer)
thomwolf's avatar
thomwolf committed
195
        key_layer = self.transpose_for_scores(mixed_key_layer)
thomwolf's avatar
thomwolf committed
196
        value_layer = self.transpose_for_scores(mixed_value_layer)
thomwolf's avatar
thomwolf committed
197

thomwolf's avatar
thomwolf committed
198
199
200
201
202
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
203
204

        # Normalize the attention scores to probabilities.
thomwolf's avatar
thomwolf committed
205
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
thomwolf's avatar
thomwolf committed
206
207
208

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
thomwolf's avatar
thomwolf committed
209
        attention_probs = self.dropout(attention_probs)
thomwolf's avatar
thomwolf committed
210
211
212
213

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
thomwolf's avatar
thomwolf committed
214
        context_layer = context_layer.view(*new_context_layer_shape)
thomwolf's avatar
thomwolf committed
215
216
217
218
219
220
221
222
223
224
225
        return context_layer


class BERTSelfOutput(nn.Module):
    def __init__(self, config):
        super(BERTSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BERTLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
thomwolf's avatar
thomwolf committed
226
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
227
228
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
thomwolf's avatar
thomwolf committed
229
230
231
232
233
234
235
        return hidden_states


class BERTAttention(nn.Module):
    def __init__(self, config):
        super(BERTAttention, self).__init__()
        self.self = BERTSelfAttention(config)
thomwolf's avatar
thomwolf committed
236
237
238
        self.output = BERTSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
thomwolf's avatar
thomwolf committed
239
240
        self_output = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
thomwolf's avatar
thomwolf committed
241
242
243
244
245
        return attention_output


class BERTIntermediate(nn.Module):
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
246
        super(BERTIntermediate, self).__init__()
thomwolf's avatar
thomwolf committed
247
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
lukovnikov's avatar
lukovnikov committed
248
249
        self.intermediate_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act
thomwolf's avatar
thomwolf committed
250
251

    def forward(self, hidden_states):
thomwolf's avatar
thomwolf committed
252
253
254
255
256
257
258
259
260
261
262
263
264
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BERTOutput(nn.Module):
    def __init__(self, config):
        super(BERTOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BERTLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
265
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
266
267
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
thomwolf's avatar
thomwolf committed
268
269
270
271
272
273
274
275
276
277
        return hidden_states


class BERTLayer(nn.Module):
    def __init__(self, config):
        super(BERTLayer, self).__init__()
        self.attention = BERTAttention(config)
        self.intermediate = BERTIntermediate(config)
        self.output = BERTOutput(config)

thomwolf's avatar
thomwolf committed
278
279
280
281
    def forward(self, hidden_states, attention_mask):
        attention_output = self.attention(hidden_states, attention_mask)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
thomwolf's avatar
thomwolf committed
282
        return layer_output
thomwolf's avatar
thomwolf committed
283
284
285
286
287


class BERTEncoder(nn.Module):
    def __init__(self, config):
        super(BERTEncoder, self).__init__()
thomwolf's avatar
thomwolf committed
288
        layer = BERTLayer(config)
thomwolf's avatar
thomwolf committed
289
290
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])    

thomwolf's avatar
thomwolf committed
291
    def forward(self, hidden_states, attention_mask):
thomwolf's avatar
thomwolf committed
292
        all_encoder_layers = []
thomwolf's avatar
thomwolf committed
293
294
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask)
thomwolf's avatar
thomwolf committed
295
296
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers
thomwolf's avatar
thomwolf committed
297
298
299
300
301


class BERTPooler(nn.Module):
    def __init__(self, config):
        super(BERTPooler, self).__init__()
thomwolf's avatar
thomwolf committed
302
303
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
thomwolf's avatar
thomwolf committed
304

thomwolf's avatar
thomwolf committed
305
306
    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
thomwolf's avatar
thomwolf committed
307
        # to the first token.
thomwolf's avatar
thomwolf committed
308
309
310
311
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output
thomwolf's avatar
thomwolf committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326


class BertModel(nn.Module):
    """BERT model ("Bidirectional Embedding Representations from a Transformer").

    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])

    config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
        num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)

thomwolf's avatar
thomwolf committed
327
328
    model = modeling.BertModel(config=config)
    all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
thomwolf's avatar
thomwolf committed
329
330
    ```
    """
thomwolf's avatar
thomwolf committed
331
    def __init__(self, config: BertConfig):
thomwolf's avatar
thomwolf committed
332
333
334
335
336
        """Constructor for BertModel.

        Args:
            config: `BertConfig` instance.
        """
thomwolf's avatar
thomwolf committed
337
        super(BertModel, self).__init__()
thomwolf's avatar
thomwolf committed
338
339
        self.embeddings = BERTEmbeddings(config)
        self.encoder = BERTEncoder(config)
thomwolf's avatar
thomwolf committed
340
        self.pooler = BERTPooler(config)
thomwolf's avatar
thomwolf committed
341

thomwolf's avatar
thomwolf committed
342
343
344
345
346
347
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

thomwolf's avatar
thomwolf committed
348
        # We create a 3D attention mask from a 2D tensor mask.
thomwolf's avatar
thomwolf committed
349
350
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
351
352
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
thomwolf's avatar
thomwolf committed
353
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
thomwolf's avatar
thomwolf committed
354
355
356
357
358
359

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
thomwolf's avatar
thomwolf committed
360
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
thomwolf's avatar
thomwolf committed
361
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
362

thomwolf's avatar
thomwolf committed
363
        embedding_output = self.embeddings(input_ids, token_type_ids)
thomwolf's avatar
thomwolf committed
364
        all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
thomwolf's avatar
thomwolf committed
365
366
        sequence_output = all_encoder_layers[-1]
        pooled_output = self.pooler(sequence_output)
lukovnikov's avatar
lukovnikov committed
367
        return all_encoder_layers, pooled_output
368
369

class BertForSequenceClassification(nn.Module):
thomwolf's avatar
thomwolf committed
370
371
372
373
374
375
376
377
378
379
380
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.

    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])

381
    config = BertConfig(vocab_size=32000, hidden_size=512,
thomwolf's avatar
thomwolf committed
382
383
384
385
        num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)

    num_labels = 2

386
    model = BertForSequenceClassification(config, num_labels)
thomwolf's avatar
thomwolf committed
387
388
    logits = model(input_ids, token_type_ids, input_mask)
    ```
thomwolf's avatar
thomwolf committed
389
390
    """
    def __init__(self, config, num_labels):
391
392
393
394
395
        super(BertForSequenceClassification, self).__init__()
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)

396
397
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
thomwolf's avatar
thomwolf committed
398
                # Slightly different from the TF version which uses truncated_normal for initialization
399
                # cf https://github.com/pytorch/pytorch/pull/5617
400
                module.weight.data.normal_(mean=0.0, std=config.initializer_range)
401
            elif isinstance(module, BERTLayerNorm):
402
403
                module.beta.data.normal_(mean=0.0, std=config.initializer_range)
                module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
404
405
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()
406
407
408
409
410
411
412
413
414
415
416
417
418
        self.apply(init_weights)

    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits
419
420
421

class BertForQuestionAnswering(nn.Module):
    """BERT model for Question Answering (span extraction).
thomwolf's avatar
thomwolf committed
422
423
    This module is composed of the BERT model with a linear layer on top of
    the sequence output that computes start_logits and end_logits
424
425
426
427
428
429
430
431
432
433
434
435

    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])

    config = BertConfig(vocab_size=32000, hidden_size=512,
        num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)

    model = BertForQuestionAnswering(config)
thomwolf's avatar
thomwolf committed
436
    start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
437
438
439
440
441
    ```
    """
    def __init__(self, config):
        super(BertForQuestionAnswering, self).__init__()
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
442
        # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
443
444
445
        # self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

446
447
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Embedding)):
thomwolf's avatar
thomwolf committed
448
                # Slightly different from the TF version which uses truncated_normal for initialization
449
                # cf https://github.com/pytorch/pytorch/pull/5617
450
                module.weight.data.normal_(mean=0.0, std=config.initializer_range)
451
            elif isinstance(module, BERTLayerNorm):
452
453
                module.beta.data.normal_(mean=0.0, std=config.initializer_range)
                module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
454
455
            if isinstance(module, nn.Linear):
                module.bias.data.zero_()
456
457
458
459
460
461
462
        self.apply(init_weights)

    def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
        all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
        sequence_output = all_encoder_layers[-1]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
463
464
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
465
466

        if start_positions is not None and end_positions is not None:
467
468
469
470
471
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
472
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
473
            ignored_index = start_logits.size(1)
474
475
476
477
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
478
479
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
480
            total_loss = (start_loss + end_loss) / 2
481
            return total_loss
482
483
        else:
            return start_logits, end_logits