Commit c9bce181 authored by thomwolf's avatar thomwolf
Browse files

fixing model to add torchscript, embedding resizing, head pruning and masking + tests

parent 62df4ba5
......@@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):
......
This diff is collapsed.
......@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
DilBertForSequenceClassification)
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DilBertModelTester(object):
......@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, input_mask)
sequence_output, pooled_output = model(input_ids)
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, input_mask, token_labels)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment