Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de713fa9
Commit
de713fa9
authored
Jun 20, 2019
by
thomwolf
Browse files
starting
parent
c304593d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
629 additions
and
2 deletions
+629
-2
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+62
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/modeling_transfo_xl.py
pytorch_pretrained_bert/modeling_transfo_xl.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+565
-0
pytorch_pretrained_bert/tokenization_xlnet.py
pytorch_pretrained_bert/tokenization_xlnet.py
+0
-0
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
0 → 100755
View file @
de713fa9
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
load_tf_weights_in_xlnet
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetModel
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--xlnet_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_path
)
pytorch_pretrained_bert/modeling.py
View file @
de713fa9
...
...
@@ -718,7 +718,7 @@ class BertPreTrainedModel(nn.Module):
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_bert
(
model
,
weights_path
)
return
load_tf_weights_in_bert
(
model
,
resolved_archive_file
)
# Load from a PyTorch state_dict
old_keys
=
[]
new_keys
=
[]
...
...
pytorch_pretrained_bert/modeling_transfo_xl.py
View file @
de713fa9
...
...
@@ -236,7 +236,7 @@ class TransfoXLConfig(object):
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention probabilities.
untie_r: untie relative position biases
untie_r: untie relative position biases
embd_pdrop: The dropout ratio for the embeddings.
init: parameter initializer to use
init_range: parameters initialized by U(-init_range, init_range).
...
...
pytorch_pretrained_bert/modeling_xlnet.py
0 → 100644
View file @
de713fa9
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch XLNet model.
"""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
copy
import
json
import
logging
import
math
import
os
import
sys
from
io
import
open
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
CONFIG_NAME
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_MODEL_ARCHIVE_MAP
=
{
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin"
,
}
PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
'xlnet-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json"
,
}
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'kernel'
or
l
[
0
]
==
'gamma'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'output_bias'
or
l
[
0
]
==
'beta'
:
pointer
=
getattr
(
pointer
,
'bias'
)
elif
l
[
0
]
==
'output_weights'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'squad'
:
pointer
=
getattr
(
pointer
,
'classifier'
)
else
:
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
def
gelu
(
x
):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return
x
*
0.5
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
def
swish
(
x
):
return
x
*
torch
.
sigmoid
(
x
)
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
class
XLNetBaseConfig
(
object
):
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `XLNetConfig` from a Python dictionary of parameters."""
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=-
1
)
for
key
,
value
in
json_object
.
items
():
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `XLNetConfig` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
):
""" Save this instance to a json file."""
with
open
(
json_file_path
,
"w"
,
encoding
=
'utf-8'
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
class
XLNetConfig
(
XLNetBaseConfig
):
"""Configuration class to store the configuration of a `XLNetModel`.
"""
def
__init__
(
self
,
vocab_size_or_config_json_file
,
d_model
=
768
,
n_layer
=
12
,
n_head
=
12
,
d_inner
=
3072
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
"""Constructs XLNetConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLNetModel`.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
d_inner: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`XLNetModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
class
XLNetRunConfig
(
XLNetBaseConfig
):
"""XLNetRunConfig contains hyperparameters that could be different
between pretraining and finetuning.
These hyperparameters can also be changed from run to run.
We store them separately from XLNetConfig for flexibility.
"""
def
__init__
(
self
,
dropout
,
dropatt
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
"""
Args:
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
self
.
init
=
init
self
.
init_range
=
init_range
self
.
init_std
=
init_std
self
.
dropout
=
dropout
self
.
dropatt
=
dropatt
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
XLNetLayerNorm
except
ImportError
:
logger
.
info
(
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
)
class
XLNetLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-12
):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super
(
XLNetLayerNorm
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
u
=
x
.
mean
(
-
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
+
self
.
bias
class
XLNetPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLNetConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
def
init_xlnet_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
XLNetLayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a XLNetPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance
- a path or url to a pretrained model archive containing:
. `xlnet_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
config_file
=
PRETRAINED_CONFIG_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
XLNET_CONFIG_NAME
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained weights."
.
format
(
archive_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
()),
archive_file
))
return
None
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_CONFIG_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download pretrained model configuration file."
.
format
(
config_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
()),
config_file
))
return
None
if
resolved_archive_file
==
archive_file
and
resolved_config_file
==
config_file
:
logger
.
info
(
"loading weights file {}"
.
format
(
archive_file
))
logger
.
info
(
"loading configuration file {}"
.
format
(
config_file
))
else
:
logger
.
info
(
"loading weights file {} from cache at {}"
.
format
(
archive_file
,
resolved_archive_file
))
logger
.
info
(
"loading configuration file {} from cache at {}"
.
format
(
config_file
,
resolved_config_file
))
# Load config
config
=
XLNetConfig
.
from_json_file
(
resolved_config_file
)
logger
.
info
(
"Model config {}"
.
format
(
config
))
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
if
state_dict
is
None
and
not
from_tf
:
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
# Directly load from a TensorFlow checkpoint
return
load_tf_weights_in_xlnet
(
model
,
resolved_archive_file
)
# Load from a PyTorch state_dict
missing_keys
=
[]
unexpected_keys
=
[]
error_msgs
=
[]
# copy state_dict so _load_from_state_dict can modify it
metadata
=
getattr
(
state_dict
,
'_metadata'
,
None
)
state_dict
=
state_dict
.
copy
()
if
metadata
is
not
None
:
state_dict
.
_metadata
=
metadata
def
load
(
module
,
prefix
=
''
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
module
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
True
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
start_prefix
=
''
if
not
hasattr
(
model
,
'xlnet'
)
and
any
(
s
.
startswith
(
'xlnet.'
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
'xlnet.'
load
(
model
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
if
len
(
unexpected_keys
)
>
0
:
logger
.
info
(
"Weights from pretrained model not used in {}: {}"
.
format
(
model
.
__class__
.
__name__
,
unexpected_keys
))
if
len
(
error_msgs
)
>
0
:
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
return
model
class
XLNetModel
(
XLNetPreTrainedModel
):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
Params:
`config`: a XLNetConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLNet paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
apply
(
self
.
init_xlnet_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
\ No newline at end of file
pytorch_pretrained_bert/tokenization_xlnet.py
0 → 100644
View file @
de713fa9
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment