Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ed4e5422
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5cd9e2cba13afb1e5d00401f1ebc1dc733070d46"
Commit
ed4e5422
authored
Aug 05, 2019
by
thomwolf
Browse files
adding tests
parent
b90e29d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
3 deletions
+83
-3
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-0
pytorch_transformers/modeling_auto.py
pytorch_transformers/modeling_auto.py
+25
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/tests/modeling_auto_test.py
pytorch_transformers/tests/modeling_auto_test.py
+55
-0
No files found.
pytorch_transformers/__init__.py
View file @
ed4e5422
...
...
@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
from
.modeling_auto
import
(
AutoConfig
,
AutoModel
,
AutoModelForSequenceClassification
,
AutoModelWithLMHead
)
from
.modeling_bert
import
(
BertConfig
,
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertForNextSentencePrediction
,
BertForSequenceClassification
,
BertForMultipleChoice
,
...
...
pytorch_transformers/modeling_auto.py
View file @
ed4e5422
...
...
@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
def
__init__
(
self
,
base_model
):
super
(
AutoModelWithLMHead
,
self
).
__init__
(
base_model
)
config
=
base_model
.
config
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
apply
(
self
.
init_weights
)
...
...
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS
=
{
'num_labels'
:
2
,
'summary_type'
:
'first'
,
'summary_use_proj'
:
True
,
'summary_activation'
:
None
,
'summary_proj_to_labels'
:
True
,
'summary_first_dropout'
:
0.1
}
class
AutoModelForSequenceClassification
(
DerivedAutoModel
):
r
"""
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
...
...
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
def
__init__
(
self
,
base_model
):
super
(
AutoModelForSequenceClassification
,
self
).
__init__
(
base_model
)
self
.
num_labels
=
base_model
.
config
.
num_labels
self
.
sequence_summary
=
SequenceSummary
(
base_model
.
config
)
# Complete configuration with defaults if necessary
config
=
base_model
.
config
for
key
,
value
in
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS
.
items
():
if
not
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
# Update base model and derived model config
self
.
transformer
.
config
=
config
self
.
config
=
config
self
.
num_labels
=
config
.
num_labels
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
...
...
pytorch_transformers/modeling_utils.py
View file @
ed4e5422
...
...
@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
super
(
SequenceSummary
,
self
).
__init__
()
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
if
config
.
summary_type
==
'attn'
:
if
self
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
...
...
pytorch_transformers/tests/modeling_auto_test.py
0 → 100644
View file @
ed4e5422
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
logging
from
pytorch_transformers
import
AutoConfig
,
BertConfig
,
AutoModel
,
BertModel
,
AutoModelForSequenceClassification
,
AutoModelWithLMHead
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
class
AutoModelTest
(
unittest
.
TestCase
):
def
test_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
AutoModel
.
from_pretrained
(
model_name
)
model
,
loading_info
=
AutoModel
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertModel
)
for
value
in
loading_info
.
values
():
self
.
assertEqual
(
len
(
value
),
0
)
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
getattr
(
model
,
model
.
base_model_prefix
),
BertModel
)
model
=
AutoModelWithLMHead
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
getattr
(
model
,
model
.
base_model_prefix
),
BertModel
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment