Unverified Commit 2d6a5349 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #597 from huggingface/attention

GPT-2 (medium size model, special_tokens, fine-tuning, attention) + repo code coverage metric 
parents f9cde97b 35e6baab
...@@ -7,9 +7,11 @@ jobs: ...@@ -7,9 +7,11 @@ jobs:
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest ftfy spacy - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install spacy ftfy==4.4.3
- run: sudo python -m spacy download en - run: sudo python -m spacy download en
- run: python -m pytest -sv tests/ --runslow - run: python -m pytest -sv tests/ --runslow --cov
- run: codecov
build_py2: build_py2:
working_directory: ~/pytorch-pretrained-BERT working_directory: ~/pytorch-pretrained-BERT
docker: docker:
...@@ -17,10 +19,11 @@ jobs: ...@@ -17,10 +19,11 @@ jobs:
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest spacy - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install ftfy==4.4.3 - run: sudo pip install spacy ftfy==4.4.3
- run: sudo python -m spacy download en - run: sudo python -m spacy download en
- run: python -m pytest -sv tests/ --runslow - run: python -m pytest -sv tests/ --runslow --cov
- run: codecov
workflows: workflows:
version: 2 version: 2
build_and_test: build_and_test:
......
[run]
source=pytorch_pretrained_bert
[report]
exclude_lines =
pragma: no cover
raise
except
register_parameter
\ No newline at end of file
...@@ -278,12 +278,13 @@ class BertEmbeddings(nn.Module): ...@@ -278,12 +278,13 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertSelfAttention, self).__init__() super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) "heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
...@@ -325,6 +326,8 @@ class BertSelfAttention(nn.Module): ...@@ -325,6 +326,8 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
if self.output_attentions:
return attention_probs, context_layer
return context_layer return context_layer
...@@ -343,14 +346,19 @@ class BertSelfOutput(nn.Module): ...@@ -343,14 +346,19 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertAttention, self).__init__() super(BertAttention, self).__init__()
self.self = BertSelfAttention(config) self.output_attentions = output_attentions
self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask): def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask) self_output = self.self(input_tensor, attention_mask)
if self.output_attentions:
attentions, self_output = self_output
attention_output = self.output(self_output, input_tensor) attention_output = self.output(self_output, input_tensor)
if self.output_attentions:
return attentions, attention_output
return attention_output return attention_output
...@@ -384,33 +392,45 @@ class BertOutput(nn.Module): ...@@ -384,33 +392,45 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.attention = BertAttention(config) self.output_attentions = output_attentions
self.attention = BertAttention(config, output_attentions=output_attentions)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask)
if self.output_attentions:
attentions, attention_output = attention_output
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
if self.output_attentions:
return attentions, layer_output
return layer_output return layer_output
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
layer = BertLayer(config) self.output_attentions = output_attentions
layer = BertLayer(config, output_attentions=output_attentions)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = [] all_encoder_layers = []
all_attentions = []
for layer_module in self.layer: for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
if output_all_encoded_layers: if output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers: if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if self.output_attentions:
return all_attentions, all_encoder_layers
return all_encoder_layers return all_encoder_layers
...@@ -702,10 +722,11 @@ class BertModel(BertPreTrainedModel): ...@@ -702,10 +722,11 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertModel, self).__init__(config) super(BertModel, self).__init__(config)
self.output_attentions = output_attentions
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config, output_attentions=output_attentions)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
...@@ -734,10 +755,14 @@ class BertModel(BertPreTrainedModel): ...@@ -734,10 +755,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers = self.encoder(embedding_output, encoded_layers = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers)
if self.output_attentions:
all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers: if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1] encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output return encoded_layers, pooled_output
...@@ -791,15 +816,20 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -791,15 +816,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForPreTraining, self).__init__(config) super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, pooled_output = outputs
else:
sequence_output, pooled_output = outputs
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
...@@ -808,7 +838,8 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -808,7 +838,8 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
else: elif self.output_attentions:
return all_attentions, prediction_scores, seq_relationship_score
return prediction_scores, seq_relationship_score return prediction_scores, seq_relationship_score
...@@ -854,22 +885,28 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -854,22 +885,28 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss return masked_lm_loss
else: elif self.output_attentions:
return all_attentions, prediction_scores
return prediction_scores return prediction_scores
...@@ -916,22 +953,28 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -916,22 +953,28 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask) seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output) if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
seq_relationship_score = self.cls(pooled_output)
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss return next_sentence_loss
else: elif self.output_attentions:
return all_attentions, seq_relationship_score
return seq_relationship_score return seq_relationship_score
...@@ -980,16 +1023,21 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -980,16 +1023,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels=2, output_attentions=False):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
...@@ -997,7 +1045,8 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -997,7 +1045,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return all_attentions, logits
return logits return logits
...@@ -1045,10 +1094,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1045,10 +1094,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices=2): def __init__(self, config, num_choices=2, output_attentions=False):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
...@@ -1057,7 +1107,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1057,7 +1107,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices) reshaped_logits = logits.view(-1, self.num_choices)
...@@ -1066,7 +1120,8 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1066,7 +1120,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
return loss return loss
else: elif self.output_attentions:
return all_attentions, reshaped_logits
return reshaped_logits return reshaped_logits
...@@ -1115,16 +1170,21 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1115,16 +1170,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels=2, output_attentions=False):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -1139,7 +1199,8 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1139,7 +1199,8 @@ class BertForTokenClassification(BertPreTrainedModel):
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return all_attentions, logits
return logits return logits
...@@ -1190,16 +1251,19 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1190,16 +1251,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version self.bert = BertModel(config, output_attentions=output_attentions)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
...@@ -1221,5 +1285,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1221,5 +1285,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss return total_loss
else: elif self.output_attentions:
return all_attentions, start_logits, end_logits
return start_logits, end_logits return start_logits, end_logits
...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm ...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
...@@ -107,18 +109,24 @@ class GPT2Config(object): ...@@ -107,18 +109,24 @@ class GPT2Config(object):
def __init__( def __init__(
self, self,
vocab_size_or_config_json_file=50257, vocab_size_or_config_json_file=50257,
n_special=0,
n_positions=1024, n_positions=1024,
n_ctx=1024, n_ctx=1024,
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True
): ):
"""Constructs GPT2Config. """Constructs GPT2Config.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
...@@ -126,8 +134,14 @@ class GPT2Config(object): ...@@ -126,8 +134,14 @@ class GPT2Config(object):
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
the Transformer encoder. the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -137,19 +151,28 @@ class GPT2Config(object): ...@@ -137,19 +151,28 @@ class GPT2Config(object):
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)" "or the path to a pretrained model config file (str)"
) )
@property
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `GPT2Config` from a Python dictionary of parameters.""" """Constructs a `GPT2Config` from a Python dictionary of parameters."""
...@@ -200,7 +223,7 @@ class Conv1D(nn.Module): ...@@ -200,7 +223,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -209,8 +232,11 @@ class Attention(nn.Module): ...@@ -209,8 +232,11 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v): def _attn(self, q, k, v):
w = torch.matmul(q, k) w = torch.matmul(q, k)
...@@ -221,6 +247,9 @@ class Attention(nn.Module): ...@@ -221,6 +247,9 @@ class Attention(nn.Module):
w = w * b - 1e4 * (1 - b) w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -248,8 +277,13 @@ class Attention(nn.Module): ...@@ -248,8 +277,13 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present
return a, present return a, present
...@@ -260,27 +294,35 @@ class MLP(nn.Module): ...@@ -260,27 +294,35 @@ class MLP(nn.Module):
self.c_fc = Conv1D(n_state, nx) self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state) self.c_proj = Conv1D(nx, n_state)
self.act = gelu self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x): def forward(self, x):
h = self.act(self.c_fc(x)) h = self.act(self.c_fc(x))
h2 = self.c_proj(h) h2 = self.c_proj(h)
return h2 return self.dropout(h2)
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past) output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
if self.output_attentions:
attentions, a, present = output_attn
else:
a, present = output_attn
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
if self.output_attentions:
return attentions, x, present
return x, present return x, present
...@@ -290,17 +332,20 @@ class GPT2LMHead(nn.Module): ...@@ -290,17 +332,20 @@ class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__() super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights) self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
if not self.predict_special_tokens:
lm_logits = lm_logits[..., :self.vocab_size]
return lm_logits return lm_logits
...@@ -310,6 +355,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -310,6 +355,7 @@ class GPT2MultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(GPT2MultipleChoiceHead, self).__init__() super(GPT2MultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.weight, std=0.02)
...@@ -323,6 +369,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -323,6 +369,7 @@ class GPT2MultipleChoiceHead(nn.Module):
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices) # (bsz, num_choices)
return multiple_choice_logits return multiple_choice_logits
...@@ -345,9 +392,6 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -345,9 +392,6 @@ class GPT2PreTrainedModel(nn.Module):
) )
self.config = config self.config = config
def set_tied(self):
pass
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -363,7 +407,7 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -363,7 +407,7 @@ class GPT2PreTrainedModel(nn.Module):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
): ):
""" """
Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
...@@ -475,14 +519,32 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -475,14 +519,32 @@ class GPT2PreTrainedModel(nn.Module):
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
) )
# Make sure we are still sharing the output and input embeddings after loading weights # Add additional embeddings for special tokens if needed
model.set_tied() # This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
return model return model
class GPT2Model(GPT2PreTrainedModel): class GPT2Model(GPT2PreTrainedModel):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners"). """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
GPT-2 use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params: Params:
config: a GPT2Config class instance with the configuration to build a new model config: a GPT2Config class instance with the configuration to build a new model
...@@ -519,16 +581,32 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -519,16 +581,32 @@ class GPT2Model(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.output_attentions = output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed = self.wte
self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
self.wte.to(old_embed.weight.device)
self.init_weights(self.wte)
# Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -551,12 +629,21 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -551,12 +629,21 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
token_type_embeds = 0 token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
presents = [] presents = []
all_attentions = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
if self.output_attentions:
attentions, hidden_states, present = block(hidden_states, layer_past)
all_attentions.append(attentions)
else:
hidden_states, present = block(hidden_states, layer_past) hidden_states, present = block(hidden_states, layer_past)
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape), presents
return hidden_states.view(*output_shape), presents return hidden_states.view(*output_shape), presents
...@@ -604,30 +691,38 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -604,30 +691,38 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2LMHeadModel, self).__init__(config) super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents
return lm_logits, presents return lm_logits, presents
...@@ -680,32 +775,40 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -680,32 +775,40 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents return lm_logits, mc_logits, presents
...@@ -143,6 +143,7 @@ class OpenAIGPTConfig(object): ...@@ -143,6 +143,7 @@ class OpenAIGPTConfig(object):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True
): ):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
...@@ -165,6 +166,7 @@ class OpenAIGPTConfig(object): ...@@ -165,6 +166,7 @@ class OpenAIGPTConfig(object):
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -186,6 +188,7 @@ class OpenAIGPTConfig(object): ...@@ -186,6 +188,7 @@ class OpenAIGPTConfig(object):
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
...@@ -253,7 +256,7 @@ class Conv1D(nn.Module): ...@@ -253,7 +256,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -262,6 +265,7 @@ class Attention(nn.Module): ...@@ -262,6 +265,7 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx) self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -278,6 +282,8 @@ class Attention(nn.Module): ...@@ -278,6 +282,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w) w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -300,9 +306,13 @@ class Attention(nn.Module): ...@@ -300,9 +306,13 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a
return a return a
...@@ -322,19 +332,24 @@ class MLP(nn.Module): ...@@ -322,19 +332,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale) self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x): def forward(self, x):
a = self.attn(x) a = self.attn(x)
if self.output_attentions:
attentions, a = a
n = self.ln_1(x + a) n = self.ln_1(x + a)
m = self.mlp(n) m = self.mlp(n)
h = self.ln_2(n + m) h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h
return h return h
...@@ -344,17 +359,21 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -344,17 +359,21 @@ class OpenAIGPTLMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(OpenAIGPTLMHead, self).__init__() super(OpenAIGPTLMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
if not self.predict_special_tokens:
lm_logits = lm_logits[..., :self.vocab_size]
return lm_logits return lm_logits
...@@ -364,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -364,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTMultipleChoiceHead, self).__init__() super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
# self.multiple_choice_token = multiple_choice_token
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
...@@ -415,9 +433,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -415,9 +433,6 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def set_num_special_tokens(self, num_special_tokens):
pass
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
""" """
...@@ -594,17 +609,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -594,17 +609,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
num_tokens = config.vocab_size + config.n_special self.output_attentions = output_attentions
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd) self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights) self.apply(self.init_weights)
# nn.init.normal_(self.embed.weight, std=0.02)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice if needed " " Update input embeddings with new embedding matrice if needed "
...@@ -640,12 +654,19 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -640,12 +654,19 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
token_type_embeds = self.tokens_embed(token_type_ids) token_type_embeds = self.tokens_embed(token_type_ids)
else: else:
token_type_embeds = 0 token_type_embeds = 0
# Add the position information to the input embeddings
# h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
all_attentions = []
for block in self.h: for block in self.h:
if self.output_attentions:
attentions, hidden_states = block(hidden_states)
all_attentions.append(attentions)
else:
hidden_states = block(hidden_states) hidden_states = block(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
return hidden_states.view(*output_shape) return hidden_states.view(*output_shape)
...@@ -705,21 +726,24 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -705,21 +726,24 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTLMHeadModel, self).__init__(config) super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -730,6 +754,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -730,6 +754,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits
return lm_logits return lm_logits
...@@ -794,22 +820,25 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -794,22 +820,25 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
...@@ -823,4 +852,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -823,4 +852,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits return lm_logits, mc_logits
...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__) ...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
} }
PRETRAINED_MERGES_ARCHIVE_MAP = { PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024, 'gpt2': 1024,
...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object): ...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
def encode(self, text): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens): def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join([self.decoder[token] for token in tokens]) text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object): ...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '') out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string return out_string
......
...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_special=1,
n_positions=33, n_positions=33,
n_embd=32, n_embd=32,
n_layer=5, n_layer=5,
...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) total_num_tokens = self.vocab_size + self.n_special
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None position_ids = None
if self.use_position_ids: if self.use_position_ids:
...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config = GPT2Config( config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions, n_positions=self.n_positions,
n_embd=self.n_embd, n_embd=self.n_embd,
n_layer=self.n_layer, n_layer=self.n_layer,
...@@ -129,11 +133,29 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -129,11 +133,29 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2LMHeadModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_lm_head_output(self, result): def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertEqual(self.n_layer, len(result["presents"]))
self.parent.assertListEqual(
list(result["presents"][0].size()),
[2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head])
def check_gpt2_lm_head_loss_output(self, result): def check_gpt2_lm_head_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -156,8 +178,25 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -156,8 +178,25 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2DoubleHeadsModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"mc_logits": mc_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_double_heads_output(self, result): def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase): ...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4,
scope=None): scope=None):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase): ...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase): ...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase): ...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase): ...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase): ...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase): ...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase): ...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase): ...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase): ...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase): ...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.num_labels]) [self.batch_size, self.seq_length, self.num_labels])
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase): ...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
tester.check_bert_for_token_classification_output(output_result) tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result) tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment