Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
24d80689
Commit
24d80689
authored
Jun 21, 2019
by
thomwolf
Browse files
weights loading script ok
parent
32da7548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
90 deletions
+193
-90
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+17
-9
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+173
-78
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+3
-3
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
24d80689
...
@@ -18,23 +18,31 @@ from __future__ import absolute_import
...
@@ -18,23 +18,31 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
argparse
import
argparse
import
torch
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_
folder_
path
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
pytorch_config_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
CONFIG_NAME
)
print
(
"Save PyTorch model to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_weights_dump_path
)))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_config_dump_path
)))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
@@ -50,13 +58,13 @@ if __name__ == "__main__":
...
@@ -50,13 +58,13 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
parser
.
add_argument
(
"--pytorch_dump_
folder_
path"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"Path to the
output PyTorch model
."
)
help
=
"Path to the
folder to store the PyTorch model or dataset/vocab
."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_path
)
args
.
pytorch_dump_
folder_
path
)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
24d80689
...
@@ -45,70 +45,122 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
...
@@ -45,70 +45,122 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_CONFIG_NAME
=
'xlnet_config.json'
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
if
hasattr
(
model
,
'transformer'
):
# We are loading pre-trained weights in a XLNetLMHeadModel => we will load also the output bias
tf_to_pt_map
[
'model/lm_loss/bias'
]
=
model
.
lm_loss
.
bias
# Now load the rest of the transformer
model
=
model
.
transformer
# Embeddings and output
tf_to_pt_map
.
update
({
'model/transformer/word_embedding/lookup_table'
:
model
.
word_embedding
.
weight
,
'model/transformer/mask_emb/mask_emb'
:
model
.
mask_emb
})
# Transformer blocks
for
i
,
b
in
enumerate
(
model
.
layer
):
layer_str
=
"model/transformer/layer_%d/"
%
i
tf_to_pt_map
.
update
({
layer_str
+
"rel_attn/LayerNorm/gamma"
:
b
.
rel_attn
.
layer_norm
.
weight
,
layer_str
+
"rel_attn/LayerNorm/beta"
:
b
.
rel_attn
.
layer_norm
.
bias
,
layer_str
+
"rel_attn/o/kernel"
:
b
.
rel_attn
.
o
,
layer_str
+
"rel_attn/q/kernel"
:
b
.
rel_attn
.
q
,
layer_str
+
"rel_attn/k/kernel"
:
b
.
rel_attn
.
k
,
layer_str
+
"rel_attn/r/kernel"
:
b
.
rel_attn
.
r
,
layer_str
+
"rel_attn/v/kernel"
:
b
.
rel_attn
.
v
,
layer_str
+
"ff/LayerNorm/gamma"
:
b
.
ff
.
layer_norm
.
weight
,
layer_str
+
"ff/LayerNorm/beta"
:
b
.
ff
.
layer_norm
.
bias
,
layer_str
+
"ff/layer_1/kernel"
:
b
.
ff
.
layer_1
.
weight
,
layer_str
+
"ff/layer_1/bias"
:
b
.
ff
.
layer_1
.
bias
,
layer_str
+
"ff/layer_2/kernel"
:
b
.
ff
.
layer_2
.
weight
,
layer_str
+
"ff/layer_2/bias"
:
b
.
ff
.
layer_2
.
bias
,
})
# Relative positioning biases
if
config
.
untie_r
:
r_r_list
=
[]
r_w_list
=
[]
r_s_list
=
[]
seg_embed_list
=
[]
for
b
in
model
.
layer
:
r_r_list
.
append
(
b
.
rel_attn
.
r_r_bias
)
r_w_list
.
append
(
b
.
rel_attn
.
r_w_bias
)
r_s_list
.
append
(
b
.
rel_attn
.
r_s_bias
)
seg_embed_list
.
append
(
b
.
rel_attn
.
seg_embed
)
else
:
r_r_list
=
[
model
.
r_r_bias
]
r_w_list
=
[
model
.
r_w_bias
]
r_s_list
=
[
model
.
r_s_bias
]
seg_embed_list
=
[
model
.
seg_embed
]
tf_to_pt_map
.
update
({
'model/transformer/r_r_bias'
:
r_r_list
,
'model/transformer/r_w_bias'
:
r_w_list
,
'model/transformer/r_s_bias'
:
r_s_list
,
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
""" Load tf checkpoints in a pytorch model
"""
"""
try
:
try
:
import
re
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
except
ImportError
:
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
# Build TF to PyTorch weights loading map
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
)
# Load weights from TF model
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
tf_weights
=
{}
arrays
=
[]
for
name
,
shape
in
init_vars
:
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
tf_weights
[
name
]
=
array
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
for
name
,
pointer
in
tf_to_pt_map
.
items
():
name
=
name
.
split
(
'/'
)
print
(
"Importing {}"
.
format
(
name
))
assert
name
in
tf_weights
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
if
'kernel'
in
name
and
'ff'
in
name
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
print
(
"Transposing"
)
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'kernel'
or
l
[
0
]
==
'gamma'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'output_bias'
or
l
[
0
]
==
'beta'
:
pointer
=
getattr
(
pointer
,
'bias'
)
elif
l
[
0
]
==
'output_weights'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'squad'
:
pointer
=
getattr
(
pointer
,
'classifier'
)
else
:
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
array
=
np
.
transpose
(
array
)
try
:
if
isinstance
(
pointer
,
list
):
assert
pointer
.
shape
==
array
.
shape
# Here we will split the TF weigths
except
AssertionError
as
e
:
assert
len
(
pointer
)
==
array
.
shape
[
0
]
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
for
i
,
p_i
in
enumerate
(
pointer
):
raise
arr_i
=
array
[
i
,
...]
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
try
:
pointer
.
data
=
torch
.
from_numpy
(
array
)
assert
p_i
.
shape
==
arr_i
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
else
:
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
tf_weights
.
pop
(
name
,
None
)
tf_weights
.
pop
(
name
+
'/Adam'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
return
model
...
@@ -181,7 +233,18 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -181,7 +233,18 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
,
dropout
=
0.1
,
dropatt
=
0.1
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
Args:
Args:
...
@@ -207,6 +270,22 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -207,6 +270,22 @@ class XLNetConfig(XLNetBaseConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
@@ -215,7 +294,7 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -215,7 +294,7 @@ class XLNetConfig(XLNetBaseConfig):
for
key
,
value
in
json_config
.
items
():
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
n_token
=
vocab_size_or_config_json_file
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -225,9 +304,21 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -225,9 +304,21 @@ class XLNetConfig(XLNetBaseConfig):
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
layer_norm_eps
=
layer_norm_eps
self
.
init
=
init
self
.
init_range
=
init_range
self
.
init_std
=
init_std
self
.
dropout
=
dropout
self
.
dropatt
=
dropatt
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -327,7 +418,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -327,7 +418,7 @@ class XLNetRelativeAttention(nn.Module):
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
...
@@ -385,7 +476,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -385,7 +476,7 @@ class XLNetRelativeAttention(nn.Module):
attn_out
=
self
.
dropout
(
attn_out
)
attn_out
=
self
.
dropout
(
attn_out
)
if
residual
:
if
residual
:
attn_out
=
attn_out
+
h
attn_out
=
attn_out
+
h
output
=
self
.
L
ayer
N
orm
(
attn_out
)
output
=
self
.
l
ayer
_n
orm
(
attn_out
)
return
output
return
output
...
@@ -483,7 +574,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -483,7 +574,7 @@ class XLNetRelativeAttention(nn.Module):
class
XLNetFeedForward
(
nn
.
Module
):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
XLNetFeedForward
,
self
).
__init__
()
super
(
XLNetFeedForward
,
self
).
__init__
()
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
@@ -499,7 +590,7 @@ class XLNetFeedForward(nn.Module):
...
@@ -499,7 +590,7 @@ class XLNetFeedForward(nn.Module):
output
=
self
.
dropout
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
layer_2
(
output
)
output
=
self
.
layer_2
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
L
ayer
N
orm
(
output
+
inp
)
output
=
self
.
l
ayer
_n
orm
(
output
+
inp
)
return
output
return
output
class
XLNetLayer
(
nn
.
Module
):
class
XLNetLayer
(
nn
.
Module
):
...
@@ -691,11 +782,26 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -691,11 +782,26 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
bi_data
=
config
.
bi_data
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
clamp_len
=
config
.
clamp_len
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
...
@@ -783,12 +889,12 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -783,12 +889,12 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
return
pos_emb
def
forward
(
self
,
word_emb
_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
inp
_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
word_emb_k: floa
t32 Tensor in shape [len, bsz
, d_model
], the input token
embedding
s.
inp_k: in
t32 Tensor in shape [len, bsz], the input token
ID
s.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
...
@@ -820,11 +926,12 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -820,11 +926,12 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
qlen
,
bsz
=
word_emb
_k
.
shape
[
0
],
word_emb
_k
.
shape
[
1
]
qlen
,
bsz
=
inp
_k
.
shape
[
0
],
inp
_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
dtype_float
=
word_emb_k
.
dtype
device
=
word_emb_k
.
device
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
##### Attention mask
##### Attention mask
# causal attention mask
# causal attention mask
...
@@ -865,7 +972,8 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -865,7 +972,8 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
##### Process Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
...
@@ -983,30 +1091,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -983,30 +1091,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
# Tie weights
# Tie weights
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
tie_weights
()
def
prune_heads
(
self
,
heads_to_prune
):
def
tie_weights
(
self
):
""" Prunes heads of the model.
""" Make sure we are sharing the embeddings
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
@@ -1037,9 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1037,9 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
output
,
new_mems
=
self
.
transformer
(
word_emb_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
output_all_encoded_layers
,
head_mask
)
...
@@ -1059,5 +1154,5 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1059,5 +1154,5 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if not output_all_encoded_layers:
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return
logits
,
new_mems
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
tests/modeling_xlnet_test.py
View file @
24d80689
...
@@ -186,13 +186,13 @@ class XLNetModelTest(unittest.TestCase):
...
@@ -186,13 +186,13 @@ class XLNetModelTest(unittest.TestCase):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
obj
=
json
.
loads
(
config
.
to_json_string
())
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
self
.
assertEqual
(
obj
[
"d_model"
],
16
*
4
)
def
test_config_to_json_file
(
self
):
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
json_file_path
=
"/tmp/config.json"
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment