"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cafa6a9e29f3e99c67a1028f8ca779d439bc0689"
Commit c9bce181 authored by thomwolf's avatar thomwolf
Browse files

fixing model to add torchscript, embedding resizing, head pruning and masking + tests

parent 62df4ba5
...@@ -449,7 +449,7 @@ class BertEncoder(nn.Module): ...@@ -449,7 +449,7 @@ class BertEncoder(nn.Module):
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module): class BertPooler(nn.Module):
......
This diff is collapsed.
...@@ -21,7 +21,7 @@ import shutil ...@@ -21,7 +21,7 @@ import shutil
import pytest import pytest
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM, from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
DilBertForQuestionAnswering, DilBertForSequenceClassification) DilBertForQuestionAnswering, DilBertForSequenceClassification)
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
...@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering, all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
DilBertForSequenceClassification) DilBertForSequenceClassification)
test_pruning = False test_pruning = True
test_torchscript = False test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = True
test_head_masking = False test_head_masking = True
class DilBertModelTester(object): class DilBertModelTester(object):
...@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config) model = DilBertModel(config=config)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, input_mask) (sequence_output,) = model(input_ids, input_mask)
sequence_output, pooled_output = model(input_ids) (sequence_output,) = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
"pooled_output": pooled_output,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config) model = DilBertForMaskedLM(config=config)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, input_mask, token_labels) loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment