Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45709d75
Commit
45709d75
authored
Jun 21, 2019
by
thomwolf
Browse files
model running with simple inputs
parent
b407972e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
404 additions
and
108 deletions
+404
-108
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+3
-0
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+2
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+1
-1
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+149
-103
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+247
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
45709d75
...
...
@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
GPT2MultipleChoiceHead
,
load_tf_weights_in_gpt2
)
from
.modeling_xlnet
import
(
XLNetBaseConfig
,
XLNetConfig
,
XLNetRunConfig
,
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
...
...
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
45709d75
...
...
@@ -21,13 +21,13 @@ from __future__ import print_function
import
argparse
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNet
LMHead
Model
,
load_tf_weights_in_xlnet
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetModel
(
config
)
model
=
XLNet
LMHead
Model
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
...
...
pytorch_pretrained_bert/modeling.py
View file @
45709d75
...
...
@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
45709d75
...
...
@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
45709d75
...
...
@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
45709d75
...
...
@@ -29,6 +29,7 @@ from io import open
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
CrossEntropyLoss
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
CONFIG_NAME
...
...
@@ -126,32 +127,27 @@ def swish(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
1
,
bsz
,
1
)
return
pos_emb
class
XLNetBaseConfig
(
object
):
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `XLNetConfig` from a Python dictionary of parameters."""
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=-
1
)
"""Constructs a `XLNet
Base
Config` from a Python dictionary of parameters."""
config
=
cls
(
vocab_size_or_config_json_file
=-
1
)
for
key
,
value
in
json_object
.
items
():
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `XLNetConfig` from a json file of parameters."""
"""Constructs a `XLNet
Base
Config` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
update
(
self
,
other
):
dict_b
=
other
.
to_dict
()
for
key
,
value
in
dict_b
.
items
():
self
.
__dict__
[
key
]
=
value
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
...
...
@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig):
d_inner
=
4096
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
attn_type
=
"bi"
,
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
...
...
@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig):
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
...
...
@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig):
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
...
...
@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
if
config
.
d_model
%
config
.
n
um_attention
_head
s
!=
0
:
if
config
.
d_model
%
config
.
n_head
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
d_model
,
config
.
n
um_attention
_head
s
))
"heads (%d)"
%
(
config
.
d_model
,
config
.
n_head
))
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
n_head
=
config
.
n
um_attention
_head
s
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
d_model
=
config
.
d_model
self
.
scale
=
1
/
(
config
.
d_head
**
0.5
)
...
...
@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module):
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
2
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
...
@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module):
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
@
staticmethod
def
rel_shift
(
x
,
klen
=-
1
):
"""perform relative shift to form the relative attention score."""
x_size
=
x
.
shape
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
return
x
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
"""Core relative positional attention operations."""
...
...
@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module):
# position based attention score
bd
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_r_bias
,
k_head_r
)
bd
=
rel_shift
(
bd
,
klen
=
torch
.
shape
(
ac
)
[
1
])
bd
=
self
.
rel_shift
(
bd
,
klen
=
ac
.
shape
[
1
])
# segment based attention score
if
seg_mat
is
None
:
...
...
@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module):
# post processing
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
attention_output
=
output_h
,
output_g
else
:
###### Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
...
...
@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module):
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
attention_output
=
self
.
post_attention
(
h
,
attn_vec
)
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
# Mask heads if we want to
...
...
@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module):
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return
attention_
output
return
output_h
,
output
_g
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
...
...
@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module):
else
:
self
.
activation_function
=
config
.
ff_activation
def
forward
(
self
,
hidden_states
,
input_tensor
):
hidden_states
=
self
.
layer_1
(
hidden_states
)
hidden_states
=
self
.
activation_function
(
hidden_states
)
hidden_states
=
self
.
layer_2
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
def
forward
(
self
,
inp
):
output
=
inp
output
=
self
.
layer_1
(
output
)
output
=
self
.
activation_function
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
layer_2
(
output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
LayerNorm
(
output
+
inp
)
return
output
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
...
...
@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module):
def
forward
(
self
,
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
r
,
seg_mat
,
two_streams
=
False
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
if
two_streams
:
if
output_g
is
not
None
:
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
...
...
@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module):
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLNetConfig
):
if
not
isinstance
(
config
,
XLNet
Base
Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"Parameter config in `{}(config)` should be an instance of class `XLNet
Base
Config`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
...
...
@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
()
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
d_model
=
config
.
d_model
self
.
same_length
=
config
.
same_length
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
@
classmethod
def
_create_mask
(
qlen
,
mlen
,
dtype
=
torch
.
float
,
same_length
=
False
):
"""create causal attention mask."""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
],
dtype
=
dtype
)
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype
)
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
same_length
:
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
ret
=
tf
.
concat
([
ret
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
ret
[:,
qlen
:]],
1
)
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
])
mask_up
=
torch
.
triu
(
attn_mask
,
diagonal
=
1
)
attn_mask_pad
=
torch
.
zeros
([
qlen
,
mlen
])
ret
=
torch
.
cat
([
attn_mask_pad
,
mask_up
],
dim
=
1
)
if
self
.
same_length
:
mask_lo
=
torch
.
tril
(
attn_mask
,
diagonal
=-
1
)
ret
=
torch
.
cat
([
ret
[:,
:
qlen
]
+
mask_lo
,
ret
[:,
qlen
:]],
dim
=
1
)
ret
=
ret
.
to
(
next
(
self
.
parameters
()))
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
...
...
@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel):
return
new_mem
.
detach
()
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
torch
.
float
):
@
staticmethod
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)],
dim
=-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
-
1
,
bsz
,
-
1
)
return
pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
z
range
(
0
,
d_model
,
2.0
,
dtype
=
dtype
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
config
.
d_model
))
freq_seq
=
torch
.
a
range
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
...
...
@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel):
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
dtype
)
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
torch
.
float
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
torch
.
float
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
def
forward
(
self
,
inp
_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
word_emb
_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: in
t32 Tensor in shape [len, bsz], the input token
ID
s.
word_emb_k: floa
t32 Tensor in shape [len, bsz
, d_model
], the input token
embedding
s.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
input_mask:
[optional]
float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
mems:
[optional]
a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
perm_mask:
[optional]
float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
target_mapping:
[optional]
float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
inp_q:
[optional]
float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
...
...
@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen
,
bsz
=
inp
_k
.
shape
qlen
,
bsz
=
word_emb_k
.
shape
[
0
],
word_emb
_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
dtype_float
=
word_emb_k
.
dtype
device
=
word_emb_k
.
device
##### Attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
attn_mask
=
_
create_mask
(
qlen
,
mlen
,
inp_k
.
dtype
,
self
.
same_length
)
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
...
...
@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel):
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]
,
dtype
=
data_mask
.
dtype
,
device
=
data_mask
.
device
)
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]
).
to
(
data_mask
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
...
...
@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel):
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
>
0
).
float
(
)
attn_mask
=
(
attn_mask
>
0
).
to
(
dtype_
float
)
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
tf_float
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
tf_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
tf_float
)
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
else
:
non_tgt_mask
=
None
##### Word embedding
word_emb_k
=
self
.
word_embedding
(
inp_k
)
##### Process Word embeddings and prepare h & g hidden states
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
1
)
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
...
...
@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding
if
seg_id
is
not
None
:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
)
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
seg_id
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
seg_id
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
#
seg_mat =
tf
.one_hot(seg_mat,
2,
dtype
=tf
_float)
seg_mat
=
F
.
one_hot
(
seg_mat
,
num_classes
=
2
).
to
(
dtype_float
)
else
:
seg_mat
=
None
##### Positional encoding
pos_emb
=
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
inp_k
.
dtype
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n
um_hidden
_layer
s
x num_heads]
# and head_mask is converted to shape [n
um_hidden
_layer
s
x batch x num_heads x seq_length x seq_length]
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n
um_hidden
_layer
s
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n
um_hidden
_layer
s
head_mask
=
[
None
]
*
self
.
config
.
n_layer
new_mems
=
[]
if
mems
is
None
:
...
...
@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel):
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
return
output
return
output
,
new_mems
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
...
...
@@ -932,28 +971,27 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
num_hidde
n_layer
s
=12, num_attention_heads=12, intermediate_size=3072)
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
run_config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
run_
config
.
attn_type
self
.
same_length
=
run_
config
.
same_length
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
self
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
# Tie weights
if
config
.
tie_weight
:
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
...
...
@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
...
...
@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output
,
new_mems
=
self
.
transformer
(
output_h
,
non_tgt_mask
,
r
,
seg_mat
,
output_g
=
output_g
,
attn_mask_g
=
attn_mask
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output
,
new_mems
=
self
.
transformer
(
word_emb_k
,
seg_id
,
input_mask
,
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
logits
=
self
.
lm_loss
(
output
)
if
target
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
target
.
view
(
-
1
))
return
loss
,
new_mems
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
...
...
tests/modeling_xlnet_test.py
0 → 100644
View file @
45709d75
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
XLNetLMHeadModel
)
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
n_head
=
4
,
d_inner
=
128
,
n_layer
=
5
,
max_position_embeddings
=
10
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
seed
=
1
,
type_vocab_size
=
2
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
n_head
=
n_head
self
.
d_inner
=
d_inner
self
.
n_layer
=
n_layer
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
type_vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
d_model
,
n_head
=
self
.
n_head
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
n_layer
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
)
run_config
=
XLNetRunConfig
(
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
)
config
.
update
(
run_config
)
return
(
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
hidden_states_1
,
mems_1
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
hidden_states_2
,
mems_2
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1
)
outputs
=
{
"hidden_states_1"
:
hidden_states_1
,
"mems_1"
:
mems_1
,
"hidden_states_2"
:
hidden_states_2
,
"mems_2"
:
mems_2
,
}
return
outputs
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
target
=
lm_labels
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
target
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
outputs
=
{
"loss_1"
:
loss_1
,
"mems_1a"
:
mems_1a
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"mems_2a"
:
mems_2a
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment