Unverified Commit 7424b284 authored by ziliwang's avatar ziliwang Committed by GitHub
Browse files

Merge pull request #1 from huggingface/master

merege from original repo
parents 6060b2f8 364920e2
...@@ -17,14 +17,12 @@ from __future__ import division ...@@ -17,14 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import pytest
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification) DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
class DistilBertModelTest(CommonTestCases.CommonModelTester): class DistilBertModelTest(CommonTestCases.CommonModelTester):
...@@ -148,7 +146,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -148,7 +146,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertForQuestionAnswering(config=config) model = DistilBertForQuestionAnswering(config=config)
model.eval() model.eval()
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels) loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels)
result = { result = {
"loss": loss, "loss": loss,
"start_logits": start_logits, "start_logits": start_logits,
...@@ -166,7 +164,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -166,7 +164,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = DistilBertForSequenceClassification(config) model = DistilBertForSequenceClassification(config)
model.eval() model.eval()
loss, logits = model(input_ids, input_mask, sequence_labels) loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
result = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment