Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
65c49bb2
Commit
65c49bb2
authored
Sep 13, 2019
by
thomwolf
Browse files
adding TF 2.0 adaptive softmax with logits + loss outputs
parent
39c38b2e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
869 additions
and
1105 deletions
+869
-1105
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+2
-0
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+344
-729
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
+279
-0
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+2
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+20
-373
pytorch_transformers/tests/modeling_tf_bert_test.py
pytorch_transformers/tests/modeling_tf_bert_test.py
+5
-1
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
+217
-0
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
65c49bb2
...
@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
...
@@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
input_ids
=
inputs
...
...
pytorch_transformers/modeling_tf_transfo_xl.py
View file @
65c49bb2
...
@@ -30,8 +30,8 @@ import numpy as np
...
@@ -30,8 +30,8 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.configuration_transfo_xl
import
TransfoXLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
from
.modeling_transfo_xl_utilities
import
Projected
Adaptive
Log
Softmax
,
sample_logits
from
.modeling_
tf_
transfo_xl_utilities
import
TF
AdaptiveSoftmax
Mask
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
...
@@ -49,55 +49,56 @@ def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
...
@@ -49,55 +49,56 @@ def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path):
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
PositionalEmbedding
(
nn
.
Module
):
class
TF
PositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
demb
):
def
__init__
(
self
,
demb
,
**
kwargs
):
super
(
PositionalEmbedding
,
self
).
__init__
()
super
(
TF
PositionalEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
demb
=
demb
self
.
inv_freq
=
1
/
(
10000
**
(
tf
.
range
(
0
,
demb
,
2.0
)
/
demb
))
inv_freq
=
1
/
(
10000
**
(
torch
.
arange
(
0.0
,
demb
,
2.0
)
/
demb
))
def
call
(
self
,
pos_seq
,
bsz
=
None
):
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
sinusoid_inp
=
tf
.
einsum
(
'i,j->ij'
,
pos_seq
,
self
.
inv_freq
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
def
forward
(
self
,
pos_seq
,
bsz
=
None
):
sinusoid_inp
=
torch
.
ger
(
pos_seq
,
self
.
inv_freq
)
pos_emb
=
torch
.
cat
([
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()],
dim
=-
1
)
if
bsz
is
not
None
:
if
bsz
is
not
None
:
return
pos_emb
[:,
None
,:]
.
expand
(
-
1
,
bsz
,
-
1
)
return
tf
.
tile
(
pos_emb
[:,
None
,
:]
,
[
1
,
bsz
,
1
]
)
else
:
else
:
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,
:]
class
PositionwiseFF
(
nn
.
Module
):
class
TF
PositionwiseFF
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
**
kwargs
):
super
(
PositionwiseFF
,
self
).
__init__
()
super
(
TF
PositionwiseFF
,
self
).
__init__
(
**
kwargs
)
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
CoreNet
=
nn
.
Sequential
(
self
.
layer_1
=
tf
.
keras
.
layers
.
Dense
(
d_inner
,
activation
=
tf
.
nn
.
relu
,
name
=
'CoreNet_._0'
)
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
self
.
drop_1
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
),
self
.
layer_2
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
name
=
'CoreNet_._2'
)
nn
.
Linear
(
d_inner
,
d_model
),
self
.
drop_2
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
nn
.
Dropout
(
dropout
),
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
inp
):
def
call
(
self
,
inp
,
training
=
False
):
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
##### layer normalization + positionwise feed-forward
core_out
=
self
.
CoreNet
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
layer_norm
(
inp
)
core_out
=
self
.
layer_1
(
core_out
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
##### residual connection
##### residual connection
output
=
core_out
+
inp
output
=
core_out
+
inp
else
:
else
:
##### positionwise feed-forward
##### positionwise feed-forward
core_out
=
self
.
CoreNet
(
inp
)
core_out
=
self
.
layer_1
(
inp
)
core_out
=
self
.
drop_1
(
core_out
,
training
=
training
)
core_out
=
self
.
layer_2
(
core_out
)
core_out
=
self
.
drop_2
(
core_out
,
training
=
training
)
##### residual connection + layer normalization
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
output
=
self
.
layer_norm
(
inp
+
core_out
)
...
@@ -105,102 +106,11 @@ class PositionwiseFF(nn.Module):
...
@@ -105,102 +106,11 @@ class PositionwiseFF(nn.Module):
return
output
return
output
class
TFRelPartialLearnableMultiHeadAttn
(
tf
.
keras
.
layers
.
Layer
):
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
h
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
,
**
kwargs
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
TF
Rel
PartialLearnable
MultiHeadAttn
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -208,91 +118,60 @@ class RelMultiHeadAttn(nn.Module):
...
@@ -208,91 +118,60 @@ class RelMultiHeadAttn(nn.Module):
self
.
d_head
=
d_head
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
qkv_net
=
nn
.
Linear
(
d_model
,
3
*
n_head
*
d_head
,
bias
=
False
)
self
.
qkv_net
=
tf
.
keras
.
layers
.
Dense
(
3
*
n_head
*
d_head
,
use_
bias
=
False
,
name
=
'qkv_net'
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
dropatt
=
tf
.
keras
.
layers
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
o_net
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
use_
bias
=
False
,
name
=
'o_net'
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
if
r_r_bias
is
not
None
and
r_w_bias
is
not
None
:
# Biases are shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
self
.
r_w_bias
=
r_w_bias
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
self
.
r_r_bias
=
None
self
.
r_w_bias
=
None
if
left
:
self
.
r_net
=
tf
.
keras
.
layers
.
Dense
(
self
.
n_head
*
self
.
d_head
,
use_bias
=
False
,
name
=
'r_net'
)
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded_shape
=
(
x
.
size
(
1
)
+
1
,
x
.
size
(
0
))
+
x
.
size
()[
2
:]
def
build
(
self
,
input_shape
):
x_padded
=
x_padded
.
view
(
*
x_padded_shape
)
if
self
.
r_r_bias
is
None
or
self
.
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
trainable
=
True
,
name
=
'r_r_bias'
)
self
.
r_w_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
trainable
=
True
,
name
=
'r_w_bias'
)
super
(
TFRelPartialLearnableMultiHeadAttn
,
self
).
build
(
input_shape
)
x
=
x_padded
[
1
:].
view_as
(
x
)
def
_rel_shift
(
self
,
x
):
x_size
=
shape_list
(
x
)
if
zero_triu
:
x
=
tf
.
pad
(
x
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
],
[
0
,
0
]])
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
tf
.
reshape
(
x
,
[
x_size
[
1
]
+
1
,
x_size
[
0
],
x_size
[
2
],
x_size
[
3
]])
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
x
=
tf
.
slice
(
x
,
[
1
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
-
1
])
x
=
tf
.
reshape
(
x
,
x_size
)
return
x
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
def
call
(
self
,
inputs
,
training
=
False
):
raise
NotImplementedError
w
,
r
,
attn_mask
,
mems
,
head_mask
=
inputs
qlen
,
rlen
,
bsz
=
shape_list
(
w
)[
0
],
shape_list
(
r
)[
0
],
shape_list
(
w
)[
1
]
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
if
mems
is
not
None
:
cat
=
t
orch
.
cat
([
mems
,
w
],
0
)
cat
=
t
f
.
con
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_heads
=
self
.
qkv_net
(
cat
)
r_head_k
=
self
.
r_net
(
r
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
t
orch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
,
w_head_k
,
w_head_v
=
t
f
.
split
(
w_heads
,
3
,
axis
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
else
:
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
...
@@ -301,56 +180,52 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -301,56 +180,52 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
w_heads
=
self
.
qkv_net
(
w
)
w_heads
=
self
.
qkv_net
(
w
)
r_head_k
=
self
.
r_net
(
r
)
r_head_k
=
self
.
r_net
(
r
)
w_head_q
,
w_head_k
,
w_head_v
=
t
orch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
,
w_head_k
,
w_head_v
=
t
f
.
split
(
w_heads
,
3
,
axis
=-
1
)
klen
=
w_head_k
.
size
(
0
)
klen
=
shape_list
(
w_head_k
)[
0
]
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_q
=
tf
.
reshape
(
w_head_q
,
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_k
=
tf
.
reshape
(
w_head_k
,
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
# qlen x bsz x n_head x d_head
w_head_v
=
tf
.
reshape
(
w_head_v
,
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x bsz x n_head x d_head
r_head_k
=
r_head_k
.
view
(
rlen
,
self
.
n_head
,
self
.
d_head
)
# qlen x n_head x d_head
r_head_k
=
tf
.
reshape
(
r_head_k
,
(
rlen
,
self
.
n_head
,
self
.
d_head
)
)
# qlen x n_head x d_head
#### compute attention score
#### compute attention score
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
rw_head_q
=
w_head_q
+
self
.
r_w_bias
# qlen x bsz x n_head x d_head
AC
=
t
orch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
)
)
# qlen x klen x bsz x n_head
AC
=
t
f
.
einsum
(
'ibnd,jbnd->ijbn'
,
rw_head_q
,
w_head_k
)
# qlen x klen x bsz x n_head
rr_head_q
=
w_head_q
+
self
.
r_r_bias
rr_head_q
=
w_head_q
+
self
.
r_r_bias
BD
=
t
orch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
rr_head_q
,
r_head_k
)
)
# qlen x klen x bsz x n_head
BD
=
t
f
.
einsum
(
'ibnd,jnd->ijbn'
,
rr_head_q
,
r_head_k
)
# qlen x klen x bsz x n_head
BD
=
self
.
_rel_shift
(
BD
)
BD
=
self
.
_rel_shift
(
BD
)
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
attn_score
=
attn_score
*
self
.
scale
#### compute attention probability
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
if
attn_mask
is
not
None
:
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
if
attn_mask
.
dim
()
==
2
:
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[
None
,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
elif
attn_mask
.
dim
()
==
3
:
attn_score
=
attn_score
.
float
().
masked_fill
(
attn_mask
[:,:,:,
None
],
-
1e30
).
type_as
(
attn_score
)
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
attn_prob
=
self
.
dropatt
(
attn_prob
,
training
=
training
)
# Mask heads if we want to
# Mask heads if we want to
if
head_mask
is
not
None
:
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
#### compute attention vector
attn_vec
=
t
orch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
)
)
attn_vec
=
t
f
.
einsum
(
'ijbn,jbnd->ibnd'
,
attn_prob
,
w_head_v
)
# [qlen x bsz x n_head x d_head]
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec_sizes
=
shape_list
(
attn_vec
)
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
attn_vec
=
tf
.
reshape
(
attn_vec
,
(
attn_vec_sizes
[
0
],
attn_vec_sizes
[
1
],
self
.
n_head
*
self
.
d_head
))
##### linear projection
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
attn_out
=
self
.
drop
(
attn_out
,
training
=
training
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
##### residual connection
##### residual connection
...
@@ -364,166 +239,40 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -364,166 +239,40 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return
outputs
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
class
TFRelPartialLearnableDecoderLayer
(
tf
.
keras
.
layers
.
Layer
):
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
dropatt
=
0.
,
pre_lnorm
=
False
,
r_w_bias
=
None
,
r_r_bias
=
None
,
output_attentions
=
False
,
**
kwargs
):
**
kwargs
):
super
(
RelPartialLearnableDecoderLayer
,
self
).
__init__
()
super
(
TFRelPartialLearnableDecoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
self
.
dec_attn
=
TFRelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
r_w_bias
=
r_w_bias
,
r_r_bias
=
r_r_bias
,
output_attentions
=
output_attentions
,
name
=
'dec_attn'
)
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
self
.
pos_ff
=
TFPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
pre_lnorm
,
name
=
'pos_ff'
)
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r
,
attn_mask
=
dec_attn_mask
,
def
call
(
self
,
inputs
,
training
=
False
):
mems
=
mems
,
head_mask
=
head_mask
)
dec_inp
,
r
,
dec_attn_mask
,
mems
,
head_mask
=
inputs
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
attn_outputs
=
self
.
dec_attn
([
dec_inp
,
r
,
dec_attn_mask
,
mems
,
head_mask
],
training
=
training
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
],
training
=
training
)
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
return
outputs
class
TFAdaptiveEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
sample_softmax
=
False
,
**
kwargs
):
super
(
AdaptiveEmbedding
,
self
).
__init__
()
super
(
TF
AdaptiveEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_embed
=
d_embed
...
@@ -536,188 +285,53 @@ class AdaptiveEmbedding(nn.Module):
...
@@ -536,188 +285,53 @@ class AdaptiveEmbedding(nn.Module):
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
emb_layers
=
nn
.
ModuleList
()
self
.
emb_layers
=
[]
self
.
emb_projs
=
nn
.
ParameterList
()
self
.
emb_projs
=
[]
if
div_val
==
1
:
if
div_val
==
1
:
self
.
emb_layers
.
append
(
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
nn
.
Embedding
(
n_token
,
d_embed
,
sparse
=
sample_softmax
>
0
)
)
if
d_proj
!=
d_embed
:
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_embed
)))
else
:
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
nn
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
))
self
.
emb_layers
.
append
(
tf
.
keras
.
layers
.
Embedding
(
r_idx
-
l_idx
,
d_emb_i
,
name
=
'emb_layers_._{}'
.
format
(
i
)))
self
.
emb_projs
.
append
(
nn
.
Parameter
(
torch
.
FloatTensor
(
d_proj
,
d_emb_i
)))
def
build
(
self
,
input_shape
):
for
i
in
range
(
len
(
self
.
cutoffs
)):
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
self
.
emb_projs
.
append
(
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
trainable
=
True
,
name
=
'emb_projs._{}'
.
format
(
i
)))
super
(
TFAdaptiveEmbedding
,
self
).
build
(
input_shape
)
def
forward
(
self
,
inp
):
def
call
(
self
,
inp
):
if
self
.
div_val
==
1
:
if
self
.
div_val
==
1
:
embed
=
self
.
emb_layers
[
0
](
inp
)
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
if
self
.
d_proj
!=
self
.
d_embed
:
embed
=
F
.
linear
(
embed
,
self
.
emb_projs
[
0
])
else
:
else
:
param
=
next
(
self
.
parameters
())
inp_flat
=
tf
.
reshape
(
inp
,
(
-
1
,))
inp_flat
=
inp
.
view
(
-
1
)
emb_flat
=
tf
.
zeros
([
shape_list
(
inp_flat
)[
0
],
self
.
d_proj
])
emb_flat
=
torch
.
zeros
([
inp_flat
.
size
(
0
),
self
.
d_proj
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
mask_i
=
(
inp_flat
>=
l_idx
)
&
(
inp_flat
<
r_idx
)
mask_i
=
(
inp_flat
>=
l_idx
)
&
(
inp_flat
<
r_idx
)
indices_i
=
mask_i
.
nonzero
().
squeeze
()
if
indices_i
.
numel
()
==
0
:
continue
inp_i
=
inp_flat
.
index_select
(
0
,
indices
_i
)
-
l_idx
inp_i
=
tf
.
boolean_mask
(
inp_flat
,
mask
_i
)
-
l_idx
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
F
.
linear
(
emb_i
,
self
.
emb_projs
[
i
])
emb_i
=
tf
.
einsum
(
'id,de->ie'
,
emb_i
,
self
.
emb_projs
[
i
])
emb_flat
.
index_copy_
(
0
,
indices_i
,
emb_i
)
mask_idx
=
tf
.
cast
(
tf
.
where
(
mask_i
),
dtype
=
tf
.
int64
)
emb_flat
+=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
tf
.
cast
(
tf
.
shape
(
emb_flat
),
dtype
=
tf
.
int64
))
embed_shape
=
inp
.
size
(
)
+
(
self
.
d_proj
,)
embed_shape
=
shape_list
(
inp
)
+
[
self
.
d_proj
]
embed
=
emb_flat
.
view
(
embed_shape
)
embed
=
tf
.
reshape
(
emb_flat
,
embed_shape
)
embed
.
mul_
(
self
.
emb_scale
)
embed
*=
self
.
emb_scale
return
embed
return
embed
class
TransfoXLPreTrainedModel
(
PreTrainedModel
):
class
TFTransfoXLMainLayer
(
tf
.
keras
.
layers
.
Layer
):
""" An abstract class to handle weights initialization and
def
__init__
(
self
,
config
,
**
kwargs
):
a simple interface for dowloading and loading pretrained models.
super
(
TFTransfoXLMainLayer
,
self
).
__init__
(
**
kwargs
)
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
elif
self
.
config
.
init
==
'normal'
:
nn
.
init
.
normal_
(
weight
,
0.0
,
self
.
config
.
init_std
)
def
_init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
def
_init_weights
(
self
,
m
):
""" Initialize the weights.
"""
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Linear'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
)
and
m
.
weight
is
not
None
:
self
.
_init_weight
(
m
.
weight
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
elif
classname
.
find
(
'AdaptiveEmbedding'
)
!=
-
1
:
if
hasattr
(
m
,
'emb_projs'
):
for
i
in
range
(
len
(
m
.
emb_projs
)):
if
m
.
emb_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
emb_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'Embedding'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
self
.
_init_weight
(
m
.
weight
)
elif
classname
.
find
(
'ProjectedAdaptiveLogSoftmax'
)
!=
-
1
:
if
hasattr
(
m
,
'cluster_weight'
)
and
m
.
cluster_weight
is
not
None
:
self
.
_init_weight
(
m
.
cluster_weight
)
if
hasattr
(
m
,
'cluster_bias'
)
and
m
.
cluster_bias
is
not
None
:
self
.
_init_bias
(
m
.
cluster_bias
)
if
hasattr
(
m
,
'out_projs'
):
for
i
in
range
(
len
(
m
.
out_projs
)):
if
m
.
out_projs
[
i
]
is
not
None
:
nn
.
init
.
normal_
(
m
.
out_projs
[
i
],
0.0
,
self
.
config
.
proj_init_std
)
elif
classname
.
find
(
'LayerNorm'
)
!=
-
1
:
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
normal_
(
m
.
weight
,
1.0
,
self
.
config
.
init_std
)
if
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
self
.
_init_bias
(
m
.
bias
)
else
:
if
hasattr
(
m
,
'r_emb'
):
self
.
_init_weight
(
m
.
r_emb
)
if
hasattr
(
m
,
'r_w_bias'
):
self
.
_init_weight
(
m
.
r_w_bias
)
if
hasattr
(
m
,
'r_r_bias'
):
self
.
_init_weight
(
m
.
r_r_bias
)
if
hasattr
(
m
,
'r_bias'
):
self
.
_init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLModel
(
TransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
TransfoXLModel
,
self
).
__init__
(
config
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
...
@@ -727,11 +341,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -727,11 +341,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
d_model
=
config
.
d_model
self
.
d_model
=
config
.
d_model
self
.
n_head
=
config
.
n_head
self
.
n_head
=
config
.
n_head
self
.
d_head
=
config
.
d_head
self
.
d_head
=
config
.
d_head
self
.
untie_r
=
config
.
untie_r
self
.
word_emb
=
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
self
.
word_emb
=
TF
AdaptiveEmbedding
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
div_val
=
config
.
div_val
,
name
=
'word_emb'
)
self
.
drop
=
nn
.
Dropout
(
config
.
dropout
)
self
.
drop
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
n_layer
=
config
.
n_layer
self
.
n_layer
=
config
.
n_layer
...
@@ -742,61 +357,41 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -742,61 +357,41 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
attn_type
=
config
.
attn_type
self
.
attn_type
=
config
.
attn_type
if
not
config
.
untie_r
:
self
.
layers
=
[]
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
layers
=
nn
.
ModuleList
()
if
config
.
attn_type
==
0
:
# the default attention
if
config
.
attn_type
==
0
:
# the default attention
for
i
in
range
(
config
.
n_layer
):
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
self
.
layers
.
append
(
RelPartialLearnableDecoderLayer
(
TFRelPartialLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_w_bias
=
None
if
self
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
r_r_bias
=
None
if
self
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
output_attentions
=
self
.
output_attentions
,
)
name
=
'layers_._{}'
.
format
(
i
))
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
)
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
self
.
pos_emb
=
TFPositionalEmbedding
(
self
.
d_model
,
name
=
'pos_emb'
)
elif
self
.
attn_type
==
1
:
# learnable
else
:
# learnable embeddings and absolute embeddings
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
def
build
(
self
,
input_shape
):
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
if
not
self
.
untie_r
:
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
r_w_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
initializer
=
'zeros'
,
elif
self
.
attn_type
==
3
:
# absolute deeper SA
trainable
=
True
,
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
name
=
'r_w_bias'
)
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_r_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_head
,
self
.
d_head
),
initializer
=
'zeros'
,
self
.
init_weights
()
trainable
=
True
,
name
=
'r_r_bias'
)
super
(
TFTransfoXLMainLayer
,
self
).
build
(
input_shape
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
return
self
.
word_emb
...
@@ -810,16 +405,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -810,16 +405,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
ext_len
=
ext_len
self
.
ext_len
=
ext_len
def
_prune_heads
(
self
,
heads
):
def
_prune_heads
(
self
,
heads
):
logger
.
info
(
"Head pruning is not implemented for Transformer-XL model"
)
raise
NotImplementedError
pass
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
data
):
if
self
.
mem_len
>
0
:
if
self
.
mem_len
>
0
:
mems
=
[]
mems
=
[]
param
=
next
(
self
.
parameters
())
for
i
in
range
(
self
.
n_layer
):
for
i
in
range
(
self
.
n_layer
):
empty
=
torch
.
zeros
(
self
.
mem_len
,
data
.
size
(
1
),
self
.
config
.
d_model
,
empty
=
tf
.
zeros
([
self
.
mem_len
,
shape_list
(
data
)[
1
],
self
.
d_model
])
dtype
=
param
.
dtype
,
device
=
param
.
device
)
mems
.
append
(
empty
)
mems
.
append
(
empty
)
return
mems
return
mems
...
@@ -838,164 +430,211 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -838,164 +430,211 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# will be used as the extended context. Hence, we only cache
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
# to `mlen + qlen - self.ext_len`.
with
torch
.
no_grad
():
new_mems
=
[]
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
end_idx
=
mlen
+
max
(
0
,
qlen
-
0
-
self
.
ext_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
for
i
in
range
(
len
(
hids
)):
cat
=
torch
.
cat
([
mems
[
i
],
hids
[
i
]],
dim
=
0
)
cat
=
tf
.
concat
([
mems
[
i
],
hids
[
i
]],
axis
=
0
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
].
detach
())
tf
.
stop_gradient
(
cat
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
])
return
new_mems
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
def
call
(
self
,
inputs
,
training
=
False
):
qlen
,
bsz
=
dec_inp
.
size
()
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
input_ids
=
inputs
mems
,
head_mask
=
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
assert
len
(
inputs
)
<=
3
,
"Too many inputs."
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
tf
.
transpose
(
input_ids
,
perm
=
(
1
,
0
))
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
qlen
,
bsz
=
shape_list
(
input_ids
)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
not
head_mask
is
None
:
if
head_mask
.
dim
()
==
1
:
raise
NotImplementedError
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
else
:
head_mask
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
word_emb
=
self
.
word_emb
(
input_ids
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
mlen
=
shape_list
(
mems
[
0
])[
0
]
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
attn_mask
=
tf
.
ones
([
qlen
,
qlen
])
mask_u
=
tf
.
linalg
.
band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
linalg
.
band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
])
dec_attn_mask
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
self
.
same_length
:
if
self
.
same_length
:
all_ones
=
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
)
mask_l
=
tf
.
linalg
.
band_part
(
attn_mask
,
-
1
,
0
)
mask_len
=
klen
-
self
.
mem_len
dec_attn_mask
=
tf
.
concat
([
dec_attn_mask
[:,
:
qlen
]
+
mask_l
-
mask_dia
,
if
mask_len
>
0
:
dec_attn_mask
[:,
qlen
:]],
1
)
mask_shift_len
=
qlen
-
mask_len
# ::: PyTorch masking code for reference :::
else
:
# if self.same_length:
mask_shift_len
=
qlen
# all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
# mask_len = klen - self.mem_len
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
))[:,
:,
None
]
# -1
# if mask_len > 0:
else
:
# mask_shift_len = qlen - mask_len
dec_attn_mask
=
torch
.
triu
(
# else:
word_emb
.
new_ones
((
qlen
,
klen
),
dtype
=
torch
.
uint8
),
diagonal
=
1
+
mlen
)[:,:,
None
]
# mask_shift_len = qlen
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
# else:
# dec_attn_mask = torch.triu(
# word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
hids
=
[]
hids
=
[]
attentions
=
[]
attentions
=
[]
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
pos_seq
=
tf
.
range
(
klen
-
1
,
-
1
,
-
1.0
)
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
)
pos_emb
=
self
.
drop
(
pos_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
pos_emb
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_seq
=
tf
.
minimum
(
pos_seq
,
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
core_out
=
self
.
drop
(
word_emb
,
training
=
training
)
pos_emb
=
self
.
drop
(
pos_emb
,
training
=
training
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
layer_outputs
=
layer
([
core_out
,
pos_emb
,
dec_attn_mask
,
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
mems_i
,
head_mask
[
i
]],
training
=
training
)
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
attentions
.
append
(
layer_outputs
[
1
])
else
:
# learnable embeddings and absolute embeddings
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
,
training
=
training
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
new_mems
=
self
.
_update_mems
(
hids
,
mems
,
mlen
,
qlen
)
# We transpose back here to shape [bsz, len, hidden_dim]
# We transpose back here to shape [bsz, len, hidden_dim]
outputs
=
[
core_ou
t
.
transpose
(
0
,
1
).
contiguous
(
),
new_mems
]
outputs
=
[
t
f
.
transpose
(
core_out
,
perm
=
(
1
,
0
,
2
)
),
new_mems
]
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids
.
append
(
core_out
)
hids
.
append
(
core_out
)
hids
=
list
(
t
.
transpose
(
0
,
1
).
contiguous
(
)
for
t
in
hids
)
hids
=
list
(
t
f
.
transpose
(
t
,
perm
=
(
1
,
0
,
2
)
)
for
t
in
hids
)
outputs
.
append
(
hids
)
outputs
.
append
(
hids
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
(
)
for
t
in
attentions
)
attentions
=
list
(
t
f
.
transpose
(
t
,
perm
=
(
2
,
3
,
0
,
1
)
)
for
t
in
attentions
)
outputs
.
append
(
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
class
TFTransfoXLPreTrainedModel
(
TFPreTrainedModel
):
mems
=
self
.
init_mems
(
input_ids
)
""" An abstract class to handle weights initialization and
outputs
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_transfo_xl_pt_weights_in_tf2
base_model_prefix
=
"transformer"
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse
previously computed hidden-states to attend to longer context (memory).
This model also uses adaptive softmax inputs and outputs (tied).
This model is a PyTorch `torch.tf.keras.layers.Layer`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
https://arxiv.org/abs/1901.02860
.. _`torch.tf.keras.layers.Layer`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
TRANSFO_XL_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
the right or on the left.
Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mems**: (`optional`)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Bert Model transformer outputing raw hidden-states without any specific head on top."
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TFTransfoXLModel
(
TFTransfoXLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states, mems = outputs[:2]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFTransfoXLModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFTransfoXLMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
training
=
False
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
training
=
training
,
**
kwargs
)
return
outputs
@
add_start_docstrings
(
"""The Transformer-XL Model with a language modeling head on top
@
add_start_docstrings
(
"""The Transformer-XL Model with a language modeling head on top
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
(adaptive softmax with weights tied to the adaptive input embeddings)"""
,
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
TRANSFO_XL_START_DOCSTRING
,
TRANSFO_XL_INPUTS_DOCSTRING
)
class
TransfoXLLMHeadModel
(
TransfoXLPreTrainedModel
):
class
TF
TransfoXLLMHeadModel
(
TF
TransfoXLPreTrainedModel
):
r
"""
r
"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Labels for language modeling.
...
@@ -1032,46 +671,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1032,46 +671,16 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
super
(
TF
TransfoXLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
TransfoXLM
odel
(
config
)
self
.
transformer
=
TF
TransfoXLM
ainLayer
(
config
,
name
=
'transformer'
)
self
.
sample_softmax
=
config
.
sample_softmax
self
.
sample_softmax
=
config
.
sample_softmax
# use sampled softmax
# use sampled softmax
if
config
.
sample_softmax
>
0
:
if
config
.
sample_softmax
>
0
:
self
.
out_layer
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
)
raise
NotImplementedError
self
.
sampler
=
LogUniformSampler
(
config
.
n_token
,
config
.
sample_softmax
)
# use adaptive softmax (including standard softmax)
# use adaptive softmax (including standard softmax)
else
:
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
self
.
crit
=
TFAdaptiveSoftmaxMask
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
config
.
cutoffs
,
div_val
=
config
.
div_val
,
name
=
'crit'
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
"""
Run this to be sure output and input (adaptive) softmax weights are tied
"""
# sampled softmax
if
self
.
sample_softmax
>
0
:
if
self
.
config
.
tie_weight
:
self
.
out_layer
.
weight
=
self
.
transformer
.
word_emb
.
weight
# adaptive softmax (including standard softmax)
else
:
if
self
.
config
.
tie_weight
:
for
i
in
range
(
len
(
self
.
crit
.
out_layers
)):
self
.
_tie_or_clone_weights
(
self
.
crit
.
out_layers
[
i
],
self
.
transformer
.
word_emb
.
emb_layers
[
i
])
if
self
.
config
.
tie_projs
:
for
i
,
tie_proj
in
enumerate
(
self
.
config
.
tie_projs
):
if
tie_proj
and
self
.
config
.
div_val
==
1
and
self
.
config
.
d_model
!=
self
.
config
.
d_embed
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
0
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
0
]
elif
tie_proj
and
self
.
config
.
div_val
!=
1
:
if
self
.
config
.
torchscript
:
self
.
crit
.
out_projs
[
i
]
=
nn
.
Parameter
(
self
.
transformer
.
word_emb
.
emb_projs
[
i
].
clone
())
else
:
self
.
crit
.
out_projs
[
i
]
=
self
.
transformer
.
word_emb
.
emb_projs
[
i
]
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
def
reset_length
(
self
,
tgt_len
,
ext_len
,
mem_len
):
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
self
.
transformer
.
reset_length
(
tgt_len
,
ext_len
,
mem_len
)
...
@@ -1079,30 +688,36 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1079,30 +688,36 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def
init_mems
(
self
,
data
):
def
init_mems
(
self
,
data
):
return
self
.
transformer
.
init_mems
(
data
)
return
self
.
transformer
.
init_mems
(
data
)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
,
labels
=
None
):
def
call
(
self
,
inputs
,
training
=
False
):
bsz
=
input_ids
.
size
(
0
)
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
tgt_len
=
input_ids
.
size
(
1
)
input_ids
=
inputs
mems
,
head_mask
,
labels
=
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
mems
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
head_mask
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
labels
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
mems
=
inputs
.
get
(
'mems'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
labels
=
inputs
.
get
(
'labels'
,
None
)
assert
len
(
inputs
)
<=
4
,
"Too many inputs."
bsz
,
tgt_len
=
shape_list
(
input_ids
)[:
2
]
transformer_outputs
=
self
.
transformer
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
transformer_outputs
=
self
.
transformer
(
[
input_ids
,
mems
,
head_mask
],
training
=
training
)
last_hidden
=
transformer_outputs
[
0
]
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
outputs
=
transformer_outputs
[
1
:]
outputs
=
transformer_outputs
[
1
:]
if
self
.
sample_softmax
>
0
and
self
.
training
:
if
self
.
sample_softmax
>
0
and
training
:
assert
self
.
config
.
tie_weight
raise
NotImplementedError
logit
=
sample_logits
(
self
.
transformer
.
word_emb
,
self
.
out_layer
.
bias
,
labels
,
pred_hid
,
self
.
sampler
)
softmax_output
=
-
F
.
log_softmax
(
logit
,
-
1
)[:,
:,
0
]
outputs
=
[
softmax_output
]
+
outputs
if
labels
is
not
None
:
# TODO: This is not implemented
raise
NotImplementedError
else
:
else
:
softmax_output
=
self
.
crit
(
pred_hid
.
view
(
-
1
,
pred_hid
.
size
(
-
1
)),
labels
)
# pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1]))
if
labels
is
None
:
softmax_output
=
self
.
crit
([
pred_hid
,
labels
],
training
=
training
)
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
,
-
1
)
# softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1))
outputs
=
[
softmax_output
]
+
outputs
outputs
=
[
softmax_output
]
+
outputs
else
:
softmax_output
=
softmax_output
.
view
(
bsz
,
tgt_len
)
outputs
=
[
softmax_output
,
None
]
+
outputs
return
outputs
#
(loss), logits or None if labels is not None (speed up adaptive softmax)
, new_mems, (all hidden states), (all attentions)
return
outputs
#
logits
, new_mems, (all hidden states), (all attentions)
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Utilities for PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
"""
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
from
.modeling_tf_utils
import
shape_list
class
TFAdaptiveSoftmaxMask
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
,
**
kwargs
):
super
(
TFAdaptiveSoftmaxMask
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_proj
=
d_proj
self
.
cutoffs
=
cutoffs
+
[
n_token
]
self
.
cutoff_ends
=
[
0
]
+
self
.
cutoffs
self
.
div_val
=
div_val
self
.
shortlist_size
=
self
.
cutoffs
[
0
]
self
.
n_clusters
=
len
(
self
.
cutoffs
)
-
1
self
.
head_size
=
self
.
shortlist_size
+
self
.
n_clusters
self
.
keep_order
=
keep_order
self
.
out_layers
=
[]
self
.
out_projs
=
[]
def
build
(
self
,
input_shape
):
if
self
.
n_clusters
>
0
:
self
.
cluster_weight
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,
self
.
d_embed
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_weight'
)
self
.
cluster_bias
=
self
.
add_weight
(
shape
=
(
self
.
n_clusters
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'cluster_bias'
)
if
self
.
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
self
.
d_proj
!=
self
.
d_embed
:
weight
=
self
.
add_weight
(
shape
=
(
self
.
d_embed
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
else
:
self
.
out_projs
.
append
(
None
)
weight
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,
self
.
d_embed
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
self
.
n_token
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
else
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
weight
=
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_projs_._{}'
.
format
(
i
))
self
.
out_projs
.
append
(
weight
)
weight
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,
d_emb_i
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._weight'
.
format
(
i
))
bias
=
self
.
add_weight
(
shape
=
(
r_idx
-
l_idx
,),
initializer
=
'zeros'
,
trainable
=
True
,
name
=
'out_layers_._{}_._bias'
.
format
(
i
))
self
.
out_layers
.
append
((
weight
,
bias
))
super
(
TFAdaptiveSoftmaxMask
,
self
).
build
(
input_shape
)
@
staticmethod
def
_logit
(
x
,
W
,
b
,
proj
=
None
):
y
=
x
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
@
staticmethod
def
_gather_logprob
(
logprob
,
target
):
lp_size
=
tf
.
shape
(
logprob
)
r
=
tf
.
range
(
lp_size
[
0
])
idx
=
tf
.
stack
([
r
,
target
],
1
)
return
tf
.
gather_nd
(
logprob
,
idx
)
def
call
(
self
,
inputs
,
return_mean
=
True
,
training
=
False
):
hidden
,
target
=
inputs
head_logprob
=
0
if
self
.
n_clusters
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
self
.
_logit
(
hidden
,
self
.
out_layers
[
0
][
0
],
self
.
out_layers
[
0
][
1
],
self
.
out_projs
[
0
])
if
target
is
not
None
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
out
=
tf
.
nn
.
log_softmax
(
output
,
axis
=-
1
)
else
:
hidden_sizes
=
shape_list
(
hidden
)
out
=
[]
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
],
dtype
=
tf
.
float32
)
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
if
target
is
not
None
:
mask
=
(
target
>=
l_idx
)
&
(
target
<
r_idx
)
mask_idx
=
tf
.
where
(
mask
)
cur_target
=
tf
.
boolean_mask
(
target
,
mask
)
-
l_idx
if
self
.
div_val
==
1
:
cur_W
=
self
.
out_layers
[
0
][
0
][
l_idx
:
r_idx
]
cur_b
=
self
.
out_layers
[
0
][
1
][
l_idx
:
r_idx
]
else
:
cur_W
=
self
.
out_layers
[
i
][
0
]
cur_b
=
self
.
out_layers
[
i
][
1
]
if
i
==
0
:
cur_W
=
tf
.
concat
([
cur_W
,
self
.
cluster_weight
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
self
.
cluster_bias
],
0
)
head_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
0
])
head_logprob
=
tf
.
nn
.
log_softmax
(
head_logit
)
out
.
append
(
head_logprob
[...,
:
self
.
cutoffs
[
0
]])
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_head_logprob
,
cur_target
)
else
:
tail_logit
=
self
.
_logit
(
hidden
,
cur_W
,
cur_b
,
self
.
out_projs
[
i
])
tail_logprob
=
tf
.
nn
.
log_softmax
(
tail_logit
)
cluster_prob_idx
=
self
.
cutoffs
[
0
]
+
i
-
1
# No probability for the head cluster
logprob_i
=
head_logprob
[...,
cluster_prob_idx
,
None
]
+
tail_logprob
out
.
append
(
logprob_i
)
if
target
is
not
None
:
cur_head_logprob
=
tf
.
boolean_mask
(
head_logprob
,
mask
)
cur_tail_logprob
=
tf
.
boolean_mask
(
tail_logprob
,
mask
)
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
if
target
is
not
None
:
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
tf
.
cast
(
tf
.
shape
(
loss
),
dtype
=
tf
.
int64
))
out
=
tf
.
concat
(
out
,
axis
=-
1
)
if
target
is
not
None
:
if
return_mean
:
loss
=
tf
.
reduce_mean
(
loss
)
# Add the training-time loss value to the layer using `self.add_loss()`.
self
.
add_loss
(
loss
)
# Log the loss as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.
self
.
add_metric
(
loss
,
name
=
self
.
name
,
aggregation
=
'mean'
if
return_mean
else
''
)
return
out
def
mul_adaptive_logsoftmax
(
hidden
,
target
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
params
,
tie_projs
,
initializer
=
None
,
proj_initializer
=
None
,
div_val
=
1
,
perms
=
None
,
proj_same_dim
=
True
,
scope
=
'adaptive_softmax'
,
**
kwargs
):
def
_logit
(
x
,
W
,
b
,
proj
):
y
=
x
if
x
.
shape
.
ndims
==
3
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'ibd,ed->ibe'
,
y
,
proj
)
return
tf
.
einsum
(
'ibd,nd->ibn'
,
y
,
W
)
+
b
else
:
if
proj
is
not
None
:
y
=
tf
.
einsum
(
'id,ed->ie'
,
y
,
proj
)
return
tf
.
einsum
(
'id,nd->in'
,
y
,
W
)
+
b
params_W
,
params_projs
=
params
[
0
],
params
[
1
]
with
tf
.
variable_scope
(
scope
):
if
len
(
cutoffs
)
==
0
:
softmax_b
=
tf
.
get_variable
(
'bias'
,
[
n_token
],
initializer
=
tf
.
zeros_initializer
())
output
=
_logit
(
hidden
,
params_W
,
softmax_b
,
params_projs
)
nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
output
)
nll
=
tf
.
reduce_mean
(
nll
)
else
:
total_loss
,
total_cnt
=
0
,
0
cutoff_ends
=
[
0
]
+
cutoffs
+
[
n_token
]
for
i
in
range
(
len
(
cutoff_ends
)
-
1
):
with
tf
.
variable_scope
(
'cutoff_{}'
.
format
(
i
)):
l_idx
,
r_idx
=
cutoff_ends
[
i
],
cutoff_ends
[
i
+
1
]
cur_d_embed
=
d_embed
//
(
div_val
**
i
)
if
div_val
==
1
:
cur_W
=
params_W
[
l_idx
:
r_idx
]
else
:
cur_W
=
params_W
[
i
]
cur_b
=
tf
.
get_variable
(
'b'
,
[
r_idx
-
l_idx
],
initializer
=
tf
.
zeros_initializer
())
if
tie_projs
[
i
]:
if
div_val
==
1
:
cur_proj
=
params_projs
else
:
cur_proj
=
params_projs
[
i
]
else
:
if
(
div_val
==
1
or
not
proj_same_dim
)
and
d_proj
==
cur_d_embed
:
cur_proj
=
None
else
:
cur_proj
=
tf
.
get_variable
(
'proj'
,
[
cur_d_embed
,
d_proj
],
initializer
=
proj_initializer
)
if
i
==
0
:
cluster_W
=
tf
.
get_variable
(
'cluster_W'
,
[
len
(
cutoffs
),
d_embed
],
initializer
=
tf
.
zeros_initializer
())
cluster_b
=
tf
.
get_variable
(
'cluster_b'
,
[
len
(
cutoffs
)],
initializer
=
tf
.
zeros_initializer
())
cur_W
=
tf
.
concat
([
cur_W
,
cluster_W
],
0
)
cur_b
=
tf
.
concat
([
cur_b
,
cluster_b
],
0
)
head_logit
=
_logit
(
hidden
,
cur_W
,
cur_b
,
cur_proj
)
head_target
=
kwargs
.
get
(
"head_target"
)
head_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
head_target
,
logits
=
head_logit
)
masked_loss
=
head_nll
*
perms
[
i
]
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
perms
[
i
])
# head_logprob = tf.nn.log_softmax(head_logit)
# final_logprob = head_logprob * perms[i][:, :, None]
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
# total_cnt += tf.reduce_sum(perms[i])
else
:
cur_head_nll
=
tf
.
einsum
(
'ib,ibk->k'
,
head_nll
,
perms
[
i
])
cur_hidden
=
tf
.
einsum
(
'ibd,ibk->kd'
,
hidden
,
perms
[
i
])
tail_logit
=
_logit
(
cur_hidden
,
cur_W
,
cur_b
,
cur_proj
)
tail_target
=
tf
.
einsum
(
'ib,ibk->k'
,
tf
.
to_float
(
target
-
l_idx
),
perms
[
i
])
tail_nll
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
tf
.
to_int32
(
tail_target
),
logits
=
tail_logit
)
sum_nll
=
cur_head_nll
+
tail_nll
mask
=
tf
.
reduce_sum
(
perms
[
i
],
[
0
,
1
])
masked_loss
=
sum_nll
*
mask
total_loss
+=
tf
.
reduce_sum
(
masked_loss
)
total_cnt
+=
tf
.
reduce_sum
(
mask
)
nll
=
total_loss
/
total_cnt
return
nll
\ No newline at end of file
pytorch_transformers/modeling_tf_xlm.py
View file @
65c49bb2
...
@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
self
.
ffns
=
[]
self
.
ffns
=
[]
self
.
layer_norm2
=
[]
self
.
layer_norm2
=
[]
# if self.is_decoder:
# if self.is_decoder:
# self.layer_norm15 =
tf.keras.layers.LayerList()
# self.layer_norm15 =
[]
# self.encoder_attn =
tf.keras.layers.LayerList()
# self.encoder_attn =
[]
for
i
in
range
(
self
.
n_layers
):
for
i
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_._{}'
.
format
(
i
)))
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_._{}'
.
format
(
i
)))
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
65c49bb2
...
@@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module):
...
@@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module):
return
output
return
output
class
RelPartialLearnableMultiHeadAttn
(
nn
.
Module
):
class
MultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
d_model
=
d_model
self
.
d_head
=
d_head
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
n_head
*
d_head
,
bias
=
False
)
self
.
kv_net
=
nn
.
Linear
(
d_model
,
2
*
n_head
*
d_head
,
bias
=
False
)
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
self
.
pre_lnorm
=
pre_lnorm
if
r_r_bias
is
None
or
r_w_bias
is
None
:
# Biases are not shared
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_head
,
self
.
d_head
))
else
:
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
def
forward
(
self
,
h
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if
mems
is
not
None
:
c
=
torch
.
cat
([
mems
,
h
],
0
)
else
:
c
=
h
if
self
.
pre_lnorm
:
##### layer normalization
c
=
self
.
layer_norm
(
c
)
head_q
=
self
.
q_net
(
h
)
head_k
,
head_v
=
torch
.
chunk
(
self
.
kv_net
(
c
),
2
,
-
1
)
head_q
=
head_q
.
view
(
h
.
size
(
0
),
h
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_k
=
head_k
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
head_v
=
head_v
.
view
(
c
.
size
(
0
),
c
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [qlen x klen x bsz x n_head]
attn_score
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
head_q
,
head_k
))
attn_score
.
mul_
(
self
.
scale
)
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
head_v
))
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
h
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
h
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
RelMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
super
(
RelMultiHeadAttn
,
self
).
__init__
()
super
(
Rel
PartialLearnable
MultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
self
.
n_head
=
n_head
self
.
n_head
=
n_head
...
@@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module):
...
@@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module):
self
.
r_r_bias
=
r_r_bias
self
.
r_r_bias
=
r_r_bias
self
.
r_w_bias
=
r_w_bias
self
.
r_w_bias
=
r_w_bias
def
_parallelogram_mask
(
self
,
h
,
w
,
left
=
False
):
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
mask
=
torch
.
ones
((
h
,
w
)).
byte
()
m
=
min
(
h
,
w
)
mask
[:
m
,:
m
]
=
torch
.
triu
(
mask
[:
m
,:
m
])
mask
[
-
m
:,
-
m
:]
=
torch
.
tril
(
mask
[
-
m
:,
-
m
:])
if
left
:
return
mask
else
:
return
mask
.
flip
(
0
)
def
_shift
(
self
,
x
,
qlen
,
klen
,
mask
,
left
=
False
):
if
qlen
>
1
:
zero_pad
=
torch
.
zeros
((
x
.
size
(
0
),
qlen
-
1
,
x
.
size
(
2
),
x
.
size
(
3
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
zero_pad
=
torch
.
zeros
(
0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
left
:
mask
=
mask
.
flip
(
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
else
:
x_padded
=
torch
.
cat
([
x
,
zero_pad
],
dim
=
1
).
expand
(
qlen
,
-
1
,
-
1
,
-
1
)
x
=
x_padded
.
masked_select
(
mask
[:,:,
None
,
None
])
\
.
view
(
qlen
,
klen
,
x
.
size
(
2
),
x
.
size
(
3
))
return
x
def
_rel_shift
(
self
,
x
,
zero_triu
=
False
):
def
_rel_shift
(
self
,
x
):
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad_shape
=
(
x
.
size
(
0
),
1
)
+
x
.
size
()[
2
:]
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
zero_pad
=
torch
.
zeros
(
zero_pad_shape
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=
1
)
...
@@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module):
...
@@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module):
x
=
x_padded
[
1
:].
view_as
(
x
)
x
=
x_padded
[
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
0
),
x
.
size
(
1
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
1
)
-
x
.
size
(
0
))[:,:,
None
,
None
]
return
x
return
x
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
):
raise
NotImplementedError
class
RelPartialLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
r_net
=
nn
.
Linear
(
self
.
d_model
,
self
.
n_head
*
self
.
d_head
,
bias
=
False
)
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
def
forward
(
self
,
w
,
r
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
qlen
,
rlen
,
bsz
=
w
.
size
(
0
),
r
.
size
(
0
),
w
.
size
(
1
)
...
@@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
return
outputs
return
outputs
class
RelLearnableMultiHeadAttn
(
RelMultiHeadAttn
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
RelLearnableMultiHeadAttn
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
w
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen
,
bsz
=
w
.
size
(
0
),
w
.
size
(
1
)
if
mems
is
not
None
:
cat
=
torch
.
cat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
else
:
w_heads
=
self
.
qkv_net
(
cat
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
w_head_q
=
w_head_q
[
-
qlen
:]
else
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
w
))
else
:
w_heads
=
self
.
qkv_net
(
w
)
w_head_q
,
w_head_k
,
w_head_v
=
torch
.
chunk
(
w_heads
,
3
,
dim
=-
1
)
klen
=
w_head_k
.
size
(
0
)
w_head_q
=
w_head_q
.
view
(
qlen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_k
=
w_head_k
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
w_head_v
=
w_head_v
.
view
(
klen
,
bsz
,
self
.
n_head
,
self
.
d_head
)
if
klen
>
r_emb
.
size
(
0
):
r_emb_pad
=
r_emb
[
0
:
1
].
expand
(
klen
-
r_emb
.
size
(
0
),
-
1
,
-
1
)
r_emb
=
torch
.
cat
([
r_emb_pad
,
r_emb
],
0
)
r_bias_pad
=
r_bias
[
0
:
1
].
expand
(
klen
-
r_bias
.
size
(
0
),
-
1
)
r_bias
=
torch
.
cat
([
r_bias_pad
,
r_bias
],
0
)
else
:
r_emb
=
r_emb
[
-
klen
:]
r_bias
=
r_bias
[
-
klen
:]
#### compute attention score
rw_head_q
=
w_head_q
+
r_w_bias
[
None
]
# qlen x bsz x n_head x d_head
AC
=
torch
.
einsum
(
'ibnd,jbnd->ijbn'
,
(
rw_head_q
,
w_head_k
))
# qlen x klen x bsz x n_head
B_
=
torch
.
einsum
(
'ibnd,jnd->ijbn'
,
(
w_head_q
,
r_emb
))
# qlen x klen x bsz x n_head
D_
=
r_bias
[
None
,
:,
None
]
# 1 x klen x 1 x n_head
BD
=
self
.
_rel_shift
(
B_
+
D_
)
# [qlen x klen x bsz x n_head]
attn_score
=
AC
+
BD
attn_score
.
mul_
(
self
.
scale
)
#### compute attention probability
if
attn_mask
is
not
None
and
torch
.
sum
(
attn_mask
).
item
():
attn_mask
=
(
attn_mask
==
1
)
# Switch to bool
if
attn_mask
.
dim
()
==
2
:
attn_score
.
masked_fill_
(
attn_mask
[
None
,:,:,
None
],
-
float
(
'inf'
))
elif
attn_mask
.
dim
()
==
3
:
attn_score
.
masked_fill_
(
attn_mask
[:,:,:,
None
],
-
float
(
'inf'
))
# [qlen x klen x bsz x n_head]
attn_prob
=
F
.
softmax
(
attn_score
,
dim
=
1
)
attn_prob
=
self
.
dropatt
(
attn_prob
)
if
head_mask
is
not
None
:
attn_prob
=
attn_prob
*
head_mask
#### compute attention vector
attn_vec
=
torch
.
einsum
(
'ijbn,jbnd->ibnd'
,
(
attn_prob
,
w_head_v
))
# [qlen x bsz x n_head x d_head]
attn_vec
=
attn_vec
.
contiguous
().
view
(
attn_vec
.
size
(
0
),
attn_vec
.
size
(
1
),
self
.
n_head
*
self
.
d_head
)
##### linear projection
attn_out
=
self
.
o_net
(
attn_vec
)
attn_out
=
self
.
drop
(
attn_out
)
if
self
.
pre_lnorm
:
##### residual connection
outputs
=
[
w
+
attn_out
]
else
:
##### residual connection + layer normalization
outputs
=
[
self
.
layer_norm
(
w
+
attn_out
)]
if
self
.
output_attentions
:
outputs
.
append
(
attn_prob
)
return
outputs
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
RelLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
attn_outputs
=
self
.
dec_attn
(
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
attn_mask
=
dec_attn_mask
,
mems
=
mems
,
head_mask
=
head_mask
)
ff_output
=
self
.
pos_ff
(
attn_outputs
[
0
])
outputs
=
[
ff_output
]
+
attn_outputs
[
1
:]
return
outputs
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
...
@@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module):
...
@@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module):
return
outputs
return
outputs
class
AdaptiveEmbedding
(
nn
.
Module
):
class
AdaptiveEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
sample_softmax
=
False
):
sample_softmax
=
False
):
...
@@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
if
hasattr
(
m
,
'r_bias'
):
if
hasattr
(
m
,
'r_bias'
):
self
.
_init_bias
(
m
.
r_bias
)
self
.
_init_bias
(
m
.
r_bias
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
pass
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
TRANSFO_XL_START_DOCSTRING
=
r
""" The Transformer-XL model was proposed in
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
...
@@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
output_attentions
=
self
.
output_attentions
)
)
)
elif
config
.
attn_type
==
1
:
# learnable embeddings
else
:
# learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
for
i
in
range
(
config
.
n_layer
):
raise
NotImplementedError
# Removed them to avoid maintaining dead code
self
.
layers
.
append
(
RelLearnableDecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
tgt_len
=
config
.
tgt_len
,
ext_len
=
config
.
ext_len
,
mem_len
=
config
.
mem_len
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
elif
config
.
attn_type
in
[
2
,
3
]:
# absolute embeddings
for
i
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
DecoderLayer
(
config
.
n_head
,
config
.
d_model
,
config
.
d_head
,
config
.
d_inner
,
config
.
dropout
,
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
)
self
.
same_length
=
config
.
same_length
self
.
same_length
=
config
.
same_length
self
.
clamp_len
=
config
.
clamp_len
self
.
clamp_len
=
config
.
clamp_len
if
self
.
attn_type
==
0
:
# default attention
if
self
.
attn_type
==
0
:
# default attention
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
1
:
# learnable
else
:
# learnable embeddings and absolute embeddings
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
))
elif
self
.
attn_type
==
2
:
# absolute standard
self
.
pos_emb
=
PositionalEmbedding
(
self
.
d_model
)
elif
self
.
attn_type
==
3
:
# absolute deeper SA
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
init_weights
()
self
.
init_weights
()
...
@@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return
new_mems
return
new_mems
def
_forward
(
self
,
dec_inp
,
mems
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
qlen
,
bsz
=
dec_inp
.
size
()
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
qlen
,
bsz
=
input_ids
.
size
()
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
...
@@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
n_layer
head_mask
=
[
None
]
*
self
.
n_layer
word_emb
=
self
.
word_emb
(
dec_inp
)
word_emb
=
self
.
word_emb
(
input_ids
)
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
mlen
=
mems
[
0
].
size
(
0
)
if
mems
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
...
@@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
core_out
=
layer_outputs
[
0
]
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
1
:
# learnable
else
:
# learnable embeddings and absolute embeddings
core_out
=
self
.
drop
(
word_emb
)
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
if
self
.
clamp_len
>
0
:
r_emb
=
self
.
r_emb
[
i
][
-
self
.
clamp_len
:]
r_bias
=
self
.
r_bias
[
i
][
-
self
.
clamp_len
:]
else
:
r_emb
,
r_bias
=
self
.
r_emb
[
i
],
self
.
r_bias
[
i
]
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
layer_outputs
=
layer
(
core_out
,
r_emb
,
self
.
r_w_bias
[
i
],
r_bias
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
2
:
# absolute
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
word_emb
.
device
,
dtype
=
word_emb
.
dtype
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
pos_emb
(
pos_seq
)
core_out
=
self
.
drop
(
word_emb
+
pos_emb
[
-
qlen
:])
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
i
==
0
:
mems_i
+=
pos_emb
[:
mlen
]
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
elif
self
.
attn_type
==
3
:
core_out
=
self
.
drop
(
word_emb
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
hids
.
append
(
core_out
)
mems_i
=
None
if
mems
is
None
else
mems
[
i
]
if
mems_i
is
not
None
and
mlen
>
0
:
cur_emb
=
self
.
r_emb
[
i
][:
-
qlen
]
cur_size
=
cur_emb
.
size
(
0
)
if
cur_size
<
mlen
:
cur_emb_pad
=
cur_emb
[
0
:
1
].
expand
(
mlen
-
cur_size
,
-
1
,
-
1
)
cur_emb
=
torch
.
cat
([
cur_emb_pad
,
cur_emb
],
0
)
else
:
cur_emb
=
cur_emb
[
-
mlen
:]
mems_i
+=
cur_emb
.
view
(
mlen
,
1
,
-
1
)
core_out
+=
self
.
r_emb
[
i
][
-
qlen
:].
view
(
qlen
,
1
,
-
1
)
layer_outputs
=
layer
(
core_out
,
dec_attn_mask
=
dec_attn_mask
,
mems
=
mems_i
,
head_mask
=
head_mask
[
i
])
core_out
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
attentions
.
append
(
layer_outputs
[
1
])
core_out
=
self
.
drop
(
core_out
)
core_out
=
self
.
drop
(
core_out
)
...
@@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
outputs
.
append
(
attentions
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
def
forward
(
self
,
input_ids
,
mems
=
None
,
head_mask
=
None
):
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
input_ids
=
input_ids
.
transpose
(
0
,
1
).
contiguous
()
if
mems
is
None
:
mems
=
self
.
init_mems
(
input_ids
)
outputs
=
self
.
_forward
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
return
outputs
# last hidden state, new_mems, (all hidden states), (all attentions)
...
...
pytorch_transformers/tests/modeling_tf_bert_test.py
View file @
65c49bb2
...
@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
def
create_and_check_bert_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFBertModel
(
config
=
config
)
model
=
TFBertModel
(
config
=
config
)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# sequence_output, pooled_output = model(**inputs)
inputs
=
{
'input_ids'
:
input_ids
,
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input
_ids
,
attention_mask
=
input_mask
,
token_type_ids
=
token_type_id
s
)
inputs
=
[
input_ids
,
input_mask
]
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
inputs
)
...
...
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
0 → 100644
View file @
65c49bb2
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
random
import
shutil
import
pytest
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
pytorch_transformers
import
TransfoXLConfig
,
is_tf_available
if
is_tf_available
():
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_transfo_xl
import
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFTransfoXLModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
)
if
is_tf_available
()
else
()
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
False
class
TFTransfoXLModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
30
,
clamp_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
d_embed
=
32
,
num_attention_heads
=
4
,
d_head
=
8
,
d_inner
=
128
,
div_val
=
2
,
num_hidden_layers
=
5
,
scope
=
None
,
seed
=
1
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
num_attention_heads
=
num_attention_heads
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
num_hidden_layers
=
num_hidden_layers
self
.
scope
=
scope
self
.
seed
=
seed
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
num_attention_heads
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
num_hidden_layers
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
tf
.
random
.
set_seed
(
self
.
seed
)
def
create_and_check_transfo_xl_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLModel
(
config
)
hidden_states_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_2
,
'mems'
:
mems_1
}
hidden_states_2
,
mems_2
=
model
(
inputs
)
result
=
{
"hidden_states_1"
:
hidden_states_1
.
numpy
(),
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"hidden_states_2"
:
hidden_states_2
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TFTransfoXLLMHeadModel
(
config
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'labels'
:
lm_labels
}
_
,
mems_1
=
model
(
inputs
)
lm_logits_2
,
mems_2
=
model
([
input_ids_2
,
mems_1
])
inputs
=
{
'input_ids'
:
input_ids_1
,
'mems'
:
mems_1
,
'labels'
:
lm_labels
}
_
,
mems_2
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"lm_logits_1"
:
lm_logits_1
.
numpy
(),
"mems_2"
:
[
mem
.
numpy
()
for
mem
in
mems_2
],
"lm_logits_2"
:
lm_logits_2
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_1"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits_2"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFTransfoXLModelTest
.
TFTransfoXLModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_transfo_xl_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_model
(
*
config_and_inputs
)
def
test_transfo_xl_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_transfo_xl_lm_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFTransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment