Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
24d80689
Commit
24d80689
authored
Jun 21, 2019
by
thomwolf
Browse files
weights loading script ok
parent
32da7548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
90 deletions
+193
-90
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+17
-9
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+173
-78
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+3
-3
No files found.
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
24d80689
...
...
@@ -18,23 +18,31 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
(
CONFIG_NAME
,
WEIGHTS_NAME
,
XLNetConfig
,
XLNetRunConfig
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_
folder_
path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetLMHeadModel
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
pytorch_weights_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
WEIGHTS_NAME
)
pytorch_config_dump_path
=
os
.
path
.
join
(
pytorch_dump_folder_path
,
CONFIG_NAME
)
print
(
"Save PyTorch model to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_weights_dump_path
)))
torch
.
save
(
model
.
state_dict
(),
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
os
.
path
.
abspath
(
pytorch_config_dump_path
)))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
config
.
to_json_string
())
if
__name__
==
"__main__"
:
...
...
@@ -50,13 +58,13 @@ if __name__ == "__main__":
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained XLNet model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_
folder_
path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the
output PyTorch model
."
)
help
=
"Path to the
folder to store the PyTorch model or dataset/vocab
."
)
args
=
parser
.
parse_args
()
convert_xlnet_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
xlnet_config_file
,
args
.
pytorch_dump_path
)
args
.
pytorch_dump_
folder_
path
)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
24d80689
...
...
@@ -45,70 +45,122 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
XLNET_CONFIG_NAME
=
'xlnet_config.json'
TF_WEIGHTS_NAME
=
'model.ckpt'
def
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
):
def
build_tf_xlnet_to_pytorch_map
(
model
,
config
):
""" A map of modules from TF to PyTorch.
I use a map to keep the PyTorch model as
identical to the original PyTorch model as possible.
"""
tf_to_pt_map
=
{}
if
hasattr
(
model
,
'transformer'
):
# We are loading pre-trained weights in a XLNetLMHeadModel => we will load also the output bias
tf_to_pt_map
[
'model/lm_loss/bias'
]
=
model
.
lm_loss
.
bias
# Now load the rest of the transformer
model
=
model
.
transformer
# Embeddings and output
tf_to_pt_map
.
update
({
'model/transformer/word_embedding/lookup_table'
:
model
.
word_embedding
.
weight
,
'model/transformer/mask_emb/mask_emb'
:
model
.
mask_emb
})
# Transformer blocks
for
i
,
b
in
enumerate
(
model
.
layer
):
layer_str
=
"model/transformer/layer_%d/"
%
i
tf_to_pt_map
.
update
({
layer_str
+
"rel_attn/LayerNorm/gamma"
:
b
.
rel_attn
.
layer_norm
.
weight
,
layer_str
+
"rel_attn/LayerNorm/beta"
:
b
.
rel_attn
.
layer_norm
.
bias
,
layer_str
+
"rel_attn/o/kernel"
:
b
.
rel_attn
.
o
,
layer_str
+
"rel_attn/q/kernel"
:
b
.
rel_attn
.
q
,
layer_str
+
"rel_attn/k/kernel"
:
b
.
rel_attn
.
k
,
layer_str
+
"rel_attn/r/kernel"
:
b
.
rel_attn
.
r
,
layer_str
+
"rel_attn/v/kernel"
:
b
.
rel_attn
.
v
,
layer_str
+
"ff/LayerNorm/gamma"
:
b
.
ff
.
layer_norm
.
weight
,
layer_str
+
"ff/LayerNorm/beta"
:
b
.
ff
.
layer_norm
.
bias
,
layer_str
+
"ff/layer_1/kernel"
:
b
.
ff
.
layer_1
.
weight
,
layer_str
+
"ff/layer_1/bias"
:
b
.
ff
.
layer_1
.
bias
,
layer_str
+
"ff/layer_2/kernel"
:
b
.
ff
.
layer_2
.
weight
,
layer_str
+
"ff/layer_2/bias"
:
b
.
ff
.
layer_2
.
bias
,
})
# Relative positioning biases
if
config
.
untie_r
:
r_r_list
=
[]
r_w_list
=
[]
r_s_list
=
[]
seg_embed_list
=
[]
for
b
in
model
.
layer
:
r_r_list
.
append
(
b
.
rel_attn
.
r_r_bias
)
r_w_list
.
append
(
b
.
rel_attn
.
r_w_bias
)
r_s_list
.
append
(
b
.
rel_attn
.
r_s_bias
)
seg_embed_list
.
append
(
b
.
rel_attn
.
seg_embed
)
else
:
r_r_list
=
[
model
.
r_r_bias
]
r_w_list
=
[
model
.
r_w_bias
]
r_s_list
=
[
model
.
r_s_bias
]
seg_embed_list
=
[
model
.
seg_embed
]
tf_to_pt_map
.
update
({
'model/transformer/r_r_bias'
:
r_r_list
,
'model/transformer/r_w_bias'
:
r_w_list
,
'model/transformer/r_s_bias'
:
r_s_list
,
'model/transformer/seg_embed'
:
seg_embed_list
})
return
tf_to_pt_map
def
load_tf_weights_in_xlnet
(
model
,
config
,
tf_path
):
""" Load tf checkpoints in a pytorch model
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Build TF to PyTorch weights loading map
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
)
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
tf_weights
[
name
]
=
array
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
assert
name
in
tf_weights
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'kernel'
or
l
[
0
]
==
'gamma'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'output_bias'
or
l
[
0
]
==
'beta'
:
pointer
=
getattr
(
pointer
,
'bias'
)
elif
l
[
0
]
==
'output_weights'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'squad'
:
pointer
=
getattr
(
pointer
,
'classifier'
)
else
:
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
if
'kernel'
in
name
and
'ff'
in
name
:
print
(
"Transposing"
)
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
if
isinstance
(
pointer
,
list
):
# Here we will split the TF weigths
assert
len
(
pointer
)
==
array
.
shape
[
0
]
for
i
,
p_i
in
enumerate
(
pointer
):
arr_i
=
array
[
i
,
...]
try
:
assert
p_i
.
shape
==
arr_i
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
else
:
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
tf_weights
.
pop
(
name
,
None
)
tf_weights
.
pop
(
name
+
'/Adam'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
...
...
@@ -181,7 +233,18 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
,
dropout
=
0.1
,
dropatt
=
0.1
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
):
"""Constructs XLNetConfig.
Args:
...
...
@@ -207,6 +270,22 @@ class XLNetConfig(XLNetBaseConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
"""
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
...
...
@@ -215,7 +294,7 @@ class XLNetConfig(XLNetBaseConfig):
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
n_token
=
vocab_size_or_config_json_file
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
...
...
@@ -225,9 +304,21 @@ class XLNetConfig(XLNetBaseConfig):
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
init
=
init
self
.
init_range
=
init_range
self
.
init_std
=
init_std
self
.
dropout
=
dropout
self
.
dropatt
=
dropatt
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
...
...
@@ -327,7 +418,7 @@ class XLNetRelativeAttention(nn.Module):
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads
):
...
...
@@ -385,7 +476,7 @@ class XLNetRelativeAttention(nn.Module):
attn_out
=
self
.
dropout
(
attn_out
)
if
residual
:
attn_out
=
attn_out
+
h
output
=
self
.
L
ayer
N
orm
(
attn_out
)
output
=
self
.
l
ayer
_n
orm
(
attn_out
)
return
output
...
...
@@ -483,7 +574,7 @@ class XLNetRelativeAttention(nn.Module):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLNetFeedForward
,
self
).
__init__
()
self
.
L
ayer
N
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
l
ayer
_n
orm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
...
@@ -499,7 +590,7 @@ class XLNetFeedForward(nn.Module):
output
=
self
.
dropout
(
output
)
output
=
self
.
layer_2
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
L
ayer
N
orm
(
output
+
inp
)
output
=
self
.
l
ayer
_n
orm
(
output
+
inp
)
return
output
class
XLNetLayer
(
nn
.
Module
):
...
...
@@ -691,11 +782,26 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
self
.
word_embedding
=
nn
.
Embedding
(
config
.
n_token
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
layer
]
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
...
...
@@ -783,12 +889,12 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
word_emb
_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
inp
_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
word_emb_k: floa
t32 Tensor in shape [len, bsz
, d_model
], the input token
embedding
s.
inp_k: in
t32 Tensor in shape [len, bsz], the input token
ID
s.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
...
...
@@ -820,11 +926,12 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen
,
bsz
=
word_emb
_k
.
shape
[
0
],
word_emb
_k
.
shape
[
1
]
qlen
,
bsz
=
inp
_k
.
shape
[
0
],
inp
_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
dtype_float
=
word_emb_k
.
dtype
device
=
word_emb_k
.
device
dtype_float
=
next
(
self
.
parameters
()).
dtype
device
=
next
(
self
.
parameters
()).
device
##### Attention mask
# causal attention mask
...
...
@@ -865,7 +972,8 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
non_tgt_mask
=
None
##### Process Word embeddings and prepare h & g hidden states
##### Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
...
...
@@ -983,30 +1091,19 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
# Tie weights
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
tie_weights
()
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
get_multihead_outputs
(
self
):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
...
@@ -1037,9 +1134,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output
,
new_mems
=
self
.
transformer
(
word_emb_k
,
seg_id
,
input_mask
,
output
,
new_mems
=
self
.
transformer
(
inp_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
...
...
@@ -1059,5 +1154,5 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return
logits
,
new_mems
# return all_attentions, encoded_layers, pooled_output
tests/modeling_xlnet_test.py
View file @
24d80689
...
...
@@ -186,13 +186,13 @@ class XLNetModelTest(unittest.TestCase):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
self
.
assertEqual
(
obj
[
"d_model"
],
16
*
4
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
16
*
4
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment