Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
568c0ffb
Commit
568c0ffb
authored
Nov 05, 2019
by
thomwolf
Browse files
adding T5 model
parent
60a5babd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
412 additions
and
63 deletions
+412
-63
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+1
-3
transformers/modeling_t5.py
transformers/modeling_t5.py
+411
-60
No files found.
transformers/modeling_encoder_decoder.py
View file @
568c0ffb
...
...
@@ -217,9 +217,7 @@ class PreTrainedEncoderDecoder(nn.Module):
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
...
...
transformers/modeling_t5.py
View file @
568c0ffb
# coding=utf-8
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
# Copyright 2018
Mesh TensorFlow authors,
T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,11 +20,14 @@ import json
import
logging
import
math
import
os
import
math
import
sys
import
itertools
from
io
import
open
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
...
...
@@ -119,31 +122,389 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
class
T5Layer
(
nn
.
Module
):
class
T5DenseReluDense
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5DenseReluDense
,
self
).
__init__
()
self
.
wi
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_ff
,
bias
=
False
)
self
.
wo
=
nn
.
Linear
(
config
.
d_ff
,
config
.
d_model
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
):
h
=
self
.
wi
(
hidden_states
)
h
=
F
.
relu
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
wo
(
h
)
return
h
class
T5LayerFF
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5Layer
,
self
).
__init__
()
self
.
attention
=
T5Attention
(
config
)
self
.
intermediate
=
T5Intermediate
(
config
)
self
.
output
=
T5Output
(
config
)
super
(
T5LayerFF
,
self
).
__init__
()
self
.
DenseReluDense
=
T5DenseReluDense
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
y
=
self
.
DenseReluDense
(
norm_x
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
return
layer_output
class
T5Attention
(
nn
.
Module
):
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
config
):
super
(
T5Attention
,
self
).
__init__
()
self
.
layer_id
=
next
(
T5Attention
.
NEW_ID
)
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
dim
=
config
.
d_model
self
.
n_heads
=
config
.
num_heads
self
.
dropout
=
config
.
dropout_rate
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
q
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
dim
,
self
.
dim
,
bias
=
False
)
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
# Prune linear layers
self
.
q
=
prune_linear_layer
(
self
.
q
,
index
)
self
.
k
=
prune_linear_layer
(
self
.
k
,
index
)
self
.
v
=
prune_linear_layer
(
self
.
v
,
index
)
self
.
o
=
prune_linear_layer
(
self
.
o
,
index
,
dim
=
1
)
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger buckets
for larger absolute relative_positions. All relative positions >=max_distance
map to the same bucket. All relative positions <=-max_distance map to the
same bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret
=
0
n
=
-
relative_position
if
bidirectional
:
num_buckets
//=
2
ret
+=
(
n
<
0
).
to
(
torch
.
long
)
*
num_buckets
# mtf.to_int32(mtf.less(n, 0)) * num_buckets
n
=
torch
.
abs
(
n
)
else
:
n
=
torch
.
max
(
n
,
0
)
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact
=
num_buckets
//
2
is_small
=
(
n
<
max_exact
)
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large
=
max_exact
+
(
torch
.
log
(
n
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
)).
to
(
torch
.
long
)
val_if_large
=
torch
.
min
(
val_if_large
,
num_buckets
-
1
)
ret
+=
torch
.
where
(
is_small
,
n
,
val_if_large
)
return
ret
def
compute_bias
(
self
,
qlen
,
klen
):
""" Compute binned relative position bias """
context_position
=
torch
.
arange
(
qlen
,
dtype
=
torch
.
long
)[:,
None
]
memory_position
=
torch
.
arange
(
klen
,
dtype
=
torch
.
long
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (qlen, klen)
rp_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
bidirectional
=
not
self
.
is_decoder
,
num_buckets
=
self
.
relative_attention_num_buckets
)
values
=
self
.
relative_attention_bias
(
rp_bucket
)
# shape (qlen, klen, num_heads)
values
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape (1, num_heads, qlen, klen)
return
values
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
position_bias
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs
,
qlen
,
dim
=
input
.
size
()
if
kv
is
None
:
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
kv
.
size
(
1
)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
mask_reshape
=
(
bs
,
1
,
qlen
,
klen
)
if
mask
.
dim
()
==
3
else
(
bs
,
1
,
1
,
klen
)
def
shape
(
x
):
""" projection """
return
x
.
view
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
).
transpose
(
1
,
2
)
def
unshape
(
x
):
""" compute context """
return
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
)
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
k
=
shape
(
self
.
k
(
input
))
# (bs, n_heads, qlen, dim_per_head)
v
=
shape
(
self
.
v
(
input
))
# (bs, n_heads, qlen, dim_per_head)
elif
cache
is
None
or
self
.
layer_id
not
in
cache
:
k
=
v
=
kv
k
=
shape
(
self
.
k
(
k
))
# (bs, n_heads, qlen, dim_per_head)
v
=
shape
(
self
.
v
(
v
))
# (bs, n_heads, qlen, dim_per_head)
if
cache
is
not
None
:
if
self
.
layer_id
in
cache
:
if
kv
is
None
:
k_
,
v_
=
cache
[
self
.
layer_id
]
k
=
torch
.
cat
([
k_
,
k
],
dim
=
2
)
# (bs, n_heads, klen, dim_per_head)
v
=
torch
.
cat
([
v_
,
v
],
dim
=
2
)
# (bs, n_heads, klen, dim_per_head)
else
:
k
,
v
=
cache
[
self
.
layer_id
]
cache
[
self
.
layer_id
]
=
(
k
,
v
)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
mask
=
(
mask
==
0
).
view
(
mask_reshape
).
expand_as
(
scores
)
# (bs, n_heads, qlen, klen)
scores
.
masked_fill_
(
mask
,
-
float
(
'inf'
))
# (bs, n_heads, qlen, klen)
weights
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
scores
)
# (bs, n_heads, qlen, klen)
weights
=
F
.
dropout
(
weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# (bs, n_heads, qlen, klen)
# Mask heads if we want to
if
head_mask
is
not
None
:
weights
=
weights
*
head_mask
context
=
torch
.
matmul
(
weights
,
v
)
# (bs, n_heads, qlen, dim_per_head)
context
=
unshape
(
context
)
# (bs, qlen, dim)
context
=
self
.
o
(
context
)
outputs
=
(
context
,)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
,)
return
outputs
class
T5LayerSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5LayerSelfAttention
,
self
).
__init__
()
self
.
SelfAttention
=
T5Attention
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
):
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
SelfAttention
(
norm_x
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
class
T5LayerCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5LayerCrossAttention
,
self
).
__init__
()
self
.
EncDecAttention
=
T5Attention
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
head_mask
=
None
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
EncDecAttention
(
norm_x
,
kv
=
kv
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
class
T5Block
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5Block
,
self
).
__init__
()
self
.
is_decoder
=
config
.
is_decoder
self
.
layer_000
=
T5LayerSelfAttention
(
config
)
if
self
.
is_decoder
:
self
.
layer_001
=
T5LayerCrossAttention
(
config
)
self
.
layer_002
=
T5LayerFF
(
config
)
else
:
self
.
layer_001
=
T5LayerFF
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
head_mask
=
None
):
self_attention_outputs
=
self
.
layer_000
(
hidden_states
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
hidden_states
=
self_attention_outputs
[
0
]
outputs
=
self_attention_outputs
[
1
:]
if
self
.
is_decoder
:
cross_attention_outputs
=
self
.
layer_001
(
hidden_states
,
kv
=
encoder_hidden_states
,
attention_mask
=
encoder_attention_mask
,
head_mask
=
head_mask
)
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
hidden_states
=
self
.
layer_002
(
hidden_states
)
else
:
hidden_states
=
self
.
layer_001
(
hidden_states
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
class
T5Stack
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5Stack
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
([
T5Block
(
config
)
for
_
in
range
(
config
.
num_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
layer_norm_epsilon
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
input_ids
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
all_hidden_states
=
()
all_attentions
=
()
position_bias
=
None
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
=
extended_attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
head_mask
=
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
layer_output
=
self
.
dropout
(
hidden_states
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
class
T5PreTrainedModel
(
PreTrainedModel
):
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
class
T5PreTrainedModel
(
PreTrainedEncoderDecoder
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
T5Config
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
...
...
@@ -238,19 +599,23 @@ class T5Model(T5PreTrainedModel):
"""
def
__init__
(
self
,
config
):
super
(
T5Model
,
self
).
__init__
(
config
)
self
.
shared
=
nn
.
Embeddings
(
config
.
vocab_size
,
config
.
d_model
)
self
.
embeddings
=
T5Embeddings
(
config
)
self
.
encoder
=
T5Encoder
(
config
)
self
.
pooler
=
T5Pooler
(
config
)
encoder_config
=
copy
.
deepcopy
(
config
)
self
.
encoder
=
T5Stack
(
encoder_config
)
decoder_config
=
copy
.
deepcopy
(
config
)
decoder_config
.
is_decoder
=
True
self
.
decoder
=
T5Stack
(
decoder_config
)
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
shared
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
.
word_embeddings
=
new_embeddings
self
.
shared
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
...
...
@@ -260,50 +625,36 @@ class T5Model(T5PreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common
=
dict
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"encoder_"
)
and
not
k
.
startswith
(
"decoder_"
))
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
dict
((
k
[
len
(
"encoder_"
)
:],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"encoder_"
)))
kwargs_decoder
.
update
(
dict
((
k
[
len
(
"decoder_"
)
:],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"decoder_"
)))
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
encoder_outputs
=
()
#
#################################
# Replace this with your model code
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
encoder_outputs
[
0
]
outputs
=
(
sequence_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
#
Decode
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_
input
s
_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
return
outputs
# sequence_output, (hidden_states), (attentions)
return
decoder_outputs
+
encoder_outputs
@
add_start_docstrings
(
"""T5 Model with a `language modeling` head on top. """
,
...
...
@@ -342,7 +693,7 @@ class T5WithLMHead(T5PreTrainedModel):
super
(
T5ForMaskedLM
,
self
).
__init__
(
config
)
self
.
transformer
=
T5Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
)
self
.
lm_head
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
)
self
.
init_weights
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment