Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4d47f498
"pytorch_transformers/tokenization_openai.py" did not exist on "e8568a3b17454dd4e0b32b6cd80617aa662cc996"
Commit
4d47f498
authored
Jun 26, 2019
by
thomwolf
Browse files
slight refactoring, add abstract class for model loading
parent
59cefd4f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
404 additions
and
891 deletions
+404
-891
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+2
-1
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+177
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+15
-162
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+130
-134
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+26
-140
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+23
-151
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+10
-144
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+21
-159
No files found.
pytorch_pretrained_bert/__init__.py
View file @
4d47f498
...
...
@@ -28,4 +28,5 @@ from .optimization_openai import OpenAIAdam
from
.file_utils
import
(
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
)
from
.model_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
)
from
.model_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
pytorch_pretrained_bert/model_utils.py
View file @
4d47f498
...
...
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
CONFIG_NAME
=
"config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
TF_WEIGHTS_NAME
=
'model.ckpt'
class
PretrainedConfig
(
object
):
...
...
@@ -131,6 +132,169 @@ class PretrainedConfig(object):
writer
.
write
(
self
.
to_json_string
())
class
PreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
PretrainedConfig
pretrained_model_archive_map
=
{}
pretrained_config_archive_map
=
{}
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
PretrainedConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load, or
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
cls
.
config_class
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
cls
.
base_model_prefix
+
'.'
# Used to be able to load base models as well as derived modesl (with heads)
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
tie_weights
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
def
prune_linear_layer
(
layer
,
index
,
dim
=
0
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
...
...
@@ -197,3 +361,16 @@ def prune_conv1d_layer(layer, index, dim=1):
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
prune_layer
(
layer
,
index
,
dim
=
None
):
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
if
isinstance
(
layer
,
nn
.
Linear
):
return
prune_linear_layer
(
layer
,
index
,
dim
=
0
if
dim
is
None
else
dim
)
elif
isinstance
(
layer
,
Conv1D
):
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
else
:
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
pytorch_pretrained_bert/modeling.py
View file @
4d47f498
...
...
@@ -30,7 +30,7 @@ from torch import nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
,
prune_linear_layer
from
.model_utils
import
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -64,11 +64,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json"
,
}
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
def
load_tf_weights_in_bert
(
model
,
config
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -168,7 +166,8 @@ class BertConfig(PretrainedConfig):
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
,
finetuning_task
=
None
):
"""Constructs BertConfig.
Args:
...
...
@@ -193,6 +192,7 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -213,6 +213,7 @@ class BertConfig(PretrainedConfig):
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
finetuning_task
=
finetuning_task
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -539,20 +540,18 @@ class BertPreTrainingHeads(nn.Module):
return
prediction_scores
,
seq_relationship_score
class
BertPreTrainedModel
(
nn
.
Module
):
class
BertPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
BertConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
BertConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -567,152 +566,6 @@ class BertPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
BERT_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
BertConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_bert
(
model
,
resolved_archive_file
)
# Load from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'bert'
)
and
any
(
s
.
startswith
(
'bert.'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'bert.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
model
class
BertModel
(
BertPreTrainedModel
):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
4d47f498
...
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.file_utils
import
cached_path
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
prune_conv1d_layer
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_conv1d_layer
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -42,7 +42,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
def
load_tf_weights_in_gpt2
(
model
,
config
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -356,22 +356,18 @@ class GPT2MultipleChoiceHead(nn.Module):
return
multiple_choice_logits
class
GPT2PreTrainedModel
(
nn
.
Module
):
class
GPT2PreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
GPT2Config
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
GPT2Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -407,130 +403,130 @@ class GPT2PreTrainedModel(nn.Module):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
#
state_dict = kwargs.get('state_dict', None)
#
kwargs.pop('state_dict', None)
#
cache_dir = kwargs.get('cache_dir', None)
#
kwargs.pop('cache_dir', None)
#
from_tf = kwargs.get('from_tf', False)
#
kwargs.pop('from_tf', None)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
GPT2Config
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_gpt2
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
key
.
endswith
(
".g"
):
new_key
=
key
[:
-
2
]
+
".weight"
elif
key
.
endswith
(
".b"
):
new_key
=
key
[:
-
2
]
+
".bias"
elif
key
.
endswith
(
".w"
):
new_key
=
key
[:
-
2
]
+
".weight"
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
start_model
=
model
if
hasattr
(
model
,
"transformer"
)
and
all
(
not
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_model
=
model
.
transformer
load
(
start_model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
#
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
#
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
#
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
#
else:
#
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
#
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
#
#
redirect to the cache, if necessary
#
try:
#
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
#
except EnvironmentError:
#
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
#
logger.error(
#
"Couldn't reach server at '{}' to download pretrained weights.".format(
#
archive_file))
#
else:
#
logger.error(
#
"Model name '{}' was not found in model name list ({}). "
#
"We assumed '{}' was a path or url but couldn't find file {} "
#
"at this path or url.".format(
#
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
#
archive_file
#
)
#
)
#
return None
#
try:
#
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
#
except EnvironmentError:
#
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
#
logger.error(
#
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
#
config_file))
#
else:
#
logger.error(
#
"Model name '{}' was not found in model name list ({}). "
#
"We assumed '{}' was a path or url but couldn't find file {} "
#
"at this path or url.".format(
#
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
#
config_file
#
)
#
)
#
return None
#
if resolved_archive_file == archive_file and resolved_config_file == config_file:
#
logger.info("loading weights file {}".format(archive_file))
#
logger.info("loading configuration file {}".format(config_file))
#
else:
#
logger.info("loading weights file {} from cache at {}".format(
#
archive_file, resolved_archive_file))
#
logger.info("loading configuration file {} from cache at {}".format(
#
config_file, resolved_config_file))
#
#
Load config
#
config = GPT2Config.from_json_file(resolved_config_file)
#
logger.info("Model config {}".format(config))
#
#
Instantiate model.
#
model = cls(config, *inputs, **kwargs)
#
if state_dict is None and not from_tf:
#
state_dict = torch.load(resolved_archive_file, map_location='cpu')
#
if from_tf:
#
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
#
return load_tf_weights_in_gpt2(model, resolved_archive_file)
#
old_keys = []
#
new_keys = []
#
for key in state_dict.keys():
#
new_key = None
#
if key.endswith(".g"):
#
new_key = key[:-2] + ".weight"
#
elif key.endswith(".b"):
#
new_key = key[:-2] + ".bias"
#
elif key.endswith(".w"):
#
new_key = key[:-2] + ".weight"
#
if new_key:
#
old_keys.append(key)
#
new_keys.append(new_key)
#
for old_key, new_key in zip(old_keys, new_keys):
#
state_dict[new_key] = state_dict.pop(old_key)
#
missing_keys = []
#
unexpected_keys = []
#
error_msgs = []
#
#
copy state_dict so _load_from_state_dict can modify it
#
metadata = getattr(state_dict, "_metadata", None)
#
state_dict = state_dict.copy()
#
if metadata is not None:
#
state_dict._metadata = metadata
#
def load(module, prefix=""):
#
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
#
module._load_from_state_dict(
#
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
#
)
#
for name, child in module._modules.items():
#
if child is not None:
#
load(child, prefix + name + ".")
#
start_model = model
#
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
#
start_model = model.transformer
#
load(start_model, prefix="")
#
if len(missing_keys) > 0:
#
logger.info(
#
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
#
)
#
if len(unexpected_keys) > 0:
#
logger.info(
#
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
#
)
#
if len(error_msgs) > 0:
#
raise RuntimeError(
#
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
#
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
)
return
model
...
...
@@ -608,9 +604,9 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
=
None
):
" Update input embeddings with new embedding matrice if needed "
if
self
.
config
.
n_special
==
num_special_tokens
:
if
num_special_tokens
is
None
or
self
.
config
.
n_special
==
num_special_tokens
:
return
# Update config
self
.
config
.
n_special
=
num_special_tokens
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
4d47f498
...
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.file_utils
import
cached_path
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
prune_conv1d_layer
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_conv1d_layer
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,12 +41,17 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
}
def
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
):
def
load_tf_weights_in_openai_gpt
(
model
,
config
,
openai_checkpoint_folder_path
):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
import
re
import
numpy
as
np
print
(
"Loading weights..."
)
if
'.ckpt'
in
openai_checkpoint_folder_path
:
openai_checkpoint_folder_path
=
os
.
path
.
dirname
(
openai_checkpoint_folder_path
)
logger
.
info
(
"Loading weights from {}"
.
format
(
openai_checkpoint_folder_path
))
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
offsets
=
np
.
cumsum
([
np
.
prod
(
shape
)
for
shape
in
shapes
])
...
...
@@ -377,22 +382,18 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
return
multiple_choice_logits
class
OpenAIGPTPreTrainedModel
(
nn
.
Module
):
class
OpenAIGPTPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
OpenAIGPTConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -408,7 +409,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
num_special_tokens
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -416,140 +417,25 @@ class OpenAIGPTPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `
openai_gpt_
config.json` a configuration file for the model
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `
openai-gpt-
config.json` a configuration file for the model
. `config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
OpenAIGPTConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
key
.
endswith
(
".g"
):
new_key
=
key
[:
-
2
]
+
".weight"
elif
key
.
endswith
(
".b"
):
new_key
=
key
[:
-
2
]
+
".bias"
elif
key
.
endswith
(
".w"
):
new_key
=
key
[:
-
2
]
+
".weight"
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
start_model
=
model
if
hasattr
(
model
,
"transformer"
)
and
all
(
not
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_model
=
model
.
transformer
load
(
start_model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
model
=
PreTrainedModel
.
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
)
return
model
...
...
@@ -621,9 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
=
None
):
" Update input embeddings with new embedding matrice if needed "
if
self
.
config
.
n_special
==
num_special_tokens
:
if
num_special_tokens
is
None
or
self
.
config
.
n_special
==
num_special_tokens
:
return
# Update config
self
.
config
.
n_special
=
num_special_tokens
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
4d47f498
...
...
@@ -38,7 +38,7 @@ from torch.nn.parameter import Parameter
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -49,8 +49,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json"
,
}
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
...
...
@@ -787,28 +785,26 @@ class AdaptiveEmbedding(nn.Module):
return
embed
class
TransfoXLPreTrainedModel
(
nn
.
Module
):
class
TransfoXLPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
TransfoXLConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
def
init_weight
(
self
,
weight
):
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
init_bias
(
self
,
bias
):
def
_
init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
init_weights
(
self
,
m
):
...
...
@@ -817,9 +813,9 @@ class TransfoXLPreTrainedModel(nn.Module):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
init_weight
(
m
.
weight
)
self
.
_
init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
self
.
_
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
...
...
@@ -827,12 +823,12 @@ class TransfoXLPreTrainedModel(nn.Module):
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
init_weight
(
m
.
weight
)
self
.
_
init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
init_weight
(
m
.
cluster_weight
)
self
.
_
init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
init_bias
(
m
.
cluster_bias
)
self
.
_
init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
...
...
@@ -841,144 +837,20 @@ class TransfoXLPreTrainedModel(nn.Module):
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
self
.
_
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
if
hasattr
(
m
,
'r_emb'
):
self
.
init_weight
(
m
.
r_emb
)
self
.
_
init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
init_weight
(
m
.
r_w_bias
)
self
.
_
init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
init_weight
(
m
.
r_r_bias
)
self
.
_
init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
init_bias
(
m
.
r_bias
)
self
.
_
init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
TransfoXLConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_transfo_xl
(
model
,
config
,
pretrained_model_name_or_path
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
# Make sure we are still sharing the input and output embeddings
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
return
model
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
4d47f498
...
...
@@ -36,7 +36,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -390,20 +390,18 @@ class BeamHypotheses(object):
return
self
.
worst_score
>=
best_sum_logprobs
/
self
.
max_len
**
self
.
length_penalty
class
XLMPreTrainedModel
(
nn
.
Module
):
class
XLMPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLMBaseConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLMBaseConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
None
base_model_prefix
=
"xlm"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -423,138 +421,6 @@ class XLMPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a XLMPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLMForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLM class
(ex: num_labels for XLMForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
XLNET_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLMConfig
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
isinstance
(
model
,
XLMLMHeadModel
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
class
XLMModel
(
XLMPreTrainedModel
):
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
4d47f498
...
...
@@ -33,7 +33,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -44,11 +44,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
,
finetuning_task
=
None
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
...
...
@@ -64,9 +62,9 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
# We will load also the sequence summary
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
if
hasattr
(
model
,
'logits_proj'
)
and
finetuning_task
is
not
None
and
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)
in
tf_weights
:
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
bias
if
hasattr
(
model
,
'logits_proj'
)
and
config
.
finetuning_task
is
not
None
and
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)
in
tf_weights
:
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
config
.
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
config
.
finetuning_task
)]
=
model
.
logits_proj
.
bias
# Now load the rest of the transformer
model
=
model
.
transformer
...
...
@@ -117,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
,
finetuning_task
=
None
):
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -138,7 +136,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
input
(
"Press Enter to continue..."
)
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
,
finetuning_task
)
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
...
...
@@ -223,7 +221,8 @@ class XLNetConfig(PretrainedConfig):
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
same_length
=
False
,
finetuning_task
=
None
):
"""Constructs XLNetConfig.
Args:
...
...
@@ -265,6 +264,7 @@ class XLNetConfig(PretrainedConfig):
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -298,6 +298,7 @@ class XLNetConfig(PretrainedConfig):
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
finetuning_task
=
finetuning_task
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -550,20 +551,19 @@ class XLNetLayer(nn.Module):
# return attentions, layer_output
return
output_h
,
output_g
class
XLNetPreTrainedModel
(
nn
.
Module
):
class
XLNetPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLNetConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
XLNetConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -583,144 +583,6 @@ class XLNetPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a XLNetPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
XLNET_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_xlnet
(
model
,
config
,
resolved_archive_file
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
isinstance
(
model
,
XLNetLMHeadModel
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment