Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4d47f498
Commit
4d47f498
authored
Jun 26, 2019
by
thomwolf
Browse files
slight refactoring, add abstract class for model loading
parent
59cefd4f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
404 additions
and
891 deletions
+404
-891
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+2
-1
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+177
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+15
-162
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+130
-134
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+26
-140
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+23
-151
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+10
-144
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+21
-159
No files found.
pytorch_pretrained_bert/__init__.py
View file @
4d47f498
...
...
@@ -28,4 +28,5 @@ from .optimization_openai import OpenAIAdam
from
.file_utils
import
(
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
)
from
.model_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
)
from
.model_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
pytorch_pretrained_bert/model_utils.py
View file @
4d47f498
...
...
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
CONFIG_NAME
=
"config.json"
WEIGHTS_NAME
=
"pytorch_model.bin"
TF_WEIGHTS_NAME
=
'model.ckpt'
class
PretrainedConfig
(
object
):
...
...
@@ -131,6 +132,169 @@ class PretrainedConfig(object):
writer
.
write
(
self
.
to_json_string
())
class
PreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
PretrainedConfig
pretrained_model_archive_map
=
{}
pretrained_config_archive_map
=
{}
load_tf_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
PretrainedConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load, or
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_model_archive_map
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
cls
.
pretrained_config_archive_map
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
cls
.
config_class
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
cls
.
base_model_prefix
+
'.'
# Used to be able to load base models as well as derived modesl (with heads)
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
tie_weights
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
def
prune_linear_layer
(
layer
,
index
,
dim
=
0
):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
...
...
@@ -197,3 +361,16 @@ def prune_conv1d_layer(layer, index, dim=1):
new_layer
.
bias
.
copy_
(
b
.
contiguous
())
new_layer
.
bias
.
requires_grad
=
True
return
new_layer
def
prune_layer
(
layer
,
index
,
dim
=
None
):
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
if
isinstance
(
layer
,
nn
.
Linear
):
return
prune_linear_layer
(
layer
,
index
,
dim
=
0
if
dim
is
None
else
dim
)
elif
isinstance
(
layer
,
Conv1D
):
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
else
:
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
pytorch_pretrained_bert/modeling.py
View file @
4d47f498
...
...
@@ -30,7 +30,7 @@ from torch import nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
,
prune_linear_layer
from
.model_utils
import
WEIGHTS_NAME
,
CONFIG_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_linear_layer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -64,11 +64,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json"
,
}
BERT_CONFIG_NAME
=
'bert_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_bert
(
model
,
tf_checkpoint_path
):
def
load_tf_weights_in_bert
(
model
,
config
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -168,7 +166,8 @@ class BertConfig(PretrainedConfig):
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
,
finetuning_task
=
None
):
"""Constructs BertConfig.
Args:
...
...
@@ -193,6 +192,7 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -213,6 +213,7 @@ class BertConfig(PretrainedConfig):
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
finetuning_task
=
finetuning_task
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -539,20 +540,18 @@ class BertPreTrainingHeads(nn.Module):
return
prediction_scores
,
seq_relationship_score
class
BertPreTrainedModel
(
nn
.
Module
):
class
BertPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
BertConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
BertConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_bert
base_model_prefix
=
"bert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
BertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -567,152 +566,6 @@ class BertPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
. `bert-base-german-cased`
. `bert-large-uncased-whole-word-masking`
. `bert-large-cased-whole-word-masking`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
BERT_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
BertConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_bert
(
model
,
resolved_archive_file
)
# Load from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
'gamma'
in
key
:
new_key
=
key
.
replace
(
'gamma'
,
'weight'
)
if
'beta'
in
key
:
new_key
=
key
.
replace
(
'beta'
,
'bias'
)
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'bert'
)
and
any
(
s
.
startswith
(
'bert.'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'bert.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
model
class
BertModel
(
BertPreTrainedModel
):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
4d47f498
...
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.file_utils
import
cached_path
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
prune_conv1d_layer
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_conv1d_layer
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -42,7 +42,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"
}
def
load_tf_weights_in_gpt2
(
model
,
gpt2_checkpoint_path
):
def
load_tf_weights_in_gpt2
(
model
,
config
,
gpt2_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -356,22 +356,18 @@ class GPT2MultipleChoiceHead(nn.Module):
return
multiple_choice_logits
class
GPT2PreTrainedModel
(
nn
.
Module
):
class
GPT2PreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
GPT2Config
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
GPT2Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -407,130 +403,130 @@ class GPT2PreTrainedModel(nn.Module):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
#
state_dict = kwargs.get('state_dict', None)
#
kwargs.pop('state_dict', None)
#
cache_dir = kwargs.get('cache_dir', None)
#
kwargs.pop('cache_dir', None)
#
from_tf = kwargs.get('from_tf', False)
#
kwargs.pop('from_tf', None)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
GPT2Config
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_gpt2
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
key
.
endswith
(
".g"
):
new_key
=
key
[:
-
2
]
+
".weight"
elif
key
.
endswith
(
".b"
):
new_key
=
key
[:
-
2
]
+
".bias"
elif
key
.
endswith
(
".w"
):
new_key
=
key
[:
-
2
]
+
".weight"
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
start_model
=
model
if
hasattr
(
model
,
"transformer"
)
and
all
(
not
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_model
=
model
.
transformer
load
(
start_model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
#
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
#
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
#
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
#
else:
#
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
#
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
#
#
redirect to the cache, if necessary
#
try:
#
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
#
except EnvironmentError:
#
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
#
logger.error(
#
"Couldn't reach server at '{}' to download pretrained weights.".format(
#
archive_file))
#
else:
#
logger.error(
#
"Model name '{}' was not found in model name list ({}). "
#
"We assumed '{}' was a path or url but couldn't find file {} "
#
"at this path or url.".format(
#
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
#
archive_file
#
)
#
)
#
return None
#
try:
#
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
#
except EnvironmentError:
#
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
#
logger.error(
#
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
#
config_file))
#
else:
#
logger.error(
#
"Model name '{}' was not found in model name list ({}). "
#
"We assumed '{}' was a path or url but couldn't find file {} "
#
"at this path or url.".format(
#
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
#
config_file
#
)
#
)
#
return None
#
if resolved_archive_file == archive_file and resolved_config_file == config_file:
#
logger.info("loading weights file {}".format(archive_file))
#
logger.info("loading configuration file {}".format(config_file))
#
else:
#
logger.info("loading weights file {} from cache at {}".format(
#
archive_file, resolved_archive_file))
#
logger.info("loading configuration file {} from cache at {}".format(
#
config_file, resolved_config_file))
#
#
Load config
#
config = GPT2Config.from_json_file(resolved_config_file)
#
logger.info("Model config {}".format(config))
#
#
Instantiate model.
#
model = cls(config, *inputs, **kwargs)
#
if state_dict is None and not from_tf:
#
state_dict = torch.load(resolved_archive_file, map_location='cpu')
#
if from_tf:
#
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
#
return load_tf_weights_in_gpt2(model, resolved_archive_file)
#
old_keys = []
#
new_keys = []
#
for key in state_dict.keys():
#
new_key = None
#
if key.endswith(".g"):
#
new_key = key[:-2] + ".weight"
#
elif key.endswith(".b"):
#
new_key = key[:-2] + ".bias"
#
elif key.endswith(".w"):
#
new_key = key[:-2] + ".weight"
#
if new_key:
#
old_keys.append(key)
#
new_keys.append(new_key)
#
for old_key, new_key in zip(old_keys, new_keys):
#
state_dict[new_key] = state_dict.pop(old_key)
#
missing_keys = []
#
unexpected_keys = []
#
error_msgs = []
#
#
copy state_dict so _load_from_state_dict can modify it
#
metadata = getattr(state_dict, "_metadata", None)
#
state_dict = state_dict.copy()
#
if metadata is not None:
#
state_dict._metadata = metadata
#
def load(module, prefix=""):
#
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
#
module._load_from_state_dict(
#
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
#
)
#
for name, child in module._modules.items():
#
if child is not None:
#
load(child, prefix + name + ".")
#
start_model = model
#
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
#
start_model = model.transformer
#
load(start_model, prefix="")
#
if len(missing_keys) > 0:
#
logger.info(
#
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
#
)
#
if len(unexpected_keys) > 0:
#
logger.info(
#
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
#
)
#
if len(error_msgs) > 0:
#
raise RuntimeError(
#
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
#
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
)
return
model
...
...
@@ -608,9 +604,9 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
=
None
):
" Update input embeddings with new embedding matrice if needed "
if
self
.
config
.
n_special
==
num_special_tokens
:
if
num_special_tokens
is
None
or
self
.
config
.
n_special
==
num_special_tokens
:
return
# Update config
self
.
config
.
n_special
=
num_special_tokens
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
4d47f498
...
...
@@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss
from
torch.nn.parameter
import
Parameter
from
.file_utils
import
cached_path
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
prune_conv1d_layer
from
.model_utils
import
Conv1D
,
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_conv1d_layer
from
.modeling
import
BertLayerNorm
as
LayerNorm
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,12 +41,17 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
}
def
load_tf_weights_in_openai_gpt
(
model
,
openai_checkpoint_folder_path
):
def
load_tf_weights_in_openai_gpt
(
model
,
config
,
openai_checkpoint_folder_path
):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
import
re
import
numpy
as
np
print
(
"Loading weights..."
)
if
'.ckpt'
in
openai_checkpoint_folder_path
:
openai_checkpoint_folder_path
=
os
.
path
.
dirname
(
openai_checkpoint_folder_path
)
logger
.
info
(
"Loading weights from {}"
.
format
(
openai_checkpoint_folder_path
))
names
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/parameters_names.json'
,
"r"
,
encoding
=
'utf-8'
))
shapes
=
json
.
load
(
open
(
openai_checkpoint_folder_path
+
'/params_shapes.json'
,
"r"
,
encoding
=
'utf-8'
))
offsets
=
np
.
cumsum
([
np
.
prod
(
shape
)
for
shape
in
shapes
])
...
...
@@ -377,22 +382,18 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
return
multiple_choice_logits
class
OpenAIGPTPreTrainedModel
(
nn
.
Module
):
class
OpenAIGPTPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
OpenAIGPTConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `OpenAIGPTConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
)
)
self
.
config
=
config
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -408,7 +409,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
num_special_tokens
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -416,140 +417,25 @@ class OpenAIGPTPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `
openai_gpt_
config.json` a configuration file for the model
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `
openai-gpt-
config.json` a configuration file for the model
. `config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
OpenAIGPTConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
return
load_tf_weights_in_openai_gpt
(
model
,
resolved_archive_file
)
old_keys
=
[]
new_keys
=
[]
for
key
in
state_dict
.
keys
():
new_key
=
None
if
key
.
endswith
(
".g"
):
new_key
=
key
[:
-
2
]
+
".weight"
elif
key
.
endswith
(
".b"
):
new_key
=
key
[:
-
2
]
+
".bias"
elif
key
.
endswith
(
".w"
):
new_key
=
key
[:
-
2
]
+
".weight"
if
new_key
:
old_keys
.
append
(
key
)
new_keys
.
append
(
new_key
)
for
old_key
,
new_key
in
zip
(
old_keys
,
new_keys
):
state_dict
[
new_key
]
=
state_dict
.
pop
(
old_key
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
"_metadata"
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
"."
)
start_model
=
model
if
hasattr
(
model
,
"transformer"
)
and
all
(
not
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_model
=
model
.
transformer
load
(
start_model
,
prefix
=
""
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
)
)
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
)
)
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
num_special_tokens
=
kwargs
.
get
(
'num_special_tokens'
,
None
)
kwargs
.
pop
(
'num_special_tokens'
,
None
)
model
=
PreTrainedModel
.
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
model
.
set_num_special_tokens
(
num_special_tokens
)
return
model
...
...
@@ -621,9 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
def
set_num_special_tokens
(
self
,
num_special_tokens
=
None
):
" Update input embeddings with new embedding matrice if needed "
if
self
.
config
.
n_special
==
num_special_tokens
:
if
num_special_tokens
is
None
or
self
.
config
.
n_special
==
num_special_tokens
:
return
# Update config
self
.
config
.
n_special
=
num_special_tokens
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
4d47f498
...
...
@@ -38,7 +38,7 @@ from torch.nn.parameter import Parameter
from
.modeling
import
BertLayerNorm
as
LayerNorm
from
.modeling_transfo_xl_utilities
import
ProjectedAdaptiveLogSoftmax
,
sample_logits
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -49,8 +49,6 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json"
,
}
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
...
...
@@ -787,28 +785,26 @@ class AdaptiveEmbedding(nn.Module):
return
embed
class
TransfoXLPreTrainedModel
(
nn
.
Module
):
class
TransfoXLPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
TransfoXLConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
def
init_weight
(
self
,
weight
):
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
init_bias
(
self
,
bias
):
def
_
init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
init_weights
(
self
,
m
):
...
...
@@ -817,9 +813,9 @@ class TransfoXLPreTrainedModel(nn.Module):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
init_weight
(
m
.
weight
)
self
.
_
init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
self
.
_
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
...
...
@@ -827,12 +823,12 @@ class TransfoXLPreTrainedModel(nn.Module):
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
init_weight
(
m
.
weight
)
self
.
_
init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
init_weight
(
m
.
cluster_weight
)
self
.
_
init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
init_bias
(
m
.
cluster_bias
)
self
.
_
init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
...
...
@@ -841,144 +837,20 @@ class TransfoXLPreTrainedModel(nn.Module):
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
init_bias
(
m
.
bias
)
self
.
_
init_bias
(
m
.
bias
)
elif
classname
.
find
(
'TransformerLM'
)
!=
-
1
:
if
hasattr
(
m
,
'r_emb'
):
self
.
init_weight
(
m
.
r_emb
)
self
.
_
init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
init_weight
(
m
.
r_w_bias
)
self
.
_
init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
init_weight
(
m
.
r_r_bias
)
self
.
_
init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
init_bias
(
m
.
r_bias
)
self
.
_
init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
archive_file
)
)
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
", "
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
config_file
)
)
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
TransfoXLConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_transfo_xl
(
model
,
config
,
pretrained_model_name_or_path
)
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer.'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
# Make sure we are still sharing the input and output embeddings
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
return
model
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
4d47f498
...
...
@@ -36,7 +36,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -390,20 +390,18 @@ class BeamHypotheses(object):
return
self
.
worst_score
>=
best_sum_logprobs
/
self
.
max_len
**
self
.
length_penalty
class
XLMPreTrainedModel
(
nn
.
Module
):
class
XLMPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLMBaseConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLMBaseConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
XLMConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
None
base_model_prefix
=
"xlm"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -423,138 +421,6 @@ class XLMPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a XLMPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLMForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLM class
(ex: num_labels for XLMForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
XLNET_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLMConfig
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
isinstance
(
model
,
XLMLMHeadModel
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
class
XLMModel
(
XLMPreTrainedModel
):
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
4d47f498
...
...
@@ -33,7 +33,7 @@ from torch.nn import functional as F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
from
.model_utils
import
CONFIG_NAME
,
WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -44,11 +44,9 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
,
finetuning_task
=
None
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
=
None
):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
...
...
@@ -64,9 +62,9 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
# We will load also the sequence summary
tf_to_pt_map
[
'model/sequnece_summary/summary/kernel'
]
=
model
.
sequence_summary
.
summary
.
weight
tf_to_pt_map
[
'model/sequnece_summary/summary/bias'
]
=
model
.
sequence_summary
.
summary
.
bias
if
hasattr
(
model
,
'logits_proj'
)
and
finetuning_task
is
not
None
and
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)
in
tf_weights
:
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
finetuning_task
)]
=
model
.
logits_proj
.
bias
if
hasattr
(
model
,
'logits_proj'
)
and
config
.
finetuning_task
is
not
None
and
'model/regression_{}/logit/kernel'
.
format
(
finetuning_task
)
in
tf_weights
:
tf_to_pt_map
[
'model/regression_{}/logit/kernel'
.
format
(
config
.
finetuning_task
)]
=
model
.
logits_proj
.
weight
tf_to_pt_map
[
'model/regression_{}/logit/bias'
.
format
(
config
.
finetuning_task
)]
=
model
.
logits_proj
.
bias
# Now load the rest of the transformer
model
=
model
.
transformer
...
...
@@ -117,7 +115,7 @@ def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None, finetuning_tas
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
,
finetuning_task
=
None
):
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
...
...
@@ -138,7 +136,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path, finetuning_task=None):
input
(
"Press Enter to continue..."
)
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
,
finetuning_task
)
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
...
...
@@ -223,7 +221,8 @@ class XLNetConfig(PretrainedConfig):
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
same_length
=
False
,
finetuning_task
=
None
):
"""Constructs XLNetConfig.
Args:
...
...
@@ -265,6 +264,7 @@ class XLNetConfig(PretrainedConfig):
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -298,6 +298,7 @@ class XLNetConfig(PretrainedConfig):
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
finetuning_task
=
finetuning_task
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -550,20 +551,19 @@ class XLNetLayer(nn.Module):
# return attentions, layer_output
return
output_h
,
output_g
class
XLNetPreTrainedModel
(
nn
.
Module
):
class
XLNetPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLNetConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
config_class
=
XLNetConfig
pretrained_model_archive_map
=
PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map
=
PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
...
...
@@ -583,144 +583,6 @@ class XLNetPreTrainedModel(nn.Module):
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a XLNetPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
XLNET_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
# Update config with kwargs if needed
to_remove
=
[]
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
config
,
key
):
setattr
(
config
,
key
,
value
)
to_remove
.
append
(
key
)
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_xlnet
(
model
,
config
,
resolved_archive_file
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'transformer'
)
and
any
(
s
.
startswith
(
'transformer'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'transformer.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
isinstance
(
model
,
XLNetLMHeadModel
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment