Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2ea5aef
Commit
c2ea5aef
authored
Jun 20, 2019
by
thomwolf
Browse files
work in progress on xlnet
parent
de713fa9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
529 additions
and
69 deletions
+529
-69
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+529
-69
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
c2ea5aef
...
@@ -126,6 +126,16 @@ def swish(x):
...
@@ -126,6 +126,16 @@ def swish(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
1
,
bsz
,
1
)
return
pos_emb
class
XLNetBaseConfig
(
object
):
class
XLNetBaseConfig
(
object
):
@
classmethod
@
classmethod
def
from_dict
(
cls
,
json_object
):
def
from_dict
(
cls
,
json_object
):
...
@@ -165,15 +175,14 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -165,15 +175,14 @@ class XLNetConfig(XLNetBaseConfig):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size_or_config_json_file
,
vocab_size_or_config_json_file
,
d_model
=
768
,
d_model
=
1024
,
n_layer
=
1
2
,
n_layer
=
2
4
,
n_head
=
1
2
,
n_head
=
1
6
,
d_inner
=
3072
,
d_inner
=
4096
,
ff_activation
=
"gelu"
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
untie_r
=
True
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
layer_norm_eps
=
1e-12
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
...
@@ -197,8 +206,6 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -197,8 +206,6 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings: The maximum sequence length that this model might
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`XLNetModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
layer_norm_eps: The epsilon used by LayerNorm.
...
@@ -214,11 +221,12 @@ class XLNetConfig(XLNetBaseConfig):
...
@@ -214,11 +221,12 @@ class XLNetConfig(XLNetBaseConfig):
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_head
=
n_head
assert
d_model
%
n_head
==
0
self
.
d_head
=
d_model
//
n_head
self
.
ff_activation
=
ff_activation
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
layer_norm_eps
=
layer_norm_eps
else
:
else
:
...
@@ -233,8 +241,8 @@ class XLNetRunConfig(XLNetBaseConfig):
...
@@ -233,8 +241,8 @@ class XLNetRunConfig(XLNetBaseConfig):
We store them separately from XLNetConfig for flexibility.
We store them separately from XLNetConfig for flexibility.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
dropout
,
dropout
=
0.1
,
dropatt
,
dropatt
=
0.1
,
init
=
"normal"
,
init
=
"normal"
,
init_range
=
0.1
,
init_range
=
0.1
,
init_std
=
0.02
,
init_std
=
0.02
,
...
@@ -278,12 +286,12 @@ try:
...
@@ -278,12 +286,12 @@ try:
except
ImportError
:
except
ImportError
:
logger
.
info
(
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
)
logger
.
info
(
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
)
class
XLNetLayerNorm
(
nn
.
Module
):
class
XLNetLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-12
):
def
__init__
(
self
,
d_model
,
eps
=
1e-12
):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
"""
super
(
XLNetLayerNorm
,
self
).
__init__
()
super
(
XLNetLayerNorm
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
d_model
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -292,6 +300,220 @@ except ImportError:
...
@@ -292,6 +300,220 @@ except ImportError:
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
+
self
.
bias
return
self
.
weight
*
x
+
self
.
bias
class
XLNetRelativeAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
if
config
.
d_model
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
d_model
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
n_head
=
config
.
num_attention_heads
self
.
d_head
=
config
.
d_head
self
.
d_model
=
config
.
d_model
self
.
scale
=
1
/
(
config
.
d_head
**
0.5
)
self
.
q
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
k
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
v
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
o
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
r
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
2
,
self
.
d_head
))
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
"""Core relative positional attention operations."""
# content based attention score
ac
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_w_bias
,
k_head_h
)
# position based attention score
bd
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_r_bias
,
k_head_r
)
bd
=
rel_shift
(
bd
,
klen
=
torch
.
shape
(
ac
)[
1
])
# segment based attention score
if
seg_mat
is
None
:
ef
=
0
else
:
ef
=
torch
.
einsum
(
'ibnd,snd->ibns'
,
q_head
+
self
.
r_s_bias
,
self
.
seg_embed
)
ef
=
torch
.
einsum
(
'ijbs,ibns->ijbn'
,
seg_mat
,
ef
)
# merge attention scores and perform masking
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
if
attn_mask
is
not
None
:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score
=
attn_score
-
1e30
*
attn_mask
# attention probability
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropout
(
attn_prob
)
# attention output
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
attn_prob
,
v_head_h
)
return
attn_vec
def
post_attention
(
self
,
h
,
attn_vec
,
residual
=
True
):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out
=
torch
.
einsum
(
'ibnd,hnd->ibh'
,
attn_vec
,
self
.
o
)
attn_out
=
self
.
dropout
(
attn_out
)
if
residual
:
attn_out
=
attn_out
+
h
output
=
self
.
LayerNorm
(
attn_out
)
return
output
def
forward
(
self
,
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
if
g
is
not
None
:
###### Two-stream attention with relative positional encoding.
# content based attention score
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
else
:
cat
=
h
# content-based key head
k_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
k
)
# content-based value head
v_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
v
)
# position-based key head
k_head_r
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
r
)
##### h-stream
# content-stream query head
q_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
h
,
self
.
q
)
# core attention ops
attn_vec_h
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
output_h
=
self
.
post_attention
(
h
,
attn_vec_h
)
##### g-stream
# query-stream query head
q_head_g
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
g
,
self
.
q
)
# core attention ops
if
target_mapping
is
not
None
:
q_head_g
=
torch
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
attn_vec_g
=
torch
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
else
:
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
# post processing
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
attention_output
=
output_h
,
output_g
else
:
###### Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
else
:
cat
=
h
# content heads
q_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
h
,
self
.
q
)
k_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
k
)
v_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
v
)
# positional heads
k_head_r
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
r
)
# core attention ops
attn_vec
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
attention_output
=
self
.
post_attention
(
h
,
attn_vec
)
# Mask heads if we want to
# if head_mask is not None:
# attention_probs = attention_probs * head_mask
# context_layer = torch.matmul(attention_probs, value_layer)
# if self.keep_multihead_output:
# self.multihead_output = context_layer
# self.multihead_output.retain_grad()
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# context_layer = context_layer.view(*new_context_layer_shape)
# if self.output_attentions:
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return
attention_output
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLNetFeedForward
,
self
).
__init__
()
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
self
.
activation_function
=
config
.
ff_activation
def
forward
(
self
,
hidden_states
,
input_tensor
):
hidden_states
=
self
.
layer_1
(
hidden_states
)
hidden_states
=
self
.
activation_function
(
hidden_states
)
hidden_states
=
self
.
layer_2
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
r
,
seg_mat
,
two_streams
=
False
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
if
two_streams
:
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
# if self.output_attentions:
# return attentions, layer_output
return
output_h
,
output_g
class
XLNetPreTrainedModel
(
nn
.
Module
):
class
XLNetPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
...
@@ -445,6 +667,228 @@ class XLNetPreTrainedModel(nn.Module):
...
@@ -445,6 +667,228 @@ class XLNetPreTrainedModel(nn.Module):
class
XLNetModel
(
XLNetPreTrainedModel
):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
@
classmethod
def
_create_mask
(
qlen
,
mlen
,
dtype
=
torch
.
float
,
same_length
=
False
):
"""create causal attention mask."""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
],
dtype
=
dtype
)
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype
)
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
same_length
:
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
ret
=
tf
.
concat
([
ret
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
ret
[:,
qlen
:]],
1
)
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
"""cache hidden states into memory."""
if
self
.
mem_len
is
None
or
self
.
mem_len
==
0
:
return
None
else
:
if
self
.
reuse_len
is
not
None
and
self
.
reuse_len
>
0
:
curr_out
=
curr_out
[:
self
.
reuse_len
]
if
prev_mem
is
None
:
new_mem
=
curr_out
[
-
self
.
mem_len
:]
else
:
new_mem
=
torch
.
cat
([
prev_mem
,
curr_out
],
dim
=
0
)[
-
self
.
mem_len
:]
return
new_mem
.
detach
()
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
torch
.
float
):
"""create relative positional encoding."""
freq_seq
=
torch
.
zrange
(
0
,
d_model
,
2.0
,
dtype
=
dtype
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
config
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
beg
,
end
=
klen
,
-
qlen
elif
self
.
attn_type
==
'uni'
:
# beg, end = klen - 1, -1
beg
,
end
=
klen
,
-
1
else
:
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
return
pos_emb
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen
,
bsz
=
inp_k
.
shape
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
##### Attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
attn_mask
=
_create_mask
(
qlen
,
mlen
,
inp_k
.
dtype
,
self
.
same_length
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
else
:
raise
ValueError
(
'Unsupported attention type: {}'
.
format
(
self
.
attn_type
))
# data mask: input mask & perm mask
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
input_mask
[
None
]
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
perm_mask
else
:
data_mask
=
None
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
],
dtype
=
data_mask
.
dtype
,
device
=
data_mask
.
device
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
>
0
).
float
()
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
tf_float
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
tf_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
tf_float
)
else
:
non_tgt_mask
=
None
##### Word embedding
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
output_g
=
None
##### Segment embedding
if
seg_id
is
not
None
:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
)
cat_ids
=
torch
.
cat
([
mem_pad
,
seg_id
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
seg_id
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
# seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
else
:
seg_mat
=
None
##### Positional encoding
pos_emb
=
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
inp_k
.
dtype
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
new_mems
=
[]
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
return
output
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
Params:
Params:
...
@@ -473,10 +917,10 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -473,10 +917,10 @@ class XLNetModel(XLNetPreTrainedModel):
`encoded_layers`: controled by `output_all_encoded_layers` argument:
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each
of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
hidden_size
],
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
d_model
],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length,
hidden_size
],
to the last attention block of shape [batch_size, sequence_length,
d_model
],
`pooled_output`: a torch.FloatTensor of size [batch_size,
hidden_size
] which is the output of a
`pooled_output`: a torch.FloatTensor of size [batch_size,
d_model
] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
...
@@ -487,16 +931,30 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -487,16 +931,30 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000,
hidden_size
=768,
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000,
d_model
=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
```
"""
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
def
__init__
(
self
,
config
,
run_config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
super
(
XLNet
LMHead
Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
run_config
.
attn_type
self
.
same_length
=
run_config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
self
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
# Tie weights
if
config
.
tie_weight
:
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
self
.
apply
(
self
.
init_xlnet_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
def
prune_heads
(
self
,
heads_to_prune
):
...
@@ -512,54 +970,56 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -512,54 +970,56 @@ class XLNetModel(XLNetPreTrainedModel):
"""
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
if
attention_mask
is
None
:
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
attention_mask
=
torch
.
ones_like
(
input_ids
)
output_all_encoded_layers
=
True
,
head_mask
=
None
):
if
token_type_ids
is
None
:
"""
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
# We create a 3D attention mask from a 2D tensor mask.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
# Sizes are [batch_size, 1, 1, to_seq_length]
input_mask: float32 Tensor in shape [len, bsz], the input mask.
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
0 for real tokens and 1 for padding.
# this attention mask is more simple than the triangular masking of causal attention
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
from previous batches. The length of the list equals n_layer.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
If perm_mask[i, j, k] = 0, i attend to j in batch k;
# masked positions, this operation will create a tensor which is 0.0 for
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
# positions we want to attend and -10000.0 for masked positions.
If None, each position attends to all the others.
# Since we are adding it to the raw scores before the softmax, this is
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
# effectively the same as removing these entirely.
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
on the j-th token.
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
Only used during pretraining for partial prediction.
Set to None during finetuning.
# Prepare head mask if needed
inp_q: float32 Tensor in shape [len, bsz].
# 1.0 in head_mask indicate we keep the head
1 for tokens with losses and 0 for tokens without losses.
# attention_probs has shape bsz x n_heads x N x N
Only used during pretraining for two-stream attention.
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
Set to None during finetuning.
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
mem_len: int, the number of tokens to cache.
encoded_layers
=
self
.
encoder
(
embedding_output
,
reuse_len: int, the number of tokens in the currect batch to be cached
extended_attention_mask
,
and reused in the future.
output_all_encoded_layers
=
output_all_encoded_layers
,
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output
,
new_mems
=
self
.
transformer
(
output_h
,
non_tgt_mask
,
r
,
seg_mat
,
output_g
=
output_g
,
attn_mask_g
=
attn_mask
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
logits
=
self
.
lm_loss
(
output
)
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return
output
,
new_mems
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment