Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2ea5aef
Commit
c2ea5aef
authored
Jun 20, 2019
by
thomwolf
Browse files
work in progress on xlnet
parent
de713fa9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
529 additions
and
69 deletions
+529
-69
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+529
-69
No files found.
pytorch_pretrained_bert/modeling_xlnet.py
View file @
c2ea5aef
...
...
@@ -126,6 +126,16 @@ def swish(x):
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
def
positional_embedding
(
pos_seq
,
inv_freq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
einsum
(
'i,d->id'
,
pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
if
bsz
is
not
None
:
pos_emb
=
pos_emb
.
expand
(
1
,
bsz
,
1
)
return
pos_emb
class
XLNetBaseConfig
(
object
):
@
classmethod
def
from_dict
(
cls
,
json_object
):
...
...
@@ -165,15 +175,14 @@ class XLNetConfig(XLNetBaseConfig):
"""
def
__init__
(
self
,
vocab_size_or_config_json_file
,
d_model
=
768
,
n_layer
=
1
2
,
n_head
=
1
2
,
d_inner
=
3072
,
d_model
=
1024
,
n_layer
=
2
4
,
n_head
=
1
6
,
d_inner
=
4096
,
ff_activation
=
"gelu"
,
untie_r
=
True
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
):
"""Constructs XLNetConfig.
...
...
@@ -197,8 +206,6 @@ class XLNetConfig(XLNetBaseConfig):
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`XLNetModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
...
...
@@ -214,11 +221,12 @@ class XLNetConfig(XLNetBaseConfig):
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
assert
d_model
%
n_head
==
0
self
.
d_head
=
d_model
//
n_head
self
.
ff_activation
=
ff_activation
self
.
d_inner
=
d_inner
self
.
untie_r
=
untie_r
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
else
:
...
...
@@ -233,8 +241,8 @@ class XLNetRunConfig(XLNetBaseConfig):
We store them separately from XLNetConfig for flexibility.
"""
def
__init__
(
self
,
dropout
,
dropatt
,
dropout
=
0.1
,
dropatt
=
0.1
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
...
...
@@ -278,12 +286,12 @@ try:
except
ImportError
:
logger
.
info
(
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
)
class
XLNetLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-12
):
def
__init__
(
self
,
d_model
,
eps
=
1e-12
):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super
(
XLNetLayerNorm
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
d_model
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
d_model
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
...
...
@@ -292,6 +300,220 @@ except ImportError:
x
=
(
x
-
u
)
/
torch
.
sqrt
(
s
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
+
self
.
bias
class
XLNetRelativeAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetRelativeAttention
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
if
config
.
d_model
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
d_model
,
config
.
num_attention_heads
))
self
.
output_attentions
=
output_attentions
self
.
keep_multihead_output
=
keep_multihead_output
self
.
multihead_output
=
None
self
.
n_head
=
config
.
num_attention_heads
self
.
d_head
=
config
.
d_head
self
.
d_model
=
config
.
d_model
self
.
scale
=
1
/
(
config
.
d_head
**
0.5
)
self
.
q
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
k
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
v
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
o
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
r
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
d_model
,
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
self
.
d_head
))
self
.
seg_embed
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
n_head
,
2
,
self
.
d_head
))
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
def
rel_attn_core
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
None
,
attn_mask
=
None
):
"""Core relative positional attention operations."""
# content based attention score
ac
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_w_bias
,
k_head_h
)
# position based attention score
bd
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
self
.
r_r_bias
,
k_head_r
)
bd
=
rel_shift
(
bd
,
klen
=
torch
.
shape
(
ac
)[
1
])
# segment based attention score
if
seg_mat
is
None
:
ef
=
0
else
:
ef
=
torch
.
einsum
(
'ibnd,snd->ibns'
,
q_head
+
self
.
r_s_bias
,
self
.
seg_embed
)
ef
=
torch
.
einsum
(
'ijbs,ibns->ijbn'
,
seg_mat
,
ef
)
# merge attention scores and perform masking
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
if
attn_mask
is
not
None
:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score
=
attn_score
-
1e30
*
attn_mask
# attention probability
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropout
(
attn_prob
)
# attention output
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
attn_prob
,
v_head_h
)
return
attn_vec
def
post_attention
(
self
,
h
,
attn_vec
,
residual
=
True
):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out
=
torch
.
einsum
(
'ibnd,hnd->ibh'
,
attn_vec
,
self
.
o
)
attn_out
=
self
.
dropout
(
attn_out
)
if
residual
:
attn_out
=
attn_out
+
h
output
=
self
.
LayerNorm
(
attn_out
)
return
output
def
forward
(
self
,
h
,
g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
if
g
is
not
None
:
###### Two-stream attention with relative positional encoding.
# content based attention score
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
else
:
cat
=
h
# content-based key head
k_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
k
)
# content-based value head
v_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
v
)
# position-based key head
k_head_r
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
r
)
##### h-stream
# content-stream query head
q_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
h
,
self
.
q
)
# core attention ops
attn_vec_h
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
output_h
=
self
.
post_attention
(
h
,
attn_vec_h
)
##### g-stream
# query-stream query head
q_head_g
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
g
,
self
.
q
)
# core attention ops
if
target_mapping
is
not
None
:
q_head_g
=
torch
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
attn_vec_g
=
torch
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
else
:
attn_vec_g
=
self
.
rel_attn_core
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_g
)
# post processing
output_g
=
self
.
post_attention
(
g
,
attn_vec_g
)
attention_output
=
output_h
,
output_g
else
:
###### Multi-head attention with relative positional encoding
if
mems
is
not
None
and
mems
.
dim
()
>
1
:
cat
=
torch
.
cat
([
mems
,
h
],
dim
=
0
)
else
:
cat
=
h
# content heads
q_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
h
,
self
.
q
)
k_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
k
)
v_head_h
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
cat
,
self
.
v
)
# positional heads
k_head_r
=
torch
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
r
)
# core attention ops
attn_vec
=
self
.
rel_attn_core
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_mat
=
seg_mat
,
attn_mask
=
attn_mask_h
)
# post processing
attention_output
=
self
.
post_attention
(
h
,
attn_vec
)
# Mask heads if we want to
# if head_mask is not None:
# attention_probs = attention_probs * head_mask
# context_layer = torch.matmul(attention_probs, value_layer)
# if self.keep_multihead_output:
# self.multihead_output = context_layer
# self.multihead_output.retain_grad()
# context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
# context_layer = context_layer.view(*new_context_layer_shape)
# if self.output_attentions:
# attentions, self_output = self_output
# if self.output_attentions:
# return attentions, attention_output
return
attention_output
class
XLNetFeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
XLNetFeedForward
,
self
).
__init__
()
self
.
LayerNorm
=
XLNetLayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
self
.
activation_function
=
config
.
ff_activation
def
forward
(
self
,
hidden_states
,
input_tensor
):
hidden_states
=
self
.
layer_1
(
hidden_states
)
hidden_states
=
self
.
activation_function
(
hidden_states
)
hidden_states
=
self
.
layer_2
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
LayerNorm
(
hidden_states
+
input_tensor
)
return
hidden_states
class
XLNetLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetLayer
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
rel_attn
=
XLNetRelativeAttention
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
ff
=
XLNetFeedForward
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
r
,
seg_mat
,
two_streams
=
False
,
mems
=
None
,
target_mapping
=
None
,
head_mask
=
None
):
output_h
,
output_g
=
self
.
rel_attn
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
if
two_streams
:
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
# if self.output_attentions:
# return attentions, layer_output
return
output_h
,
output_g
class
XLNetPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
...
...
@@ -445,6 +667,228 @@ class XLNetPreTrainedModel(nn.Module):
class
XLNetModel
(
XLNetPreTrainedModel
):
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
mem_len
=
config
.
mem_len
self
.
reuse_len
=
config
.
reuse_len
layer
=
XLNetLayer
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
num_hidden_layers
)])
@
classmethod
def
_create_mask
(
qlen
,
mlen
,
dtype
=
torch
.
float
,
same_length
=
False
):
"""create causal attention mask."""
attn_mask
=
torch
.
ones
([
qlen
,
qlen
],
dtype
=
dtype
)
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype
)
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
same_length
:
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
ret
=
tf
.
concat
([
ret
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
ret
[:,
qlen
:]],
1
)
return
ret
def
cache_mem
(
self
,
curr_out
,
prev_mem
):
"""cache hidden states into memory."""
if
self
.
mem_len
is
None
or
self
.
mem_len
==
0
:
return
None
else
:
if
self
.
reuse_len
is
not
None
and
self
.
reuse_len
>
0
:
curr_out
=
curr_out
[:
self
.
reuse_len
]
if
prev_mem
is
None
:
new_mem
=
curr_out
[
-
self
.
mem_len
:]
else
:
new_mem
=
torch
.
cat
([
prev_mem
,
curr_out
],
dim
=
0
)[
-
self
.
mem_len
:]
return
new_mem
.
detach
()
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
torch
.
float
):
"""create relative positional encoding."""
freq_seq
=
torch
.
zrange
(
0
,
d_model
,
2.0
,
dtype
=
dtype
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
config
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
beg
,
end
=
klen
,
-
qlen
elif
self
.
attn_type
==
'uni'
:
# beg, end = klen - 1, -1
beg
,
end
=
klen
,
-
1
else
:
raise
ValueError
(
'Unknown `attn_type` {}.'
.
format
(
self
.
attn_type
))
if
self
.
bi_data
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
bwd_pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
1.0
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
bwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
if
bsz
is
not
None
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
//
2
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
,
bsz
//
2
)
else
:
fwd_pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
)
bwd_pos_emb
=
positional_embedding
(
bwd_pos_seq
,
inv_freq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
1
)
else
:
fwd_pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
fwd_pos_seq
.
clamp
(
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
return
pos_emb
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
qlen
,
bsz
=
inp_k
.
shape
mlen
=
mems
[
0
].
shape
[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
##### Attention mask
# causal attention mask
if
self
.
attn_type
==
'uni'
:
attn_mask
=
_create_mask
(
qlen
,
mlen
,
inp_k
.
dtype
,
self
.
same_length
)
attn_mask
=
attn_mask
[:,
:,
None
,
None
]
elif
self
.
attn_type
==
'bi'
:
attn_mask
=
None
else
:
raise
ValueError
(
'Unsupported attention type: {}'
.
format
(
self
.
attn_type
))
# data mask: input mask & perm mask
if
input_mask
is
not
None
and
perm_mask
is
not
None
:
data_mask
=
input_mask
[
None
]
+
perm_mask
elif
input_mask
is
not
None
and
perm_mask
is
None
:
data_mask
=
input_mask
[
None
]
elif
input_mask
is
None
and
perm_mask
is
not
None
:
data_mask
=
perm_mask
else
:
data_mask
=
None
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
],
dtype
=
data_mask
.
dtype
,
device
=
data_mask
.
device
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
>
0
).
float
()
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
tf_float
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
tf_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
tf_float
)
else
:
non_tgt_mask
=
None
##### Word embedding
word_emb_k
=
self
.
word_embedding
(
inp_k
)
output_h
=
self
.
dropout
(
word_emb_k
)
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
word_emb_q
=
mask_emb
.
expand
(
target_mapping
.
shape
[
0
],
bsz
,
1
)
else
:
inp_q_ext
=
inp_q
[:,
:,
None
]
word_emb_q
=
inp_q_ext
*
mask_emb
+
(
1
-
inp_q_ext
)
*
word_emb_k
output_g
=
self
.
dropout
(
word_emb_q
)
else
:
output_g
=
None
##### Segment embedding
if
seg_id
is
not
None
:
# Convert `seg_id` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
)
cat_ids
=
torch
.
cat
([
mem_pad
,
seg_id
],
dim
=
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
seg_id
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
# seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
else
:
seg_mat
=
None
##### Positional encoding
pos_emb
=
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
inp_k
.
dtype
)
pos_emb
=
self
.
dropout
(
pos_emb
)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
new_mems
=
[]
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
output_h
,
output_g
=
layer_module
(
output_h
,
output_g
,
attn_mask_h
,
attn_mask_g
,
r
,
seg_mat
,
mems
=
mems
[
i
],
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
return
output
class
XLNetLMHeadModel
(
XLNetPreTrainedModel
):
"""XLNet model ("XLNet: Generalized Autoregressive Pretraining for Language Understanding").
Params:
...
...
@@ -473,10 +917,10 @@ class XLNetModel(XLNetPreTrainedModel):
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLNet-base, 24 for XLNet-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
hidden_size
],
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length,
d_model
],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length,
hidden_size
],
`pooled_output`: a torch.FloatTensor of size [batch_size,
hidden_size
] which is the output of a
to the last attention block of shape [batch_size, sequence_length,
d_model
],
`pooled_output`: a torch.FloatTensor of size [batch_size,
d_model
] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLNet's paper).
...
...
@@ -487,16 +931,30 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000,
hidden_size
=768,
config = modeling.XLNetConfig(vocab_size_or_config_json_file=32000,
d_model
=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLNetModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetModel
,
self
).
__init__
(
config
)
def
__init__
(
self
,
config
,
run_config
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNet
LMHead
Model
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
run_config
.
attn_type
self
.
same_length
=
run_config
.
same_length
self
.
word_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
1
,
self
.
d_model
))
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
True
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
# Tie weights
if
config
.
tie_weight
:
self
.
lm_loss
.
weight
=
self
.
word_embedding
.
weight
self
.
apply
(
self
.
init_xlnet_weights
)
def
prune_heads
(
self
,
heads_to_prune
):
...
...
@@ -512,54 +970,56 @@ class XLNetModel(XLNetPreTrainedModel):
"""
return
[
layer
.
attention
.
self
.
multihead_output
for
layer
in
self
.
encoder
.
layer
]
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand_as
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
def
forward
(
self
,
inp_k
,
seg_id
=
None
,
input_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
input_mask: float32 Tensor in shape [len, bsz], the input mask.
0 for real tokens and 1 for padding.
mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: float32 Tensor in shape [len, len, bsz].
If perm_mask[i, j, k] = 0, i attend to j in batch k;
if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: float32 Tensor in shape [num_predict, len, bsz].
If target_mapping[i, j, k] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: float32 Tensor in shape [len, bsz].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
embedding_output
=
self
.
embeddings
(
input_ids
,
token_type_ids
)
encoded_layers
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
output_all_encoded_layers
=
output_all_encoded_layers
,
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
output
,
new_mems
=
self
.
transformer
(
output_h
,
non_tgt_mask
,
r
,
seg_mat
,
output_g
=
output_g
,
attn_mask_g
=
attn_mask
,
mems
=
mems
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
if
self
.
output_attentions
:
all_attentions
,
encoded_layers
=
encoded_layers
sequence_output
=
encoded_layers
[
-
1
]
pooled_output
=
self
.
pooler
(
sequence_output
)
if
not
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
if
self
.
output_attentions
:
return
all_attentions
,
encoded_layers
,
pooled_output
return
encoded_layers
,
pooled_output
logits
=
self
.
lm_loss
(
output
)
# if self.output_attentions:
# all_attentions, encoded_layers = encoded_layers
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
# if not output_all_encoded_layers:
# encoded_layers = encoded_layers[-1]
# if self.output_attentions:
# return all_attentions, encoded_layers, pooled_output
return
output
,
new_mems
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment