Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f3776df0
"...resnet50_tensorflow.git" did not exist on "04ef7772b7cbc885d5d55ded6f541e796dd76884"
Commit
f3776df0
authored
Dec 02, 2019
by
thomwolf
Browse files
WIP debugging
parent
268d4f20
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
16 deletions
+45
-16
transformers/modeling_t5.py
transformers/modeling_t5.py
+45
-16
No files found.
transformers/modeling_t5.py
View file @
f3776df0
...
@@ -132,6 +132,21 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
...
@@ -132,6 +132,21 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
####################################################
class
T5LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super
(
T5LayerNorm
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
):
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
/
torch
.
sqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
x
class
T5DenseReluDense
(
nn
.
Module
):
class
T5DenseReluDense
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
T5DenseReluDense
,
self
).
__init__
()
super
(
T5DenseReluDense
,
self
).
__init__
()
...
@@ -151,7 +166,7 @@ class T5LayerFF(nn.Module):
...
@@ -151,7 +166,7 @@ class T5LayerFF(nn.Module):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
T5LayerFF
,
self
).
__init__
()
super
(
T5LayerFF
,
self
).
__init__
()
self
.
DenseReluDense
=
T5DenseReluDense
(
config
)
self
.
DenseReluDense
=
T5DenseReluDense
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
layer_norm
=
T5
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
...
@@ -316,13 +331,14 @@ class T5Attention(nn.Module):
...
@@ -316,13 +331,14 @@ class T5Attention(nn.Module):
cache
[
self
.
layer_id
]
=
(
k
,
v
)
cache
[
self
.
layer_id
]
=
(
k
,
v
)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
# (bs, n_heads, qlen, klen)
scores
=
torch
.
einsum
(
'bnqd,bnkd->bnqk'
,
q
,
k
)
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
scores
+=
position_bias
special_out
=
position_bias
if
mask
is
not
None
:
if
mask
is
not
None
:
scores
+=
mask
scores
+=
mask
...
@@ -346,14 +362,14 @@ class T5Attention(nn.Module):
...
@@ -346,14 +362,14 @@ class T5Attention(nn.Module):
outputs
=
outputs
+
(
weights
,)
outputs
=
outputs
+
(
weights
,)
if
self
.
has_relative_attention_bias
:
if
self
.
has_relative_attention_bias
:
outputs
=
outputs
+
(
position_bias
,)
outputs
=
outputs
+
(
position_bias
,)
return
outputs
return
outputs
+
(
special_out
,)
class
T5LayerSelfAttention
(
nn
.
Module
):
class
T5LayerSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5LayerSelfAttention
,
self
).
__init__
()
super
(
T5LayerSelfAttention
,
self
).
__init__
()
self
.
SelfAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
SelfAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
layer_norm
=
T5
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
...
@@ -363,16 +379,18 @@ class T5LayerSelfAttention(nn.Module):
...
@@ -363,16 +379,18 @@ class T5LayerSelfAttention(nn.Module):
position_bias
=
position_bias
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
y
=
attention_output
[
0
]
special_out
=
attention_output
[
-
1
]
attention_output
=
attention_output
[:
-
1
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
return
outputs
+
(
special_out
,)
class
T5LayerCrossAttention
(
nn
.
Module
):
class
T5LayerCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5LayerCrossAttention
,
self
).
__init__
()
super
(
T5LayerCrossAttention
,
self
).
__init__
()
self
.
EncDecAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
EncDecAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
layer_norm
=
T5
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
def
forward
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
...
@@ -408,7 +426,8 @@ class T5Block(nn.Module):
...
@@ -408,7 +426,8 @@ class T5Block(nn.Module):
position_bias
=
position_bias
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
=
self_attention_outputs
[
0
]
hidden_states
=
self_attention_outputs
[
0
]
outputs
=
self_attention_outputs
[
1
:]
# Keep self-attention outputs and relative position weights
special_out
=
self_attention_outputs
[
-
1
]
outputs
=
self_attention_outputs
[
1
:
-
1
]
# Keep self-attention outputs and relative position weights
if
not
self
.
is_decoder
:
if
not
self
.
is_decoder
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
...
@@ -423,7 +442,7 @@ class T5Block(nn.Module):
...
@@ -423,7 +442,7 @@ class T5Block(nn.Module):
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
return
outputs
+
(
special_out
,)
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class
T5PreTrainedModel
(
PreTrainedModel
):
class
T5PreTrainedModel
(
PreTrainedModel
):
...
@@ -438,8 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
...
@@ -438,8 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
""" Initialize the weights """
factor
=
self
.
config
.
initializer_factor
# Used for testing weights initialization
factor
=
self
.
config
.
initializer_factor
# Used for testing weights initialization
if
isinstance
(
module
,
nn
.
LayerNorm
):
if
isinstance
(
module
,
T5LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
factor
*
1.0
)
module
.
weight
.
data
.
fill_
(
factor
*
1.0
)
elif
isinstance
(
module
,
(
T5Model
,
T5WithLMHeadModel
)):
elif
isinstance
(
module
,
(
T5Model
,
T5WithLMHeadModel
)):
# Mesh TensorFlow embeddings initialization
# Mesh TensorFlow embeddings initialization
...
@@ -478,7 +496,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -478,7 +496,7 @@ class T5Stack(T5PreTrainedModel):
self
.
block
=
nn
.
ModuleList
([
T5Block
(
config
,
has_relative_attention_bias
=
bool
(
i
==
0
))
self
.
block
=
nn
.
ModuleList
([
T5Block
(
config
,
has_relative_attention_bias
=
bool
(
i
==
0
))
for
i
in
range
(
config
.
num_layers
)])
for
i
in
range
(
config
.
num_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
final_layer_norm
=
T5
LayerNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
init_weights
()
self
.
init_weights
()
...
@@ -515,11 +533,11 @@ class T5Stack(T5PreTrainedModel):
...
@@ -515,11 +533,11 @@ class T5Stack(T5PreTrainedModel):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -1
0000.0
for masked positions.
# positions we want to attend and -1
e9
for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1
0000.0
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1
e9
if
self
.
is_decoder
:
if
self
.
is_decoder
:
# If a 2D ou 3D attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
...
@@ -530,7 +548,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -530,7 +548,7 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1
0000.0
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1
e9
else
:
else
:
encoder_extended_attention_mask
=
None
encoder_extended_attention_mask
=
None
...
@@ -553,6 +571,8 @@ class T5Stack(T5PreTrainedModel):
...
@@ -553,6 +571,8 @@ class T5Stack(T5PreTrainedModel):
all_attentions
=
()
all_attentions
=
()
position_bias
=
None
position_bias
=
None
encoder_decoder_position_bias
=
None
encoder_decoder_position_bias
=
None
hidden_states
=
self
.
dropout
(
hidden_states
)
for
i
,
layer_module
in
enumerate
(
self
.
block
):
for
i
,
layer_module
in
enumerate
(
self
.
block
):
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
...
@@ -564,6 +584,8 @@ class T5Stack(T5PreTrainedModel):
...
@@ -564,6 +584,8 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
[
i
])
head_mask
=
head_mask
[
i
])
if
i
==
0
:
special_out
=
layer_outputs
[
-
1
]
# layer_outputs is a tuple with:
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
...
@@ -588,7 +610,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -588,7 +610,7 @@ class T5Stack(T5PreTrainedModel):
outputs
=
outputs
+
(
all_hidden_states
,)
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
outputs
=
outputs
+
(
all_attentions
,)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
return
outputs
+
(
special_out
,)
# last-layer hidden state, (all hidden states), (all attentions)
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
...
@@ -707,9 +729,16 @@ class T5Model(T5PreTrainedModel):
...
@@ -707,9 +729,16 @@ class T5Model(T5PreTrainedModel):
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
encoder_attention_mask
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
if
encoder_attention_mask
is
not
None
:
# Apply masking
encoder_attention_mask
=
(
encoder_attention_mask
!=
0
).
to
(
hidden_states
)
hidden_states
=
hidden_states
*
encoder_attention_mask
.
unsqueeze
(
-
1
)
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
else
:
...
@@ -719,7 +748,7 @@ class T5Model(T5PreTrainedModel):
...
@@ -719,7 +748,7 @@ class T5Model(T5PreTrainedModel):
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_
encoder
.
get
(
"
attention_mask
"
,
None
)
kwargs_decoder
[
"encoder_attention_mask"
]
=
encoder
_
attention_mask
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment