Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
169fea68
Commit
169fea68
authored
Dec 09, 2019
by
thomwolf
Browse files
updating T5
parent
f3776df0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
18 deletions
+13
-18
transformers/modeling_t5.py
transformers/modeling_t5.py
+13
-18
No files found.
transformers/modeling_t5.py
View file @
169fea68
...
@@ -281,7 +281,7 @@ class T5Attention(nn.Module):
...
@@ -281,7 +281,7 @@ class T5Attention(nn.Module):
context_position
=
torch
.
arange
(
qlen
,
dtype
=
torch
.
long
)[:,
None
]
context_position
=
torch
.
arange
(
qlen
,
dtype
=
torch
.
long
)[:,
None
]
memory_position
=
torch
.
arange
(
klen
,
dtype
=
torch
.
long
)[
None
,
:]
memory_position
=
torch
.
arange
(
klen
,
dtype
=
torch
.
long
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (qlen, klen)
relative_position
=
memory_position
-
context_position
# shape (qlen, klen)
rp_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
rp_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (qlen, klen)
bidirectional
=
not
self
.
is_decoder
,
bidirectional
=
not
self
.
is_decoder
,
num_buckets
=
self
.
relative_attention_num_buckets
)
num_buckets
=
self
.
relative_attention_num_buckets
)
values
=
self
.
relative_attention_bias
(
rp_bucket
)
# shape (qlen, klen, num_heads)
values
=
self
.
relative_attention_bias
(
rp_bucket
)
# shape (qlen, klen, num_heads)
...
@@ -337,14 +337,10 @@ class T5Attention(nn.Module):
...
@@ -337,14 +337,10 @@ class T5Attention(nn.Module):
if
not
self
.
has_relative_attention_bias
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
special_out
=
position_bias
if
mask
is
not
None
:
if
mask
is
not
None
:
scores
+=
mask
position_bias
+=
mask
# (bs, n_heads, qlen, klen)
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
scores
+=
position_bias
weights
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
scores
)
# (bs, n_heads, qlen, klen)
weights
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
scores
)
# (bs, n_heads, qlen, klen)
weights
=
F
.
dropout
(
weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# (bs, n_heads, qlen, klen)
weights
=
F
.
dropout
(
weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# (bs, n_heads, qlen, klen)
...
@@ -362,7 +358,7 @@ class T5Attention(nn.Module):
...
@@ -362,7 +358,7 @@ class T5Attention(nn.Module):
outputs
=
outputs
+
(
weights
,)
outputs
=
outputs
+
(
weights
,)
if
self
.
has_relative_attention_bias
:
if
self
.
has_relative_attention_bias
:
outputs
=
outputs
+
(
position_bias
,)
outputs
=
outputs
+
(
position_bias
,)
return
outputs
+
(
special_out
,)
return
outputs
class
T5LayerSelfAttention
(
nn
.
Module
):
class
T5LayerSelfAttention
(
nn
.
Module
):
...
@@ -379,11 +375,9 @@ class T5LayerSelfAttention(nn.Module):
...
@@ -379,11 +375,9 @@ class T5LayerSelfAttention(nn.Module):
position_bias
=
position_bias
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
y
=
attention_output
[
0
]
special_out
=
attention_output
[
-
1
]
attention_output
=
attention_output
[:
-
1
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
+
(
special_out
,)
return
outputs
class
T5LayerCrossAttention
(
nn
.
Module
):
class
T5LayerCrossAttention
(
nn
.
Module
):
...
@@ -426,8 +420,7 @@ class T5Block(nn.Module):
...
@@ -426,8 +420,7 @@ class T5Block(nn.Module):
position_bias
=
position_bias
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
=
self_attention_outputs
[
0
]
hidden_states
=
self_attention_outputs
[
0
]
special_out
=
self_attention_outputs
[
-
1
]
outputs
=
self_attention_outputs
[
1
:]
# Keep self-attention outputs and relative position weights
outputs
=
self_attention_outputs
[
1
:
-
1
]
# Keep self-attention outputs and relative position weights
if
not
self
.
is_decoder
:
if
not
self
.
is_decoder
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
...
@@ -442,7 +435,7 @@ class T5Block(nn.Module):
...
@@ -442,7 +435,7 @@ class T5Block(nn.Module):
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
+
(
special_out
,)
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
return
outputs
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class
T5PreTrainedModel
(
PreTrainedModel
):
class
T5PreTrainedModel
(
PreTrainedModel
):
...
@@ -536,6 +529,10 @@ class T5Stack(T5PreTrainedModel):
...
@@ -536,6 +529,10 @@ class T5Stack(T5PreTrainedModel):
# positions we want to attend and -1e9 for masked positions.
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask
=
(
extended_attention_mask
==
extended_attention_mask
.
transpose
(
-
1
,
-
2
))
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
...
@@ -584,8 +581,6 @@ class T5Stack(T5PreTrainedModel):
...
@@ -584,8 +581,6 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
[
i
])
head_mask
=
head_mask
[
i
])
if
i
==
0
:
special_out
=
layer_outputs
[
-
1
]
# layer_outputs is a tuple with:
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
...
@@ -610,7 +605,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -610,7 +605,7 @@ class T5Stack(T5PreTrainedModel):
outputs
=
outputs
+
(
all_hidden_states
,)
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
outputs
=
outputs
+
(
all_attentions
,)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
+
(
special_out
,)
# last-layer hidden state, (all hidden states), (all attentions)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment