Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ed4e5422
Commit
ed4e5422
authored
Aug 05, 2019
by
thomwolf
Browse files
adding tests
parent
b90e29d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
3 deletions
+83
-3
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-0
pytorch_transformers/modeling_auto.py
pytorch_transformers/modeling_auto.py
+25
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/tests/modeling_auto_test.py
pytorch_transformers/tests/modeling_auto_test.py
+55
-0
No files found.
pytorch_transformers/__init__.py
View file @
ed4e5422
...
@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
...
@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.modeling_auto
import
(
AutoConfig
,
AutoModel
,
AutoModelForSequenceClassification
,
AutoModelWithLMHead
)
from
.modeling_bert
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
from
.modeling_bert
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
BertForSequenceClassification
,
BertForMultipleChoice
,
...
...
pytorch_transformers/modeling_auto.py
View file @
ed4e5422
...
@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
...
@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
def
__init__
(
self
,
base_model
):
def
__init__
(
self
,
base_model
):
super
(
AutoModelWithLMHead
,
self
).
__init__
(
base_model
)
super
(
AutoModelWithLMHead
,
self
).
__init__
(
base_model
)
config
=
base_model
.
config
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
...
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS
=
{
'num_labels'
:
2
,
'summary_type'
:
'first'
,
'summary_use_proj'
:
True
,
'summary_activation'
:
None
,
'summary_proj_to_labels'
:
True
,
'summary_first_dropout'
:
0.1
}
class
AutoModelForSequenceClassification
(
DerivedAutoModel
):
class
AutoModelForSequenceClassification
(
DerivedAutoModel
):
r
"""
r
"""
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
...
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
...
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
def
__init__
(
self
,
base_model
):
def
__init__
(
self
,
base_model
):
super
(
AutoModelForSequenceClassification
,
self
).
__init__
(
base_model
)
super
(
AutoModelForSequenceClassification
,
self
).
__init__
(
base_model
)
self
.
num_labels
=
base_model
.
config
.
num_labels
# Complete configuration with defaults if necessary
self
.
sequence_summary
=
SequenceSummary
(
base_model
.
config
)
config
=
base_model
.
config
for
key
,
value
in
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS
.
items
():
if
not
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
# Update base model and derived model config
self
.
transformer
.
config
=
config
self
.
config
=
config
self
.
num_labels
=
config
.
num_labels
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_transformers/modeling_utils.py
View file @
ed4e5422
...
@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
...
@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
super
(
SequenceSummary
,
self
).
__init__
()
super
(
SequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
if
config
.
summary_type
==
'attn'
:
if
self
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
...
...
pytorch_transformers/tests/modeling_auto_test.py
0 → 100644
View file @
ed4e5422
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
logging
from
pytorch_transformers
import
AutoConfig
,
BertConfig
,
AutoModel
,
BertModel
,
AutoModelForSequenceClassification
,
AutoModelWithLMHead
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
class
AutoModelTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
AutoModel
.
from_pretrained
(
model_name
)
model
,
loading_info
=
AutoModel
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertModel
)
for
value
in
loading_info
.
values
():
self
.
assertEqual
(
len
(
value
),
0
)
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
getattr
(
model
,
model
.
base_model_prefix
),
BertModel
)
model
=
AutoModelWithLMHead
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
getattr
(
model
,
model
.
base_model_prefix
),
BertModel
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment