Commit f7cd7392 authored by thomwolf's avatar thomwolf
Browse files

fixed tests

parent e28d8bde
...@@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module): ...@@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, position_ids=None, token_type_ids=None): def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
...@@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel): ...@@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
if token_type_ids is None: if token_type_ids is None:
...@@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel): ...@@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel):
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids) embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, encoder_outputs = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
head_mask=head_mask) head_mask=head_mask)
...@@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder, self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings) self.bert.embeddings.word_embeddings)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None): next_sentence_label=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
...@@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self._tie_or_clone_weights(self.cls.predictions.decoder, self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings) self.bert.embeddings.word_embeddings)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
...@@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1] pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output) seq_relationship_score = self.cls(pooled_output)
...@@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
...@@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
num_choices = input_ids.shape[1] num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask) outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, head_mask=head_mask)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
...@@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
...@@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None): end_positions=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask) outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
......
...@@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer.wte) self.transformer.wte)
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, past=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
past=past, head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None): position_ids=None, past=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
past=past, head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
......
...@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.transformer.tokens_embed) self.transformer.tokens_embed)
def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask) transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None, def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask) transformer_outputs = self.transformer(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
......
...@@ -1344,7 +1344,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1344,7 +1344,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz = input_ids.size(0) bsz = input_ids.size(0)
tgt_len = input_ids.size(1) tgt_len = input_ids.size(1)
transformer_outputs = self.transformer(input_ids, mems, head_mask) transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask)
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
......
...@@ -594,7 +594,7 @@ class SQuADHead(nn.Module): ...@@ -594,7 +594,7 @@ class SQuADHead(nn.Module):
""" """
outputs = () outputs = ()
start_logits = self.start_logits(hidden_states, p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, let's remove the dimension added by batch splitting # If we are on multi-GPU, let's remove the dimension added by batch splitting
......
...@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None): attention_mask=None, cache=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
outputs = self.pred_layer(output, labels) outputs = self.pred_layer(output, labels)
...@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None): attention_mask=None, cache=None, labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.sequence_summary(output) logits = self.sequence_summary(output)
...@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None, def forward(self, input_ids, lengths=None, position_ids=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, start_positions=None, end_positions=None, attention_mask=None, cache=None, start_positions=None, end_positions=None,
cls_index=None, is_impossible=None, p_mask=None, head_mask=None): cls_index=None, is_impossible=None, p_mask=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids, token_type_ids=token_type_ids, transformer_outputs = self.transformer(input_ids, lengths=lengths, position_ids=position_ids,
langs=langs, attention_mask=attention_mask, cache=cache, head_mask=head_mask) token_type_ids=token_type_ids, langs=langs,
attention_mask=attention_mask, cache=cache, head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
......
...@@ -1049,8 +1049,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1049,8 +1049,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None): labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
mems, perm_mask, target_mapping, head_mask) input_mask=input_mask, attention_mask=attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
...@@ -1119,8 +1121,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1119,8 +1121,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,
labels=None, head_mask=None): labels=None, head_mask=None):
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
mems, perm_mask, target_mapping, head_mask) input_mask=input_mask, attention_mask=attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
output = self.sequence_summary(output) output = self.sequence_summary(output)
...@@ -1209,10 +1213,12 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1209,10 +1213,12 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
mems=None, perm_mask=None, target_mapping=None, mems=None, perm_mask=None, target_mapping=None,
start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
head_mask=None): head_mask=None):
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
mems, perm_mask, target_mapping, head_mask) input_mask=input_mask, attention_mask=attention_mask,
mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment