Commit ed4e5422 authored by thomwolf's avatar thomwolf
Browse files

adding tests

parent b90e29d5
...@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE ...@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead)
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
......
...@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel): ...@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
def __init__(self, base_model): def __init__(self, base_model):
super(AutoModelWithLMHead, self).__init__(base_model) super(AutoModelWithLMHead, self).__init__(base_model)
config = base_model.config
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel): ...@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = {
'num_labels': 2,
'summary_type': 'first',
'summary_use_proj': True,
'summary_activation': None,
'summary_proj_to_labels': True,
'summary_first_dropout': 0.1
}
class AutoModelForSequenceClassification(DerivedAutoModel): class AutoModelForSequenceClassification(DerivedAutoModel):
r""" r"""
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification :class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
...@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel): ...@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
def __init__(self, base_model): def __init__(self, base_model):
super(AutoModelForSequenceClassification, self).__init__(base_model) super(AutoModelForSequenceClassification, self).__init__(base_model)
self.num_labels = base_model.config.num_labels # Complete configuration with defaults if necessary
self.sequence_summary = SequenceSummary(base_model.config) config = base_model.config
for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items():
if not hasattr(config, key):
setattr(config, key, value)
# Update base model and derived model config
self.transformer.config = config
self.config = config
self.num_labels = config.num_labels
self.sequence_summary = SequenceSummary(config)
self.apply(self.init_weights) self.apply(self.init_weights)
......
...@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module): ...@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
super(SequenceSummary, self).__init__() super(SequenceSummary, self).__init__()
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
if config.summary_type == 'attn': if self.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that. # We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 # We can probably just use the multi-head attention module of PyTorch >=1.1.0
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import logging
from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel, AutoModelForSequenceClassification, AutoModelWithLMHead
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class AutoModelTest(unittest.TestCase):
def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = AutoModel.from_pretrained(model_name)
model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, BertModel)
for value in loading_info.values():
self.assertEqual(len(value), 0)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel)
model = AutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment