Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
65c49bb2
Commit
65c49bb2
authored
Sep 13, 2019
by
thomwolf
Browse files
adding TF 2.0 adaptive softmax with logits + loss outputs
parent
39c38b2e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
869 additions
and
1105 deletions
+869
-1105
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+2
-0
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+344
-729
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
+279
-0
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+2
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+20
-373
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+5
-1
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
+217
-0
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
65c49bb2
...
...
@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
raise
NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
...
...
pytorch_transformers/modeling_tf_transfo_xl.py
View file @
65c49bb2
...
...
@@ -30,8 +30,8 @@ import numpy as np
import
tensorflow
as
tf
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
from
.modeling_transfo_xl_utilities
import
Projected
Adaptive
Log
Softmax
,
sample_logits
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
from
.modeling_
tf_
transfo_xl_utilities
import
TF
AdaptiveSoftmax
Mask
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
...
...
@@ -49,55 +49,56 @@ def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
PositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
demb
):
super
(
PositionalEmbedding
,
self
).
__init__
()
class
TF
PositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
demb
,
**
kwargs
):
super
(
TF
PositionalEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
demb
=
demb
self
.
inv_freq
=
1
/
(
10000
**
(
tf
.
range
(
0
,
demb
,
2.0
)
/
demb
))
inv_freq
=
1
/
(
10000
**
(
torch
.
arange
(
0.0
,
demb
,
2.0
)
/
demb
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
pos_seq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
ger
(
pos_seq
,
self
.
inv_freq
)
pos_emb
=
torch
.
cat
([
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()],
dim
=-
1
)
def
call
(
self
,
pos_seq
,
bsz
=
None
):
sinusoid_inp
=
tf
.
einsum
(
'i,j->ij'
,
pos_seq
,
self
.
inv_freq
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
if
bsz
is
not
None
:
return
pos_emb
[:,
None
,:]
.
expand
(
-
1
,
bsz
,
-
1
)
return
tf
.
tile
(
pos_emb
[:,
None
,
:]
,
[
1
,
bsz
,
1
]
)
else
:
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,
:]
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
class
TF
PositionwiseFF
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
**
kwargs
):
super
(
TF
PositionwiseFF
,
self
).
__init__
(
**
kwargs
)
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
)
self
.
layer_1
=
tf
.
keras
.
layers
.
Dense
(
d_inner
,
activation
=
tf
.
nn
.
relu
,
name
=
'CoreNet_._0'
)
self
.
drop_1
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
layer_2
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
name
=
'CoreNet_._2'
)
self
.
drop_2
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
inp
):
def
call
(
self
,
inp
,
training
=
False
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
layer_norm
(
inp
)
core_out
=
self
.
layer_1
(
core_out
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
core_out
=
self
.
layer_1
(
inp
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
...
...
@@ -105,102 +106,11 @@ class PositionwiseFF(nn.Module):
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
h
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
class
TFRelPartialLearnableMultiHeadAttn
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
,
**
kwargs
):
super
(
TF
Rel
PartialLearnable
MultiHeadAttn
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
...
...
@@ -208,91 +118,60 @@ class RelMultiHeadAttn(nn.Module):
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
qkv_net
=
nn
.
Linear
(
d_model
,
3
*
n_head
*
d_head
,
bias
=
False
)
self
.
qkv_net
=
tf
.
keras
.
layers
.
Dense
(
3
*
n_head
*
d_head
,
use_
bias
=
False
,
name
=
'qkv_net'
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
dropatt
=
tf
.
keras
.
layers
.
Dropout
(
dropatt
)
self
.
o_net
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
use_
bias
=
False
,
name
=
'o_net'
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
if
r_r_bias
is
not
None
and
r_w_bias
is
not
None
:
# Biases are shared
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
self
.
r_r_bias
=
None
self
.
r_w_bias
=
None
if
left
:
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
self
.
r_net
=
tf
.
keras
.
layers
.
Dense
(
self
.
n_head
*
self
.
d_head
,
use_bias
=
False
,
name
=
'r_net'
)
x_padded_shape
=
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
))
+
x
.
size
()[
2
:]
x_padded
=
x_padded
.
view
(
*
x_padded_shape
)
def
build
(
self
,
input_shape
):
if
self
.
r_r_bias
is
None
or
self
.
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
trainable
=
True
,
name
=
'r_r_bias'
)
self
.
r_w_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
trainable
=
True
,
name
=
'r_w_bias'
)
super
(
TFRelPartialLearnableMultiHeadAttn
,
self
).
build
(
input_shape
)
x
=
x_padded
[
1
:].
view_as
(
x
)
def
_rel_shift
(
self
,
x
):
x_size
=
shape_list
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
x
=
tf
.
pad
(
x
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
]
+
1
,
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
x_size
)
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
raise
NotImplementedError
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
def
call
(
self
,
inputs
,
training
=
False
):
w
,
r
,
attn_mask
,
mems
,
head_mask
=
inputs
qlen
,
rlen
,
bsz
=
shape_list
(
w
)[
0
],
shape_list
(
r
)[
0
],
shape_list
(
w
)[
1
]
if
mems
is
not
None
:
cat
=
t
orch
.
cat
([
mems
,
w
],
0
)
cat
=
t
f
.
con
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
t
orch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
,
w_head_k
,
w_head_v
=
t
f
.
split
(
w_heads
,
3
,
axis
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
...
...
@@ -301,56 +180,52 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
w_heads
=
self
.
qkv_net
(
w
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
t
orch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
,
w_head_k
,
w_head_v
=
t
f
.
split
(
w_heads
,
3
,
axis
=-
1
)
klen
=
w_head_k
.
size
(
0
)
klen
=
shape_list
(
w_head_k
)[
0
]
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_q
=
tf
.
reshape
(
w_head_q
,
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
w_head_k
=
tf
.
reshape
(
w_head_k
,
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
w_head_v
=
tf
.
reshape
(
w_head_v
,
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
r_head_k
=
tf
.
reshape
(
r_head_k
,
(
rlen
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x n_head x d_head
#### compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
t
orch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
)
)
# qlen x klen x bsz x n_head
AC
=
t
f
.
einsum
(
'ibnd,jbnd->ijbn'
,
rw_head_q
,
w_head_k
)
# qlen x klen x bsz x n_head
rr_head_q
=
w_head_q
+
self
.
r_r_bias
BD
=
t
orch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
rr_head_q
,
r_head_k
)
)
# qlen x klen x bsz x n_head
BD
=
t
f
.
einsum
(
'ibnd,jnd->ijbn'
,
rr_head_q
,
r_head_k
)
# qlen x klen x bsz x n_head
BD
=
self
.
_rel_shift
(
BD
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
attn_score
=
attn_score
*
self
.
scale
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[
None
,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
elif
attn_mask
.
dim
()
==
3
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[:,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
if
attn_mask
is
not
None
:
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
,
training
=
training
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
t
orch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
)
)
attn_vec
=
t
f
.
einsum
(
'ijbn,jbnd->ibnd'
,
attn_prob
,
w_head_v
)
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
attn_vec_sizes
=
shape_list
(
attn_vec
)
attn_vec
=
tf
.
reshape
(
attn_vec
,
(
attn_vec_sizes
[
0
],
attn_vec_sizes
[
1
],
self
.
n_head
*
self
.
d_head
))
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
attn_out
=
self
.
drop
(
attn_out
,
training
=
training
)
if
self
.
pre_lnorm
:
##### residual connection
...
...
@@ -364,166 +239,40 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
TFRelPartialLearnableDecoderLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
dropatt
=
0.
,
pre_lnorm
=
False
,
r_w_bias
=
None
,
r_r_bias
=
None
,
output_attentions
=
False
,
**
kwargs
):
super
(
RelPartialLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
super
(
TFRelPartialLearnableDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
dec_attn
=
TFRelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
r_w_bias
=
r_w_bias
,
r_r_bias
=
r_r_bias
,
output_attentions
=
output_attentions
,
name
=
'dec_attn'
)
self
.
pos_ff
=
TFPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
pre_lnorm
,
name
=
'pos_ff'
)
def
call
(
self
,
inputs
,
training
=
False
):
dec_inp
,
r
,
dec_attn_mask
,
mems
,
head_mask
=
inputs
attn_outputs
=
self
.
dec_attn
([
dec_inp
,
r
,
dec_attn_mask
,
mems
,
head_mask
],
training
=
training
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
],
training
=
training
)
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
AdaptiveEmbedding
(
nn
.
Module
):
class
TFAdaptiveEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
super
(
AdaptiveEmbedding
,
self
).
__init__
()
sample_softmax
=
False
,
**
kwargs
):
super
(
TF
AdaptiveEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
...
...
@@ -536,188 +285,53 @@ class AdaptiveEmbedding(nn.Module):
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_projs
=
nn
.
ParameterList
()
self
.
emb_layers
=
[]
self
.
emb_projs
=
[]
if
div_val
==
1
:
self
.
emb_layers
.
append
(
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_embed
)))
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_emb_i
)))
self
.
emb_layers
.
append
(
tf
.
keras
.
layers
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
,
name
=
'emb_layers_._{}'
.
format
(
i
)))
def
build
(
self
,
input_shape
):
for
i
in
range
(
len
(
self
.
cutoffs
)):
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
self
.
emb_projs
.
append
(
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
trainable
=
True
,
name
=
'emb_projs._{}'
.
format
(
i
)))
super
(
TFAdaptiveEmbedding
,
self
).
build
(
input_shape
)
def
forward
(
self
,
inp
):
def
call
(
self
,
inp
):
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
])
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
inp
.
view
(
-
1
)
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
inp_flat
=
tf
.
reshape
(
inp
,
(
-
1
,))
emb_flat
=
tf
.
zeros
([
shape_list
(
inp_flat
)[
0
],
self
.
d_proj
])
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
mask_i
=
(
inp_flat
>=
l_idx
)
&
(
inp_flat
<
r_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
inp_i
=
inp_flat
.
index_select
(
0
,
indices
_i
)
-
l_idx
inp_i
=
tf
.
boolean_mask
(
inp_flat
,
mask
_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
])
emb_i
=
tf
.
einsum
(
'id,de->ie'
,
emb_i
,
self
.
emb_projs
[
i
])
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
mask_idx
=
tf
.
cast
(
tf
.
where
(
mask_i
),
dtype
=
tf
.
int64
)
emb_flat
+=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
tf
.
cast
(
tf
.
shape
(
emb_flat
),
dtype
=
tf
.
int64
))
embed_shape
=
inp
.
size
(
)
+
(
self
.
d_proj
,)
embed
=
emb_flat
.
view
(
embed_shape
)
embed_shape
=
shape_list
(
inp
)
+
[
self
.
d_proj
]
embed
=
tf
.
reshape
(
emb_flat
,
embed_shape
)
embed
.
mul_
(
self
.
emb_scale
)
embed
*=
self
.
emb_scale
return
embed
class
TransfoXLPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
_init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
_init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
_init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
_init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
_init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
_init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
else
:
if
hasattr
(
m
,
'r_emb'
):
self
.
_init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
_init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
_init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
_init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
class
TFTransfoXLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFTransfoXLMainLayer
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
...
...
@@ -727,11 +341,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
d_model
=
config
.
d_model
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
untie_r
=
config
.
untie_r
self
.
word_emb
=
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
word_emb
=
TF
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
,
name
=
'word_emb'
)
self
.
drop
=
nn
.
Dropout
(
config
.
dropout
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
n_layer
=
config
.
n_layer
...
...
@@ -742,61 +357,41 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
if
not
config
.
untie_r
:
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
[]
if
config
.
attn_type
==
0
:
# the default attention
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelPartialLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
TFRelPartialLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
r_w_bias
=
None
if
self
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
self
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
,
name
=
'layers_._{}'
.
format
(
i
))
)
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
1
:
# learnable
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
init_weights
()
self
.
pos_emb
=
TFPositionalEmbedding
(
self
.
d_model
,
name
=
'pos_emb'
)
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
def
build
(
self
,
input_shape
):
if
not
self
.
untie_r
:
self
.
r_w_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'r_w_bias'
)
self
.
r_r_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'r_r_bias'
)
super
(
TFTransfoXLMainLayer
,
self
).
build
(
input_shape
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
...
...
@@ -810,16 +405,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
ext_len
=
ext_len
def
_prune_heads
(
self
,
heads
):
logger
.
info
(
"Head pruning is not implemented for Transformer-XL model"
)
pass
raise
NotImplementedError
def
init_mems
(
self
,
data
):
if
self
.
mem_len
>
0
:
mems
=
[]
param
=
next
(
self
.
parameters
())
for
i
in
range
(
self
.
n_layer
):
empty
=
torch
.
zeros
(
self
.
mem_len
,
data
.
size
(
1
),
self
.
config
.
d_model
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
empty
=
tf
.
zeros
([
self
.
mem_len
,
shape_list
(
data
)[
1
],
self
.
d_model
])
mems
.
append
(
empty
)
return
mems
...
...
@@ -838,164 +430,211 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with
torch
.
no_grad
():
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
cat
=
torch
.
cat
([
mems
[
i
],
hids
[
i
]],
dim
=
0
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
].
detach
())
cat
=
tf
.
concat
([
mems
[
i
],
hids
[
i
]],
axis
=
0
)
tf
.
stop_gradient
(
cat
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
])
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mems
,
head_mask
=
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
qlen
,
bsz
=
shape_list
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
if
not
head_mask
is
None
:
raise
NotImplementedError
else
:
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
word_emb
=
self
.
word_emb
(
input_ids
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
attn_mask
=
tf
.
ones
([
qlen
,
qlen
])
mask_u
=
tf
.
linalg
.
band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
linalg
.
band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
])
dec_attn_mask
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
self
.
same_length
:
all_ones
=
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
)
mask_len
=
klen
-
self
.
mem_len
if
mask_len
>
0
:
mask_shift_len
=
qlen
-
mask_len
else
:
mask_shift_len
=
qlen
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
))[:,
:,
None
]
# -1
else
:
dec_attn_mask
=
torch
.
triu
(
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
),
diagonal
=
1
+
mlen
)[:,:,
None
]
mask_l
=
tf
.
linalg
.
band_part
(
attn_mask
,
-
1
,
0
)
dec_attn_mask
=
tf
.
concat
([
dec_attn_mask
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
dec_attn_mask
[:,
qlen
:]],
1
)
# ::: PyTorch masking code for reference :::
# if self.same_length:
# all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
# mask_len = klen - self.mem_len
# if mask_len > 0:
# mask_shift_len = qlen - mask_len
# else:
# mask_shift_len = qlen
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
# else:
# dec_attn_mask = torch.triu(
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids
=
[]
attentions
=
[]
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
)
pos_emb
=
self
.
drop
(
pos_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
pos_seq
=
tf
.
range
(
klen
-
1
,
-
1
,
-
1.0
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_seq
=
tf
.
minimum
(
pos_seq
,
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
core_out
=
self
.
drop
(
word_emb
,
training
=
training
)
pos_emb
=
self
.
drop
(
pos_emb
,
training
=
training
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
layer_outputs
=
layer
([
core_out
,
pos_emb
,
dec_attn_mask
,
mems_i
,
head_mask
[
i
]],
training
=
training
)
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
,
training
=
training
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
# We transpose back here to shape [bsz, len, hidden_dim]
outputs
=
[
core_ou
t
.
transpose
(
0
,
1
).
contiguous
(
),
new_mems
]
outputs
=
[
t
f
.
transpose
(
core_out
,
perm
=
(
1
,
0
,
2
)
),
new_mems
]
if
self
.
output_hidden_states
:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids
.
append
(
core_out
)
hids
=
list
(
t
.
transpose
(
0
,
1
).
contiguous
(
)
for
t
in
hids
)
hids
=
list
(
t
f
.
transpose
(
t
,
perm
=
(
1
,
0
,
2
)
)
for
t
in
hids
)
outputs
.
append
(
hids
)
if
self
.
output_attentions
:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
(
)
for
t
in
attentions
)
attentions
=
list
(
t
f
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
)
)
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
outputs
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
class
TFTransfoXLPreTrainedModel
(
TFPreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_transfo_xl_pt_weights_in_tf2
base_model_prefix
=
"transformer"
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.tf.keras.layers.Layer`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.tf.keras.layers.Layer`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TFTransfoXLModel
(
TFTransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFTransfoXLModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFTransfoXLMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
,
**
kwargs
)
return
outputs
@
add_start_docstrings
(
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
class
TF
TransfoXLLMHeadModel
(
TF
TransfoXLPreTrainedModel
):
r
"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
...
...
@@ -1032,46 +671,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
TransfoXLM
odel
(
config
)
super
(
TF
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
TF
TransfoXLM
ainLayer
(
config
,
name
=
'transformer'
)
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
raise
NotImplementedError
# use adaptive softmax (including standard softmax)
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
"""
Run this to be sure output and input (adaptive) softmax weights are tied
"""
# sampled softmax
if
self
.
sample_softmax
>
0
:
if
self
.
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
transformer
.
word_emb
.
weight
# adaptive softmax (including standard softmax)
else
:
if
self
.
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
_tie_or_clone_weights
(
self
.
crit
.
out_layers
[
i
],
self
.
transformer
.
word_emb
.
emb_layers
[
i
])
if
self
.
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
self
.
config
.
tie_projs
):
if
tie_proj
and
self
.
config
.
div_val
==
1
and
self
.
config
.
d_model
!=
self
.
config
.
d_embed
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
0
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
self
.
config
.
div_val
!=
1
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
i
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
i
]
self
.
crit
=
TFAdaptiveSoftmaxMask
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
,
name
=
'crit'
)
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
...
...
@@ -1079,30 +688,36 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
,
labels
=
None
):
bsz
=
input_ids
.
size
(
0
)
tgt_len
=
input_ids
.
size
(
1
)
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mems
,
head_mask
,
labels
=
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
labels
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
labels
=
inputs
.
get
(
'labels'
,
None
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
bsz
,
tgt_len
=
shape_list
(
input_ids
)[:
2
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
transformer_outputs
=
self
.
transformer
(
[
input_ids
,
mems
,
head_mask
],
training
=
training
)
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
outputs
=
transformer_outputs
[
1
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
assert
self
.
config
.
tie_weight
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
outputs
=
[
softmax_output
]
+
outputs
if
labels
is
not
None
:
# TODO: This is not implemented
raise
NotImplementedError
if
self
.
sample_softmax
>
0
and
training
:
raise
NotImplementedError
else
:
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
if
labels
is
None
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
outputs
=
[
softmax_output
]
+
outputs
else
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
outputs
=
[
softmax_output
,
None
]
+
outputs
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
softmax_output
=
self
.
crit
([
pred_hid
,
labels
],
training
=
training
)
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
outputs
=
[
softmax_output
]
+
outputs
return
outputs
#
(loss), logits or None if labels is not None (speed up adaptive softmax)
, new_mems, (all hidden states), (all attentions)
return
outputs
#
logits
, new_mems, (all hidden states), (all attentions)
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
"""
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
from
.modeling_tf_utils
import
shape_list
class
TFAdaptiveSoftmaxMask
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
,
**
kwargs
):
super
(
TFAdaptiveSoftmaxMask
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_proj
=
d_proj
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
div_val
=
div_val
self
.
shortlist_size
=
self
.
cutoffs
[
0
]
self
.
n_clusters
=
len
(
self
.
cutoffs
)
-
1
self
.
head_size
=
self
.
shortlist_size
+
self
.
n_clusters
self
.
keep_order
=
keep_order
self
.
out_layers
=
[]
self
.
out_projs
=
[]
def
build
(
self
,
input_shape
):
if
self
.
n_clusters
>
0
:
self
.
cluster_weight
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,
self
.
d_embed
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_weight'
)
self
.
cluster_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_bias'
)
if
self
.
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
self
.
d_proj
!=
self
.
d_embed
:
weight
=
self
.
add_weight
(
shape
=
(
self
.
d_embed
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
else
:
self
.
out_projs
.
append
(
None
)
weight
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,
self
.
d_embed
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
weight
=
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
weight
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,
d_emb_i
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
super
(
TFAdaptiveSoftmaxMask
,
self
).
build
(
input_shape
)
@
staticmethod
def
_logit
(
x
,
W
,
b
,
proj
=
None
):
y
=
x
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
@
staticmethod
def
_gather_logprob
(
logprob
,
target
):
lp_size
=
tf
.
shape
(
logprob
)
r
=
tf
.
range
(
lp_size
[
0
])
idx
=
tf
.
stack
([
r
,
target
],
1
)
return
tf
.
gather_nd
(
logprob
,
idx
)
def
call
(
self
,
inputs
,
return_mean
=
True
,
training
=
False
):
hidden
,
target
=
inputs
head_logprob
=
0
if
self
.
n_clusters
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
self
.
_logit
(
hidden
,
self
.
out_layers
[
0
][
0
],
self
.
out_layers
[
0
][
1
],
self
.
out_projs
[
0
])
if
target
is
not
None
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
out
=
tf
.
nn
.
log_softmax
(
output
,
axis
=-
1
)
else
:
hidden_sizes
=
shape_list
(
hidden
)
out
=
[]
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
],
dtype
=
tf
.
float32
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
if
target
is
not
None
:
mask
=
(
target
>=
l_idx
)
&
(
target
<
r_idx
)
mask_idx
=
tf
.
where
(
mask
)
cur_target
=
tf
.
boolean_mask
(
target
,
mask
)
-
l_idx
if
self
.
div_val
==
1
:
cur_W
=
self
.
out_layers
[
0
][
0
][
l_idx
:
r_idx
]
cur_b
=
self
.
out_layers
[
0
][
1
][
l_idx
:
r_idx
]
else
:
cur_W
=
self
.
out_layers
[
i
][
0
]
cur_b
=
self
.
out_layers
[
i
][
1
]
if
i
==
0
:
cur_W
=
tf
.
concat
([
cur_W
,
self
.
cluster_weight
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
self
.
cluster_bias
],
0
)
head_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
0
])
head_logprob
=
tf
.
nn
.
log_softmax
(
head_logit
)
out
.
append
(
head_logprob
[...,
:
self
.
cutoffs
[
0
]])
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_head_logprob
,
cur_target
)
else
:
tail_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
i
])
tail_logprob
=
tf
.
nn
.
log_softmax
(
tail_logit
)
cluster_prob_idx
=
self
.
cutoffs
[
0
]
+
i
-
1
# No probability for the head cluster
logprob_i
=
head_logprob
[...,
cluster_prob_idx
,
None
]
+
tail_logprob
out
.
append
(
logprob_i
)
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_tail_logprob
=
tf
.
boolean_mask
(
tail_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
if
target
is
not
None
:
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
tf
.
cast
(
tf
.
shape
(
loss
),
dtype
=
tf
.
int64
))
out
=
tf
.
concat
(
out
,
axis
=-
1
)
if
target
is
not
None
:
if
return_mean
:
loss
=
tf
.
reduce_mean
(
loss
)
# Add the training-time loss value to the layer using `self.add_loss()`.
self
.
add_loss
(
loss
)
# Log the loss as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.
self
.
add_metric
(
loss
,
name
=
self
.
name
,
aggregation
=
'mean'
if
return_mean
else
''
)
return
out
def
mul_adaptive_logsoftmax
(
hidden
,
target
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
params
,
tie_projs
,
initializer
=
None
,
proj_initializer
=
None
,
div_val
=
1
,
perms
=
None
,
proj_same_dim
=
True
,
scope
=
'adaptive_softmax'
,
**
kwargs
):
def
_logit
(
x
,
W
,
b
,
proj
):
y
=
x
if
x
.
shape
.
ndims
==
3
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
else
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'id,ed->ie'
,
y
,
proj
)
return
tf
.
einsum
(
'id,nd->in'
,
y
,
W
)
+
b
params_W
,
params_projs
=
params
[
0
],
params
[
1
]
with
tf
.
variable_scope
(
scope
):
if
len
(
cutoffs
)
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
_logit
(
hidden
,
params_W
,
softmax_b
,
params_projs
)
nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
nll
=
tf
.
reduce_mean
(
nll
)
else
:
total_loss
,
total_cnt
=
0
,
0
cutoff_ends
=
[
0
]
+
cutoffs
+
[
n_token
]
for
i
in
range
(
len
(
cutoff_ends
)
-
1
):
with
tf
.
variable_scope
(
'cutoff_{}'
.
format
(
i
)):
l_idx
,
r_idx
=
cutoff_ends
[
i
],
cutoff_ends
[
i
+
1
]
cur_d_embed
=
d_embed
//
(
div_val
**
i
)
if
div_val
==
1
:
cur_W
=
params_W
[
l_idx
:
r_idx
]
else
:
cur_W
=
params_W
[
i
]
cur_b
=
tf
.
get_variable
(
'b'
,
[
r_idx
-
l_idx
],
initializer
=
tf
.
zeros_initializer
())
if
tie_projs
[
i
]:
if
div_val
==
1
:
cur_proj
=
params_projs
else
:
cur_proj
=
params_projs
[
i
]
else
:
if
(
div_val
==
1
or
not
proj_same_dim
)
and
d_proj
==
cur_d_embed
:
cur_proj
=
None
else
:
cur_proj
=
tf
.
get_variable
(
'proj'
,
[
cur_d_embed
,
d_proj
],
initializer
=
proj_initializer
)
if
i
==
0
:
cluster_W
=
tf
.
get_variable
(
'cluster_W'
,
[
len
(
cutoffs
),
d_embed
],
initializer
=
tf
.
zeros_initializer
())
cluster_b
=
tf
.
get_variable
(
'cluster_b'
,
[
len
(
cutoffs
)],
initializer
=
tf
.
zeros_initializer
())
cur_W
=
tf
.
concat
([
cur_W
,
cluster_W
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
cluster_b
],
0
)
head_logit
=
_logit
(
hidden
,
cur_W
,
cur_b
,
cur_proj
)
head_target
=
kwargs
.
get
(
"head_target"
)
head_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
head_target
,
logits
=
head_logit
)
masked_loss
=
head_nll
*
perms
[
i
]
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
perms
[
i
])
# head_logprob = tf.nn.log_softmax(head_logit)
# final_logprob = head_logprob * perms[i][:, :, None]
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
# total_cnt += tf.reduce_sum(perms[i])
else
:
cur_head_nll
=
tf
.
einsum
(
'ib,ibk->k'
,
head_nll
,
perms
[
i
])
cur_hidden
=
tf
.
einsum
(
'ibd,ibk->kd'
,
hidden
,
perms
[
i
])
tail_logit
=
_logit
(
cur_hidden
,
cur_W
,
cur_b
,
cur_proj
)
tail_target
=
tf
.
einsum
(
'ib,ibk->k'
,
tf
.
to_float
(
target
-
l_idx
),
perms
[
i
])
tail_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
to_int32
(
tail_target
),
logits
=
tail_logit
)
sum_nll
=
cur_head_nll
+
tail_nll
mask
=
tf
.
reduce_sum
(
perms
[
i
],
[
0
,
1
])
masked_loss
=
sum_nll
*
mask
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
mask
)
nll
=
total_loss
/
total_cnt
return
nll
\ No newline at end of file
pytorch_transformers/modeling_tf_xlm.py
View file @
65c49bb2
...
...
@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self
.
ffns
=
[]
self
.
layer_norm2
=
[]
# if self.is_decoder:
# self.layer_norm15 =
tf.keras.layers.LayerList()
# self.encoder_attn =
tf.keras.layers.LayerList()
# self.layer_norm15 =
[]
# self.encoder_attn =
[]
for
i
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_._{}'
.
format
(
i
)))
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
65c49bb2
...
...
@@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module):
return
output
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
h
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
class
RelPartialLearnableMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
Rel
PartialLearnable
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
...
...
@@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module):
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
left
:
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
def
_rel_shift
(
self
,
x
):
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
...
...
@@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module):
x
=
x_padded
[
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
raise
NotImplementedError
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
...
...
@@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
...
@@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module):
return
outputs
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
...
...
@@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
if
hasattr
(
m
,
'r_bias'
):
self
.
_init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
...
...
@@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
else
:
# learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
raise
NotImplementedError
# Removed them to avoid maintaining dead code
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
1
:
# learnable
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self
.
init_weights
()
...
...
@@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
qlen
,
bsz
=
input_ids
.
size
()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
...
...
@@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
word_emb
=
self
.
word_emb
(
input_ids
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
...
...
@@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out
=
self
.
drop
(
core_out
)
...
...
@@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
outputs
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
...
...
pytorch_transformers/tests/modeling_tf_bert_test.py
View file @
65c49bb2
...
...
@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFBertModel
(
config
=
config
)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input
_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_id
s
)
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
...
...
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
random
import
shutil
import
pytest
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
pytorch_transformers
import
TransfoXLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_transfo_xl
import
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFTransfoXLModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
)
if
is_tf_available
()
else
()
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
False
class
TFTransfoXLModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
d_embed
=
32
,
num_attention_heads
=
4
,
d_head
=
8
,
d_inner
=
128
,
div_val
=
2
,
num_hidden_layers
=
5
,
scope
=
None
,
seed
=
1
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
num_attention_heads
=
num_attention_heads
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
num_hidden_layers
=
num_hidden_layers
self
.
scope
=
scope
self
.
seed
=
seed
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
num_attention_heads
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
num_hidden_layers
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLModel
(
config
)
hidden_states_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_2
,
'mems'
:
mems_1
}
hidden_states_2
,
mems_2
=
model
(
inputs
)
result
=
{
"hidden_states_1"
:
hidden_states_1
.
numpy
(),
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"hidden_states_2"
:
hidden_states_2
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLLMHeadModel
(
config
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'labels'
:
lm_labels
}
_
,
mems_1
=
model
(
inputs
)
lm_logits_2
,
mems_2
=
model
([
input_ids_2
,
mems_1
])
inputs
=
{
'input_ids'
:
input_ids_1
,
'mems'
:
mems_1
,
'labels'
:
lm_labels
}
_
,
mems_2
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"lm_logits_1"
:
lm_logits_1
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
"lm_logits_2"
:
lm_logits_2
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFTransfoXLModelTest
.
TFTransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_transfo_xl_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_model
(
*
config_and_inputs
)
def
test_transfo_xl_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_lm_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFTransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment