Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2a667b1e
Commit
2a667b1e
authored
Sep 05, 2019
by
thomwolf
Browse files
split configuration and modeling files
parent
0be6a2a6
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
91 additions
and
323 deletions
+91
-323
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+3
-163
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+3
-148
pytorch_transformers/tests/configuration_common_test.py
pytorch_transformers/tests/configuration_common_test.py
+63
-0
pytorch_transformers/tests/modeling_auto_test.py
pytorch_transformers/tests/modeling_auto_test.py
+2
-1
pytorch_transformers/tests/modeling_bert_test.py
pytorch_transformers/tests/modeling_bert_test.py
+2
-1
pytorch_transformers/tests/modeling_common_test.py
pytorch_transformers/tests/modeling_common_test.py
+3
-3
pytorch_transformers/tests/modeling_distilbert_test.py
pytorch_transformers/tests/modeling_distilbert_test.py
+3
-1
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+2
-1
pytorch_transformers/tests/modeling_openai_test.py
pytorch_transformers/tests/modeling_openai_test.py
+2
-1
pytorch_transformers/tests/modeling_roberta_test.py
pytorch_transformers/tests/modeling_roberta_test.py
+2
-1
pytorch_transformers/tests/modeling_transfo_xl_test.py
pytorch_transformers/tests/modeling_transfo_xl_test.py
+2
-1
pytorch_transformers/tests/modeling_xlm_test.py
pytorch_transformers/tests/modeling_xlm_test.py
+2
-1
pytorch_transformers/tests/modeling_xlnet_test.py
pytorch_transformers/tests/modeling_xlnet_test.py
+2
-1
No files found.
pytorch_transformers/modeling_xlm.py
View file @
2a667b1e
...
@@ -16,11 +16,8 @@
...
@@ -16,11 +16,8 @@
"""
"""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
logging
import
math
import
math
import
sys
from
io
import
open
import
itertools
import
itertools
import
numpy
as
np
import
numpy
as
np
...
@@ -30,8 +27,9 @@ from torch import nn
...
@@ -30,8 +27,9 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
(
PretrainedConfig
,
PreTrainedModel
,
add_start_docstrings
,
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
,
SequenceSummary
,
SQuADHead
prune_linear_layer
,
SequenceSummary
,
SQuADHead
)
from
.configuration_xlm
import
XLMConfig
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -47,164 +45,6 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -47,164 +45,6 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-17-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin"
,
'xlm-mlm-17-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin"
,
'xlm-mlm-100-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin"
,
'xlm-mlm-100-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin"
,
}
}
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json"
,
'xlm-mlm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json"
,
'xlm-mlm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json"
,
'xlm-mlm-enro-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json"
,
'xlm-mlm-tlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json"
,
'xlm-mlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json"
,
'xlm-clm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json"
,
'xlm-clm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json"
,
'xlm-mlm-17-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json"
,
'xlm-mlm-100-1280'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json"
,
}
class
XLMConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a `XLMModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
d_inner: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
attn_type: 'bi' for XLM, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
pretrained_config_archive_map
=
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
=
30145
,
emb_dim
=
2048
,
n_layers
=
12
,
n_heads
=
16
,
dropout
=
0.1
,
attention_dropout
=
0.1
,
gelu_activation
=
True
,
sinusoidal_embeddings
=
False
,
causal
=
False
,
asm
=
False
,
n_langs
=
1
,
use_lang_emb
=
True
,
max_position_embeddings
=
512
,
embed_init_std
=
2048
**
-
0.5
,
layer_norm_eps
=
1e-12
,
init_std
=
0.02
,
bos_index
=
0
,
eos_index
=
1
,
pad_index
=
2
,
unk_index
=
3
,
mask_index
=
5
,
is_encoder
=
True
,
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
'first'
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
summary_first_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
"""Constructs XLMConfig.
"""
super
(
XLMConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
n_words
=
vocab_size_or_config_json_file
self
.
emb_dim
=
emb_dim
self
.
n_layers
=
n_layers
self
.
n_heads
=
n_heads
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
gelu_activation
=
gelu_activation
self
.
sinusoidal_embeddings
=
sinusoidal_embeddings
self
.
causal
=
causal
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
use_lang_emb
=
use_lang_emb
self
.
layer_norm_eps
=
layer_norm_eps
self
.
bos_index
=
bos_index
self
.
eos_index
=
eos_index
self
.
pad_index
=
pad_index
self
.
unk_index
=
unk_index
self
.
mask_index
=
mask_index
self
.
is_encoder
=
is_encoder
self
.
max_position_embeddings
=
max_position_embeddings
self
.
embed_init_std
=
embed_init_std
self
.
init_std
=
init_std
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_proj_to_labels
=
summary_proj_to_labels
self
.
summary_first_dropout
=
summary_first_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)"
)
@
property
def
vocab_size
(
self
):
return
self
.
n_words
@
vocab_size
.
setter
def
vocab_size
(
self
,
value
):
self
.
n_words
=
value
@
property
def
hidden_size
(
self
):
return
self
.
emb_dim
@
property
def
num_attention_heads
(
self
):
return
self
.
n_heads
@
property
def
num_hidden_layers
(
self
):
return
self
.
n_layers
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
...
...
pytorch_transformers/modeling_xlnet.py
View file @
2a667b1e
...
@@ -29,9 +29,9 @@ from torch import nn
...
@@ -29,9 +29,9 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
Pre
t
rained
Config
,
PreTrainedModel
,
from
.modeling_utils
import
Pre
T
rained
Model
,
prune_linear_layer
,
SequenceSummary
,
PoolerAnswerClass
,
PoolerEndLogits
,
PoolerStartLogits
SequenceSummary
,
PoolerAnswerClass
,
PoolerEndLogits
,
PoolerStartLogits
,
from
.configuration_xlnet
import
XLNetConfig
add_start_docstrings
)
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -40,10 +40,6 @@ XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -40,10 +40,6 @@ XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin"
,
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin"
,
}
}
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlnet-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json"
,
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
...
@@ -192,147 +188,6 @@ def swish(x):
...
@@ -192,147 +188,6 @@ def swish(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
class
XLNetConfig
(
PretrainedConfig
):
"""Configuration class to store the configuration of a ``XLNetModel``.
Args:
vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
d_inner: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
pretrained_config_archive_map
=
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
=
32000
,
d_model
=
1024
,
n_layer
=
24
,
n_head
=
16
,
d_inner
=
4096
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
attn_type
=
"bi"
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
dropout
=
0.1
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
,
finetuning_task
=
None
,
num_labels
=
2
,
summary_type
=
'last'
,
summary_use_proj
=
True
,
summary_activation
=
'tanh'
,
summary_last_dropout
=
0.1
,
start_n_top
=
5
,
end_n_top
=
5
,
**
kwargs
):
"""Constructs XLNetConfig.
"""
super
(
XLNetConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
n_token
=
vocab_size_or_config_json_file
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
assert
d_model
%
n_head
==
0
self
.
d_head
=
d_model
//
n_head
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
dropout
=
dropout
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
finetuning_task
=
finetuning_task
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_last_dropout
=
summary_last_dropout
self
.
start_n_top
=
start_n_top
self
.
end_n_top
=
end_n_top
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)"
)
@
property
def
max_position_embeddings
(
self
):
return
-
1
@
property
def
vocab_size
(
self
):
return
self
.
n_token
@
vocab_size
.
setter
def
vocab_size
(
self
,
value
):
self
.
n_token
=
value
@
property
def
hidden_size
(
self
):
return
self
.
d_model
@
property
def
num_attention_heads
(
self
):
return
self
.
n_head
@
property
def
num_hidden_layers
(
self
):
return
self
.
n_layer
try
:
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
XLNetLayerNorm
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
XLNetLayerNorm
except
(
ImportError
,
AttributeError
)
as
e
:
except
(
ImportError
,
AttributeError
)
as
e
:
...
...
pytorch_transformers/tests/configuration_common_test.py
0 → 100644
View file @
2a667b1e
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
import
os
import
shutil
import
json
import
random
import
uuid
import
unittest
import
logging
class
ConfigTester
(
object
):
def
__init__
(
self
,
parent
,
config_class
=
None
,
**
kwargs
):
self
.
parent
=
parent
self
.
config_class
=
config_class
self
.
inputs_dict
=
kwargs
def
create_and_test_config_common_properties
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
self
.
parent
.
assertTrue
(
hasattr
(
config
,
'vocab_size'
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
'hidden_size'
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
'num_attention_heads'
))
self
.
parent
.
assertTrue
(
hasattr
(
config
,
'num_hidden_layers'
))
def
create_and_test_config_to_json_string
(
self
):
config
=
self
.
config_class
(
**
self
.
inputs_dict
)
obj
=
json
.
loads
(
config
.
to_json_string
())
for
key
,
value
in
self
.
inputs_dict
.
items
():
self
.
parent
.
assertEqual
(
obj
[
key
],
value
)
def
create_and_test_config_to_json_file
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
json_file_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"config_"
+
str
(
uuid
.
uuid4
())
+
".json"
)
config_first
.
to_json_file
(
json_file_path
)
config_second
=
self
.
config_class
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
parent
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
def
run_common_tests
(
self
):
self
.
create_and_test_config_common_properties
()
self
.
create_and_test_config_to_json_string
()
self
.
create_and_test_config_to_json_file
()
if
__name__
==
"__main__"
:
unittest
.
main
()
\ No newline at end of file
pytorch_transformers/tests/modeling_auto_test.py
View file @
2a667b1e
...
@@ -28,7 +28,8 @@ from pytorch_transformers import (AutoConfig, BertConfig,
...
@@ -28,7 +28,8 @@ from pytorch_transformers import (AutoConfig, BertConfig,
AutoModelForQuestionAnswering
,
BertForQuestionAnswering
)
AutoModelForQuestionAnswering
,
BertForQuestionAnswering
)
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
AutoModelTest
(
unittest
.
TestCase
):
class
AutoModelTest
(
unittest
.
TestCase
):
...
...
pytorch_transformers/tests/modeling_bert_test.py
View file @
2a667b1e
...
@@ -26,7 +26,8 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
...
@@ -26,7 +26,8 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification
,
BertForMultipleChoice
)
BertForTokenClassification
,
BertForMultipleChoice
)
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
BertModelTest
(
CommonTestCases
.
CommonModelTester
):
class
BertModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_common_test.py
View file @
2a667b1e
...
@@ -28,9 +28,9 @@ import logging
...
@@ -28,9 +28,9 @@ import logging
import
torch
import
torch
from
pytorch_transformers
import
PretrainedConfig
,
PreTrainedModel
from
pytorch_transformers
import
(
PretrainedConfig
,
PreTrainedModel
,
from
pytorch_transformers.modeling_bert
import
BertModel
,
BertConfig
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
BertModel
,
BertConfig
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
from
pytorch_transformers.modeling_gpt2
import
GPT2LMHeadModel
,
GPT2Config
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
GPT2LMHeadModel
,
GPT2Config
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
def
_config_zero_init
(
config
):
def
_config_zero_init
(
config
):
...
...
pytorch_transformers/tests/modeling_distilbert_test.py
View file @
2a667b1e
...
@@ -18,13 +18,15 @@ from __future__ import print_function
...
@@ -18,13 +18,15 @@ from __future__ import print_function
import
unittest
import
unittest
import
shutil
import
shutil
import
sys
import
pytest
import
pytest
from
pytorch_transformers
import
(
DistilBertConfig
,
DistilBertModel
,
DistilBertForMaskedLM
,
from
pytorch_transformers
import
(
DistilBertConfig
,
DistilBertModel
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
)
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
)
from
pytorch_transformers.modeling_distilbert
import
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_distilbert
import
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
DistilBertModelTest
(
CommonTestCases
.
CommonModelTester
):
class
DistilBertModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_gpt2_test.py
View file @
2a667b1e
...
@@ -24,7 +24,8 @@ import shutil
...
@@ -24,7 +24,8 @@ import shutil
from
pytorch_transformers
import
(
GPT2Config
,
GPT2Model
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
from
pytorch_transformers
import
(
GPT2Config
,
GPT2Model
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
,
ids_tensor
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
GPT2ModelTest
(
CommonTestCases
.
CommonModelTester
):
class
GPT2ModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_openai_test.py
View file @
2a667b1e
...
@@ -24,7 +24,8 @@ import shutil
...
@@ -24,7 +24,8 @@ import shutil
from
pytorch_transformers
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
from
pytorch_transformers
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
,
ids_tensor
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
OpenAIGPTModelTest
(
CommonTestCases
.
CommonModelTester
):
class
OpenAIGPTModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_roberta_test.py
View file @
2a667b1e
...
@@ -24,7 +24,8 @@ import torch
...
@@ -24,7 +24,8 @@ import torch
from
pytorch_transformers
import
(
RobertaConfig
,
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
)
from
pytorch_transformers
import
(
RobertaConfig
,
RobertaModel
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
)
from
pytorch_transformers.modeling_roberta
import
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_roberta
import
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
RobertaModelTest
(
CommonTestCases
.
CommonModelTester
):
class
RobertaModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_transfo_xl_test.py
View file @
2a667b1e
...
@@ -28,7 +28,8 @@ import torch
...
@@ -28,7 +28,8 @@ import torch
from
pytorch_transformers
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_transformers
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_xlm_test.py
View file @
2a667b1e
...
@@ -23,7 +23,8 @@ import pytest
...
@@ -23,7 +23,8 @@ import pytest
from
pytorch_transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
pytorch_transformers/tests/modeling_xlnet_test.py
View file @
2a667b1e
...
@@ -28,7 +28,8 @@ import torch
...
@@ -28,7 +28,8 @@ import torch
from
pytorch_transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment