Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88e5bef5
Commit
88e5bef5
authored
Nov 05, 2019
by
thomwolf
Browse files
share position biases
parent
568c0ffb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
25 deletions
+40
-25
transformers/modeling_t5.py
transformers/modeling_t5.py
+40
-25
No files found.
transformers/modeling_t5.py
View file @
88e5bef5
...
@@ -154,9 +154,10 @@ class T5LayerFF(nn.Module):
...
@@ -154,9 +154,10 @@ class T5LayerFF(nn.Module):
class
T5Attention
(
nn
.
Module
):
class
T5Attention
(
nn
.
Module
):
NEW_ID
=
itertools
.
count
()
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5Attention
,
self
).
__init__
()
super
(
T5Attention
,
self
).
__init__
()
self
.
layer_id
=
next
(
T5Attention
.
NEW_ID
)
self
.
layer_id
=
next
(
T5Attention
.
NEW_ID
)
self
.
has_relative_attention_bias
=
has_relative_attention_bias
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
...
@@ -170,7 +171,8 @@ class T5Attention(nn.Module):
...
@@ -170,7 +171,8 @@ class T5Attention(nn.Module):
self
.
v
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
self
.
pruned_heads
=
set
()
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
...
@@ -304,6 +306,8 @@ class T5Attention(nn.Module):
...
@@ -304,6 +306,8 @@ class T5Attention(nn.Module):
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
# (bs, n_heads, qlen, klen)
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
scores
+=
position_bias
...
@@ -325,20 +329,23 @@ class T5Attention(nn.Module):
...
@@ -325,20 +329,23 @@ class T5Attention(nn.Module):
outputs
=
(
context
,)
outputs
=
(
context
,)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
,)
outputs
=
outputs
+
(
weights
,)
if
self
.
has_relative_attention_bias
:
outputs
=
outputs
+
(
position_bias
,)
return
outputs
return
outputs
class
T5LayerSelfAttention
(
nn
.
Module
):
class
T5LayerSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5LayerSelfAttention
,
self
).
__init__
()
super
(
T5LayerSelfAttention
,
self
).
__init__
()
self
.
SelfAttention
=
T5Attention
(
config
)
self
.
SelfAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
SelfAttention
(
norm_x
,
attention_output
=
self
.
SelfAttention
(
norm_x
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
...
@@ -347,17 +354,18 @@ class T5LayerSelfAttention(nn.Module):
...
@@ -347,17 +354,18 @@ class T5LayerSelfAttention(nn.Module):
class
T5LayerCrossAttention
(
nn
.
Module
):
class
T5LayerCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5LayerCrossAttention
,
self
).
__init__
()
super
(
T5LayerCrossAttention
,
self
).
__init__
()
self
.
EncDecAttention
=
T5Attention
(
config
)
self
.
EncDecAttention
=
T5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
head_mask
=
None
):
def
forward
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
EncDecAttention
(
norm_x
,
attention_output
=
self
.
EncDecAttention
(
norm_x
,
kv
=
kv
,
kv
=
kv
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
...
@@ -366,20 +374,22 @@ class T5LayerCrossAttention(nn.Module):
...
@@ -366,20 +374,22 @@ class T5LayerCrossAttention(nn.Module):
class
T5Block
(
nn
.
Module
):
class
T5Block
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
super
(
T5Block
,
self
).
__init__
()
super
(
T5Block
,
self
).
__init__
()
self
.
is_decoder
=
config
.
is_decoder
self
.
is_decoder
=
config
.
is_decoder
self
.
layer_000
=
T5LayerSelfAttention
(
config
)
self
.
layer_000
=
T5LayerSelfAttention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
if
self
.
is_decoder
:
if
self
.
is_decoder
:
self
.
layer_001
=
T5LayerCrossAttention
(
config
)
self
.
layer_001
=
T5LayerCrossAttention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
)
self
.
layer_002
=
T5LayerFF
(
config
)
self
.
layer_002
=
T5LayerFF
(
config
)
else
:
else
:
self
.
layer_001
=
T5LayerFF
(
config
)
self
.
layer_001
=
T5LayerFF
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
head_mask
=
None
):
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
encoder_decoder_position_bias
=
None
,
head_mask
=
None
):
self_attention_outputs
=
self
.
layer_000
(
hidden_states
,
self_attention_outputs
=
self
.
layer_000
(
hidden_states
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
position_bias
=
position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
=
self_attention_outputs
[
0
]
hidden_states
=
self_attention_outputs
[
0
]
outputs
=
self_attention_outputs
[
1
:]
outputs
=
self_attention_outputs
[
1
:]
...
@@ -388,6 +398,7 @@ class T5Block(nn.Module):
...
@@ -388,6 +398,7 @@ class T5Block(nn.Module):
cross_attention_outputs
=
self
.
layer_001
(
hidden_states
,
cross_attention_outputs
=
self
.
layer_001
(
hidden_states
,
kv
=
encoder_hidden_states
,
kv
=
encoder_hidden_states
,
attention_mask
=
encoder_attention_mask
,
attention_mask
=
encoder_attention_mask
,
position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
hidden_states
=
cross_attention_outputs
[
0
]
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
...
@@ -402,7 +413,8 @@ class T5Block(nn.Module):
...
@@ -402,7 +413,8 @@ class T5Block(nn.Module):
class
T5Stack
(
nn
.
Module
):
class
T5Stack
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
T5Stack
,
self
).
__init__
()
super
(
T5Stack
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
T5Block
(
config
)
for
_
in
range
(
config
.
num_layers
)])
self
.
blocks
=
nn
.
ModuleList
([
T5Block
(
config
,
has_relative_attention_bias
=
bool
(
i
==
0
))
for
i
in
range
(
config
.
num_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
@@ -413,8 +425,12 @@ class T5Stack(nn.Module):
...
@@ -413,8 +425,12 @@ class T5Stack(nn.Module):
encoder_attention_mask
=
None
,
encoder_attention_mask
=
None
,
head_mask
=
None
):
head_mask
=
None
):
batch_size
,
seq_length
=
hidden_states
.
shape
[
0
],
hidden_states
.
shape
[
1
]
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
if
encoder_hidden_states
is
not
None
else
0
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
attention_mask
=
torch
.
ones
(
batch_size
,
seq_length
).
to
(
hidden_states
.
device
)
if
encoder_attention_mask
is
None
:
encoder_attention_mask
=
torch
.
ones
(
batch_size
,
encoder_seq_length
).
to
(
hidden_states
.
device
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
...
@@ -426,8 +442,7 @@ class T5Stack(nn.Module):
...
@@ -426,8 +442,7 @@ class T5Stack(nn.Module):
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
hidden_states
.
device
)
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
input_ids
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
else
:
...
@@ -469,16 +484,22 @@ class T5Stack(nn.Module):
...
@@ -469,16 +484,22 @@ class T5Stack(nn.Module):
all_hidden_states
=
()
all_hidden_states
=
()
all_attentions
=
()
all_attentions
=
()
position_bias
=
None
position_bias
=
None
encoder_decoder_position_bias
=
None
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_outputs
=
layer_module
(
hidden_states
,
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
=
extended_attention_mask
,
attention_mask
=
extended_attention_mask
,
position_bias
=
position_bias
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
[
i
])
head_mask
=
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
position_bias
=
layer_outputs
[
2
]
if
len
(
layer_outputs
)
>
3
else
None
encoder_decoder_position_bias
=
layer_outputs
[
4
]
if
len
(
layer_outputs
)
>
5
else
None
if
self
.
output_attentions
:
if
self
.
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
...
@@ -700,14 +721,8 @@ class T5WithLMHead(T5PreTrainedModel):
...
@@ -700,14 +721,8 @@ class T5WithLMHead(T5PreTrainedModel):
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
lm_labels
=
None
):
outputs
=
self
.
transformer
(
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
)
outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
lm_logits
=
self
.
cls
(
sequence_output
)
lm_logits
=
self
.
cls
(
sequence_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment