Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
24d80689
Commit
24d80689
authored
Jun 21, 2019
by
thomwolf
Browse files
weights loading script ok
parent
32da7548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
90 deletions
+193
-90
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+17
-9
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+173
-78
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+3
-3
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
24d80689
...
...
@@ -18,23 +18,31 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_
folder_
path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetLMHeadModel
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_config_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
CONFIG_NAME
)
print
(
"Save PyTorch model to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_weights_dump_path
)))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_config_dump_path
)))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
...
...
@@ -50,13 +58,13 @@ if __name__ == "__main__":
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_
folder_
path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the
output PyTorch model
."
)
help
=
"Path to the
folder to store the PyTorch model or dataset/vocab
."
)
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_path
)
args
.
pytorch_dump_
folder_
path
)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
24d80689
...
...
@@ -45,70 +45,122 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
if
hasattr
(
model
,
'transformer'
):
# We are loading pre-trained weights in a XLNetLMHeadModel => we will load also the output bias
tf_to_pt_map
[
'model/lm_loss/bias'
]
=
model
.
lm_loss
.
bias
# Now load the rest of the transformer
model
=
model
.
transformer
# Embeddings and output
tf_to_pt_map
.
update
({
'model/transformer/word_embedding/lookup_table'
:
model
.
word_embedding
.
weight
,
'model/transformer/mask_emb/mask_emb'
:
model
.
mask_emb
})
# Transformer blocks
for
i
,
b
in
enumerate
(
model
.
layer
):
layer_str
=
"model/transformer/layer_%d/"
%
i
tf_to_pt_map
.
update
({
layer_str
+
"rel_attn/LayerNorm/gamma"
:
b
.
rel_attn
.
layer_norm
.
weight
,
layer_str
+
"rel_attn/LayerNorm/beta"
:
b
.
rel_attn
.
layer_norm
.
bias
,
layer_str
+
"rel_attn/o/kernel"
:
b
.
rel_attn
.
o
,
layer_str
+
"rel_attn/q/kernel"
:
b
.
rel_attn
.
q
,
layer_str
+
"rel_attn/k/kernel"
:
b
.
rel_attn
.
k
,
layer_str
+
"rel_attn/r/kernel"
:
b
.
rel_attn
.
r
,
layer_str
+
"rel_attn/v/kernel"
:
b
.
rel_attn
.
v
,
layer_str
+
"ff/LayerNorm/gamma"
:
b
.
ff
.
layer_norm
.
weight
,
layer_str
+
"ff/LayerNorm/beta"
:
b
.
ff
.
layer_norm
.
bias
,
layer_str
+
"ff/layer_1/kernel"
:
b
.
ff
.
layer_1
.
weight
,
layer_str
+
"ff/layer_1/bias"
:
b
.
ff
.
layer_1
.
bias
,
layer_str
+
"ff/layer_2/kernel"
:
b
.
ff
.
layer_2
.
weight
,
layer_str
+
"ff/layer_2/bias"
:
b
.
ff
.
layer_2
.
bias
,
})
# Relative positioning biases
if
config
.
untie_r
:
r_r_list
=
[]
r_w_list
=
[]
r_s_list
=
[]
seg_embed_list
=
[]
for
b
in
model
.
layer
:
r_r_list
.
append
(
b
.
rel_attn
.
r_r_bias
)
r_w_list
.
append
(
b
.
rel_attn
.
r_w_bias
)
r_s_list
.
append
(
b
.
rel_attn
.
r_s_bias
)
seg_embed_list
.
append
(
b
.
rel_attn
.
seg_embed
)
else
:
r_r_list
=
[
model
.
r_r_bias
]
r_w_list
=
[
model
.
r_w_bias
]
r_s_list
=
[
model
.
r_s_bias
]
seg_embed_list
=
[
model
.
seg_embed
]
tf_to_pt_map
.
update
({
'model/transformer/r_r_bias'
:
r_r_list
,
'model/transformer/r_w_bias'
:
r_w_list
,
'model/transformer/r_s_bias'
:
r_s_list
,
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
tf_weights
[
name
]
=
array
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
assert
name
in
tf_weights
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'kernel'
or
l
[
0
]
==
'gamma'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'output_bias'
or
l
[
0
]
==
'beta'
:
pointer
=
getattr
(
pointer
,
'bias'
)
elif
l
[
0
]
==
'output_weights'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'squad'
:
pointer
=
getattr
(
pointer
,
'classifier'
)
else
:
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
if
'kernel'
in
name
and
'ff'
in
name
:
print
(
"Transposing"
)
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
if
isinstance
(
pointer
,
list
):
# Here we will split the TF weigths
assert
len
(
pointer
)
==
array
.
shape
[
0
]
for
i
,
p_i
in
enumerate
(
pointer
):
arr_i
=
array
[
i
,
...]
try
:
assert
p_i
.
shape
==
arr_i
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
else
:
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
tf_weights
.
pop
(
name
,
None
)
tf_weights
.
pop
(
name
+
'/Adam'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
...
...
@@ -181,7 +233,18 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
,
dropout
=
0.1
,
dropatt
=
0.1
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
"""Constructs XLNetConfig.
Args:
...
...
@@ -207,6 +270,22 @@ class XLNetConfig(XLNetBaseConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -215,7 +294,7 @@ class XLNetConfig(XLNetBaseConfig):
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
n_token
=
vocab_size_or_config_json_file
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
...
...
@@ -225,9 +304,21 @@ class XLNetConfig(XLNetBaseConfig):
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
init
=
init
self
.
init_range
=
init_range
self
.
init_std
=
init_std
self
.
dropout
=
dropout
self
.
dropatt
=
dropatt
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -327,7 +418,7 @@ class XLNetRelativeAttention(nn.Module):
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads
):
...
...
@@ -385,7 +476,7 @@ class XLNetRelativeAttention(nn.Module):
attn_out
=
self
.
dropout
(
attn_out
)
if
residual
:
attn_out
=
attn_out
+
h
output
=
self
.
L
ayer
N
orm
(
attn_out
)
output
=
self
.
l
ayer
_n
orm
(
attn_out
)
return
output
...
...
@@ -483,7 +574,7 @@ class XLNetRelativeAttention(nn.Module):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLNetFeedForward
,
self
).
__init__
()
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
...
@@ -499,7 +590,7 @@ class XLNetFeedForward(nn.Module):
output
=
self
.
dropout
(
output
)
output
=
self
.
layer_2
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
L
ayer
N
orm
(
output
+
inp
)
output
=
self
.
l
ayer
_n
orm
(
output
+
inp
)
return
output
class
XLNetLayer
(
nn
.
Module
):
...
...
@@ -691,11 +782,26 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
...
...
@@ -783,12 +889,12 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
word_emb
_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
inp
_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
word_emb_k: floa
t32 Tensor in shape [len, bsz
, d_model
], the input token
embedding
s.
inp_k: in
t32 Tensor in shape [len, bsz], the input token
ID
s.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -820,11 +926,12 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen
,
bsz
=
word_emb
_k
.
shape
[
0
],
word_emb
_k
.
shape
[
1
]
qlen
,
bsz
=
inp
_k
.
shape
[
0
],
inp
_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
dtype_float
=
word_emb_k
.
dtype
device
=
word_emb_k
.
device
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
##### Attention mask
# causal attention mask
...
...
@@ -865,7 +972,8 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
non_tgt_mask
=
None
##### Process Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
...
...
@@ -983,30 +1091,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
tie_weights
()
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
...
@@ -1037,9 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output
,
new_mems
=
self
.
transformer
(
word_emb_k
,
seg_id
,
input_mask
,
output
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
...
...
@@ -1059,5 +1154,5 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
tests/modeling_xlnet_test.py
View file @
24d80689
...
...
@@ -186,13 +186,13 @@ class XLNetModelTest(unittest.TestCase):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
self
.
assertEqual
(
obj
[
"d_model"
],
16
*
4
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment