Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45709d75
Commit
45709d75
authored
Jun 21, 2019
by
thomwolf
Browse files
model running with simple inputs
parent
b407972e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
404 additions
and
108 deletions
+404
-108
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+3
-0
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+2
-2
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+1
-1
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+1
-1
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+1
-1
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+149
-103
tests/modeling_xlnet_test.py
tests/modeling_xlnet_test.py
+247
-0
No files found.
pytorch_pretrained_bert/__init__.py
View file @
45709d75
...
@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
...
@@ -17,6 +17,9 @@ from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHe
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
from
.modeling_gpt2
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
GPT2MultipleChoiceHead
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
GPT2MultipleChoiceHead
,
load_tf_weights_in_gpt2
)
load_tf_weights_in_gpt2
)
from
.modeling_xlnet
import
(
XLNetBaseConfig
,
XLNetConfig
,
XLNetRunConfig
,
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
load_tf_weights_in_xlnet
)
from
.optimization
import
BertAdam
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
...
...
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
45709d75
...
@@ -21,13 +21,13 @@ from __future__ import print_function
...
@@ -21,13 +21,13 @@ from __future__ import print_function
import
argparse
import
argparse
import
torch
import
torch
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
load_tf_weights_in_xlnet
from
pytorch_pretrained_bert.modeling_xlnet
import
XLNetConfig
,
XLNetRunConfig
,
XLNet
LMHead
Model
,
load_tf_weights_in_xlnet
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_xlnet_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetModel
(
config
)
model
=
XLNet
LMHead
Model
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
load_tf_weights_in_xlnet
(
model
,
tf_checkpoint_path
)
...
...
pytorch_pretrained_bert/modeling.py
View file @
45709d75
...
@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -867,7 +867,7 @@ class BertModel(BertPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_gpt2.py
View file @
45709d75
...
@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -722,7 +722,7 @@ class GPT2Model(GPT2PreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
45709d75
...
@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -718,7 +718,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
45709d75
...
@@ -29,6 +29,7 @@ from io import open
...
@@ -29,6 +29,7 @@ from io import open
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
CONFIG_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
CONFIG_NAME
...
@@ -126,32 +127,27 @@ def swish(x):
...
@@ -126,32 +127,27 @@ def swish(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
1
,
bsz
,
1
)
return
pos_emb
class
XLNetBaseConfig
(
object
):
class
XLNetBaseConfig
(
object
):
@
classmethod
@
classmethod
def
from_dict
(
cls
,
json_object
):
def
from_dict
(
cls
,
json_object
):
"""Constructs a `XLNetConfig` from a Python dictionary of parameters."""
"""Constructs a `XLNet
Base
Config` from a Python dictionary of parameters."""
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=-
1
)
config
=
cls
(
vocab_size_or_config_json_file
=-
1
)
for
key
,
value
in
json_object
.
items
():
for
key
,
value
in
json_object
.
items
():
config
.
__dict__
[
key
]
=
value
config
.
__dict__
[
key
]
=
value
return
config
return
config
@
classmethod
@
classmethod
def
from_json_file
(
cls
,
json_file
):
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `XLNetConfig` from a json file of parameters."""
"""Constructs a `XLNet
Base
Config` from a json file of parameters."""
with
open
(
json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
text
=
reader
.
read
()
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
update
(
self
,
other
):
dict_b
=
other
.
to_dict
()
for
key
,
value
in
dict_b
.
items
():
self
.
__dict__
[
key
]
=
value
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
.
to_json_string
())
return
str
(
self
.
to_json_string
())
...
@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -181,6 +177,7 @@ class XLNetConfig(XLNetBaseConfig):
d_inner
=
4096
,
d_inner
=
4096
,
ff_activation
=
"gelu"
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
untie_r
=
True
,
attn_type
=
"bi"
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
...
@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -198,6 +195,7 @@ class XLNetConfig(XLNetBaseConfig):
ff_activation: The non-linear activation function (function or string) in the
ff_activation: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
untie_r: untie relative position biases
untie_r: untie relative position biases
attn_type: 'bi' for XLNet, 'uni' for Transformer-XL
dropout: The dropout probabilitiy for all fully connected
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
layers in the embeddings, encoder, and pooler.
...
@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -226,6 +224,7 @@ class XLNetConfig(XLNetBaseConfig):
self
.
ff_activation
=
ff_activation
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
attn_type
=
attn_type
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
layer_norm_eps
=
layer_norm_eps
...
@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -304,15 +303,15 @@ class XLNetRelativeAttention(nn.Module):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
if
config
.
d_model
%
config
.
n
um_attention
_head
s
!=
0
:
if
config
.
d_model
%
config
.
n_head
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
d_model
,
config
.
n
um_attention
_head
s
))
"heads (%d)"
%
(
config
.
d_model
,
config
.
n_head
))
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
multihead_output
=
None
self
.
n_head
=
config
.
n
um_attention
_head
s
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
d_head
=
config
.
d_head
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
self
.
scale
=
1
/
(
config
.
d_head
**
0.5
)
self
.
scale
=
1
/
(
config
.
d_head
**
0.5
)
...
@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -326,7 +325,7 @@ class XLNetRelativeAttention(nn.Module):
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
2
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
n_head
,
self
.
d_head
))
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -334,6 +333,18 @@ class XLNetRelativeAttention(nn.Module):
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
rel_shift
(
x
,
klen
=-
1
):
"""perform relative shift to form the relative attention score."""
x_size
=
x
.
shape
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
return
x
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
"""Core relative positional attention operations."""
"""Core relative positional attention operations."""
...
@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -342,7 +353,7 @@ class XLNetRelativeAttention(nn.Module):
# position based attention score
# position based attention score
bd
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_r_bias
,
k_head_r
)
bd
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_r_bias
,
k_head_r
)
bd
=
rel_shift
(
bd
,
klen
=
torch
.
shape
(
ac
)
[
1
])
bd
=
self
.
rel_shift
(
bd
,
klen
=
ac
.
shape
[
1
])
# segment based attention score
# segment based attention score
if
seg_mat
is
None
:
if
seg_mat
is
None
:
...
@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -426,7 +437,6 @@ class XLNetRelativeAttention(nn.Module):
# post processing
# post processing
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
attention_output
=
output_h
,
output_g
else
:
else
:
###### Multi-head attention with relative positional encoding
###### Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
...
@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -447,7 +457,8 @@ class XLNetRelativeAttention(nn.Module):
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
# post processing
attention_output
=
self
.
post_attention
(
h
,
attn_vec
)
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
# Mask heads if we want to
# Mask heads if we want to
...
@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module):
...
@@ -467,7 +478,7 @@ class XLNetRelativeAttention(nn.Module):
# attentions, self_output = self_output
# attentions, self_output = self_output
# if self.output_attentions:
# if self.output_attentions:
# return attentions, attention_output
# return attentions, attention_output
return
attention_
output
return
output_h
,
output
_g
class
XLNetFeedForward
(
nn
.
Module
):
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module):
...
@@ -481,13 +492,15 @@ class XLNetFeedForward(nn.Module):
else
:
else
:
self
.
activation_function
=
config
.
ff_activation
self
.
activation_function
=
config
.
ff_activation
def
forward
(
self
,
hidden_states
,
input_tensor
):
def
forward
(
self
,
inp
):
hidden_states
=
self
.
layer_1
(
hidden_states
)
output
=
inp
hidden_states
=
self
.
activation_function
(
hidden_states
)
output
=
self
.
layer_1
(
output
)
hidden_states
=
self
.
layer_2
(
hidden_states
)
output
=
self
.
activation_function
(
output
)
hidden_states
=
self
.
dropout
(
hidden_states
)
output
=
self
.
dropout
(
output
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
output
=
self
.
layer_2
(
output
)
return
hidden_states
output
=
self
.
dropout
(
output
)
output
=
self
.
LayerNorm
(
output
+
inp
)
return
output
class
XLNetLayer
(
nn
.
Module
):
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
...
@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module):
...
@@ -500,13 +513,13 @@ class XLNetLayer(nn.Module):
def
forward
(
self
,
output_h
,
output_g
,
def
forward
(
self
,
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
r
,
seg_mat
,
r
,
seg_mat
,
two_streams
=
False
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_g
,
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
if
two_streams
:
if
output_g
is
not
None
:
output_g
=
self
.
ff
(
output_g
)
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
output_h
=
self
.
ff
(
output_h
)
...
@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -520,9 +533,9 @@ class XLNetPreTrainedModel(nn.Module):
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
super
(
XLNetPreTrainedModel
,
self
).
__init__
()
if
not
isinstance
(
config
,
XLNetConfig
):
if
not
isinstance
(
config
,
XLNet
Base
Config
):
raise
ValueError
(
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `XLNetConfig`. "
"Parameter config in `{}(config)` should be an instance of class `XLNet
Base
Config`. "
"To create a model from a Google pretrained model use "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
...
@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -668,26 +681,41 @@ class XLNetPreTrainedModel(nn.Module):
class
XLNetModel
(
XLNetPreTrainedModel
):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
()
super
(
XLNetModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
mem_len
=
config
.
mem_len
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
self
.
reuse_len
=
config
.
reuse_len
self
.
d_model
=
config
.
d_model
self
.
same_length
=
config
.
same_length
self
.
attn_type
=
config
.
attn_type
self
.
bi_data
=
config
.
bi_data
self
.
clamp_len
=
config
.
clamp_len
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
@
classmethod
def
_create_mask
(
qlen
,
mlen
,
dtype
=
torch
.
float
,
same_length
=
False
):
"""create causal attention mask."""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
],
dtype
=
dtype
)
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype
)
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
same_length
:
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
ret
=
tf
.
concat
([
ret
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
ret
[:,
qlen
:]],
1
)
def
create_mask
(
self
,
qlen
,
mlen
):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
])
mask_up
=
torch
.
triu
(
attn_mask
,
diagonal
=
1
)
attn_mask_pad
=
torch
.
zeros
([
qlen
,
mlen
])
ret
=
torch
.
cat
([
attn_mask_pad
,
mask_up
],
dim
=
1
)
if
self
.
same_length
:
mask_lo
=
torch
.
tril
(
attn_mask
,
diagonal
=-
1
)
ret
=
torch
.
cat
([
ret
[:,
:
qlen
]
+
mask_lo
,
ret
[:,
qlen
:]],
dim
=
1
)
ret
=
ret
.
to
(
next
(
self
.
parameters
()))
return
ret
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
...
@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -705,10 +733,21 @@ class XLNetModel(XLNetPreTrainedModel):
return
new_mem
.
detach
()
return
new_mem
.
detach
()
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
torch
.
float
):
@
staticmethod
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)],
dim
=-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
-
1
,
bsz
,
-
1
)
return
pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
"""create relative positional encoding."""
freq_seq
=
torch
.
z
range
(
0
,
d_model
,
2.0
,
dtype
=
dtype
)
freq_seq
=
torch
.
a
range
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
config
.
d_model
))
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
# beg, end = klen - 1, -qlen
...
@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -720,51 +759,52 @@ class XLNetModel(XLNetPreTrainedModel):
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
torch
.
float
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
dtype
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
torch
.
float
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
if
bsz
is
not
None
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
else
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
fwd_pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
self
.
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
pos_emb
.
to
(
next
(
self
.
parameters
()))
return
pos_emb
return
pos_emb
def
forward
(
self
,
inp
_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
word_emb
_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: in
t32 Tensor in shape [len, bsz], the input token
ID
s.
word_emb_k: floa
t32 Tensor in shape [len, bsz
, d_model
], the input token
embedding
s.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
input_mask:
[optional]
float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
mems:
[optional]
a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
perm_mask:
[optional]
float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
target_mapping:
[optional]
float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
on the j-th token.
Only used during pretraining for partial prediction.
Only used during pretraining for partial prediction.
Set to None during finetuning.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
inp_q:
[optional]
float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Set to None during finetuning.
...
@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -780,14 +820,16 @@ class XLNetModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
qlen
,
bsz
=
inp
_k
.
shape
qlen
,
bsz
=
word_emb_k
.
shape
[
0
],
word_emb
_k
.
shape
[
1
]
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
dtype_float
=
word_emb_k
.
dtype
device
=
word_emb_k
.
device
##### Attention mask
##### Attention mask
# causal attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
if
self
.
attn_type
==
'uni'
:
attn_mask
=
_
create_mask
(
qlen
,
mlen
,
inp_k
.
dtype
,
self
.
same_length
)
attn_mask
=
self
.
create_mask
(
qlen
,
mlen
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
attn_mask
=
None
...
@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -806,7 +848,7 @@ class XLNetModel(XLNetPreTrainedModel):
if
data_mask
is
not
None
:
if
data_mask
is
not
None
:
# all mems can be attended to
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]
,
dtype
=
data_mask
.
dtype
,
device
=
data_mask
.
device
)
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]
).
to
(
data_mask
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
attn_mask
=
data_mask
[:,
:,
:,
None
]
...
@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -814,23 +856,20 @@ class XLNetModel(XLNetPreTrainedModel):
attn_mask
+=
data_mask
[:,
:,
:,
None
]
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
>
0
).
float
(
)
attn_mask
=
(
attn_mask
>
0
).
to
(
dtype_
float
)
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
tf_float
)
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
tf_float
),
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
tf_float
)
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
##### Word embedding
##### Process Word embeddings and prepare h & g hidden states
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
1
)
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
-
1
)
else
:
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
...
@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -841,33 +880,33 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding
##### Segment embedding
if
seg_id
is
not
None
:
if
seg_id
is
not
None
:
# Convert `seg_id` to one-hot `seg_mat`
# Convert `seg_id` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
)
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
seg_id
],
dim
=
0
)
cat_ids
=
torch
.
cat
([
mem_pad
,
seg_id
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
seg_id
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
seg_mat
=
(
seg_id
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
#
seg_mat =
tf
.one_hot(seg_mat,
2,
dtype
=tf
_float)
seg_mat
=
F
.
one_hot
(
seg_mat
,
num_classes
=
2
).
to
(
dtype_float
)
else
:
else
:
seg_mat
=
None
seg_mat
=
None
##### Positional encoding
##### Positional encoding
pos_emb
=
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
inp_k
.
dtype
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n
um_hidden
_layer
s
x num_heads]
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n
um_hidden
_layer
s
x batch x num_heads x seq_length x seq_length]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
_as
(
self
.
config
.
n
um_hidden
_layer
s
,
-
1
,
-
1
,
-
1
,
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n
um_hidden
_layer
s
head_mask
=
[
None
]
*
self
.
config
.
n_layer
new_mems
=
[]
new_mems
=
[]
if
mems
is
None
:
if
mems
is
None
:
...
@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -878,14 +917,14 @@ class XLNetModel(XLNetPreTrainedModel):
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
attn_mask_h
=
non_tgt_mask
,
attn_mask_g
=
attn_mask
,
r
,
seg_mat
,
r
=
pos_emb
,
seg_mat
=
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
return
output
return
output
,
new_mems
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
...
@@ -932,27 +971,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -932,27 +971,26 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000, d_model=768,
num_hidde
n_layer
s
=12, num_attention_heads=12, intermediate_size=3072)
n_layer=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
run_config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
run_
config
.
attn_type
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
run_
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
self
.
d_model
))
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
config
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
# Tie weights
# Tie weights
if
config
.
tie_weight
:
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
apply
(
self
.
init_xlnet_weights
)
...
@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -972,7 +1010,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
"""
Args:
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
...
@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1007,13 +1045,21 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
to pool the input to get a vector representation.
"""
"""
output
,
new_mems
=
self
.
transformer
(
output_h
,
non_tgt_mask
,
r
,
seg_mat
,
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_g
=
output_g
,
attn_mask_g
=
attn_mask
,
mems
=
mems
,
target_mapping
=
target_mapping
,
output
,
new_mems
=
self
.
transformer
(
word_emb_k
,
seg_id
,
input_mask
,
head_mask
=
head_mask
)
mems
,
perm_mask
,
target_mapping
,
inp_q
,
output_all_encoded_layers
,
head_mask
)
logits
=
self
.
lm_loss
(
output
)
logits
=
self
.
lm_loss
(
output
)
if
target
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
target
.
view
(
-
1
))
return
loss
,
new_mems
# if self.output_attentions:
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# sequence_output = encoded_layers[-1]
...
...
tests/modeling_xlnet_test.py
0 → 100644
View file @
45709d75
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_pretrained_bert
import
(
XLNetConfig
,
XLNetRunConfig
,
XLNetModel
,
XLNetLMHeadModel
)
from
pytorch_pretrained_bert.modeling_xlnet
import
PRETRAINED_MODEL_ARCHIVE_MAP
class
XLNetModelTest
(
unittest
.
TestCase
):
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
n_head
=
4
,
d_inner
=
128
,
n_layer
=
5
,
max_position_embeddings
=
10
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
seed
=
1
,
type_vocab_size
=
2
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
n_head
=
n_head
self
.
d_inner
=
d_inner
self
.
n_layer
=
n_layer
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
input_ids_2
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
segment_ids
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
type_vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
XLNetModelTest
.
ids_tensor
([
self
.
seq_length
,
self
.
batch_size
],
self
.
vocab_size
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
d_model
,
n_head
=
self
.
n_head
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
n_layer
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
)
run_config
=
XLNetRunConfig
(
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
)
config
.
update
(
run_config
)
return
(
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
hidden_states_1
,
mems_1
=
model
(
input_ids_1
,
seg_id
=
segment_ids
)
hidden_states_2
,
mems_2
=
model
(
input_ids_2
,
seg_id
=
segment_ids
,
mems
=
mems_1
)
outputs
=
{
"hidden_states_1"
:
hidden_states_1
,
"mems_1"
:
mems_1
,
"hidden_states_2"
:
hidden_states_2
,
"mems_2"
:
mems_2
,
}
return
outputs
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
d_model
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
segment_ids
,
lm_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
target
=
lm_labels
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
target
=
lm_labels
,
mems
=
mems_1a
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
outputs
=
{
"loss_1"
:
loss_1
,
"mems_1a"
:
mems_1a
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"mems_2a"
:
mems_2a
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
return
outputs
def
check_transfo_xl_lm_head_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
seq_length
,
self
.
batch_size
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
self
.
run_tester
(
XLNetModelTest
.
XLNetModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_model"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
XLNetConfig
(
vocab_size_or_config_json_file
=
96
,
d_model
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
XLNetConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
for
dim
in
shape
:
total_dims
*=
dim
values
=
[]
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment