Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbd04124
Unverified
Commit
dbd04124
authored
Apr 16, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 16, 2020
Browse files
[cleanup] factor out get_head_mask, invert_attn_mask, get_exten… (#3806)
* Delete some copy pasted code
parent
d22894df
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
132 additions
and
379 deletions
+132
-379
src/transformers/modeling_albert.py
src/transformers/modeling_albert.py
+1
-13
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+5
-59
src/transformers/modeling_ctrl.py
src/transformers/modeling_ctrl.py
+2
-17
src/transformers/modeling_distilbert.py
src/transformers/modeling_distilbert.py
+1
-17
src/transformers/modeling_electra.py
src/transformers/modeling_electra.py
+1
-60
src/transformers/modeling_flaubert.py
src/transformers/modeling_flaubert.py
+1
-17
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+1
-13
src/transformers/modeling_mmbt.py
src/transformers/modeling_mmbt.py
+5
-57
src/transformers/modeling_openai.py
src/transformers/modeling_openai.py
+1
-16
src/transformers/modeling_t5.py
src/transformers/modeling_t5.py
+4
-62
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+106
-18
src/transformers/modeling_xlm.py
src/transformers/modeling_xlm.py
+1
-17
templates/adding_a_new_model/modeling_xxx.py
templates/adding_a_new_model/modeling_xxx.py
+3
-13
No files found.
src/transformers/modeling_albert.py
View file @
dbd04124
...
@@ -552,19 +552,7 @@ class AlbertModel(AlbertPreTrainedModel):
...
@@ -552,19 +552,7 @@ class AlbertModel(AlbertPreTrainedModel):
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
if
head_mask
is
not
None
:
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
...
...
src/transformers/modeling_bert.py
View file @
dbd04124
...
@@ -703,36 +703,9 @@ class BertModel(BertPreTrainedModel):
...
@@ -703,36 +703,9 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
:
torch
.
Tensor
=
self
.
get_extended_attention_mask
(
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
attention_mask
,
input_shape
,
self
.
device
elif
attention_mask
.
dim
()
==
2
:
)
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_shape
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
causal_mask
.
to
(
attention_mask
.
dtype
)
# causal and attention masks must have same type with pytorch version < 1.3
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
else
:
raise
ValueError
(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
.
format
(
input_shape
,
attention_mask
.
shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
...
@@ -741,22 +714,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -741,22 +714,7 @@ class BertModel(BertPreTrainedModel):
encoder_hidden_shape
=
(
encoder_batch_size
,
encoder_sequence_length
)
encoder_hidden_shape
=
(
encoder_batch_size
,
encoder_sequence_length
)
if
encoder_attention_mask
is
None
:
if
encoder_attention_mask
is
None
:
encoder_attention_mask
=
torch
.
ones
(
encoder_hidden_shape
,
device
=
device
)
encoder_attention_mask
=
torch
.
ones
(
encoder_hidden_shape
,
device
=
device
)
encoder_extended_attention_mask
=
self
.
invert_attention_mask
(
encoder_attention_mask
)
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
elif
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
else
:
raise
ValueError
(
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})"
.
format
(
encoder_hidden_shape
,
encoder_attention_mask
.
shape
)
)
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
else
:
else
:
encoder_extended_attention_mask
=
None
encoder_extended_attention_mask
=
None
...
@@ -765,19 +723,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -765,19 +723,7 @@ class BertModel(BertPreTrainedModel):
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
embedding_output
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
...
...
src/transformers/modeling_ctrl.py
View file @
dbd04124
...
@@ -392,26 +392,11 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -392,26 +392,11 @@ class CTRLModel(CTRLPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions.
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# effectively the same as removing these entirely.
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
())
.
dtype
)
# fp16 compatibility
attention_mask
=
attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layer
)
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
...
...
src/transformers/modeling_distilbert.py
View file @
dbd04124
...
@@ -460,23 +460,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -460,23 +460,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
attention_mask
=
torch
.
ones
(
input_shape
,
device
=
device
)
# (bs, seq_length)
attention_mask
=
torch
.
ones
(
input_shape
,
device
=
device
)
# (bs, seq_length)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
inputs_embeds
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
...
...
src/transformers/modeling_electra.py
View file @
dbd04124
...
@@ -164,65 +164,6 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
...
@@ -164,65 +164,6 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
load_tf_weights
=
load_tf_weights_in_electra
load_tf_weights
=
load_tf_weights_in_electra
base_model_prefix
=
"electra"
base_model_prefix
=
"electra"
def
get_extended_attention_mask
(
self
,
attention_mask
,
input_shape
,
device
):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
attention_mask
.
dim
()
==
2
:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_shape
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
causal_mask
.
to
(
attention_mask
.
dtype
)
# causal and attention masks must have same type with pytorch version < 1.3
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
else
:
raise
ValueError
(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
.
format
(
input_shape
,
attention_mask
.
shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
def
get_head_mask
(
self
,
head_mask
):
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
num_hidden_layers
=
self
.
config
.
num_hidden_layers
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
num_hidden_layers
return
head_mask
ELECTRA_START_DOCSTRING
=
r
"""
ELECTRA_START_DOCSTRING
=
r
"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
...
@@ -376,7 +317,7 @@ class ElectraModel(ElectraPreTrainedModel):
...
@@ -376,7 +317,7 @@ class ElectraModel(ElectraPreTrainedModel):
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
device
)
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
device
)
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
,
device
)
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
,
device
)
head_mask
=
self
.
get_head_mask
(
head_mask
)
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
hidden_states
=
self
.
embeddings
(
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
input_ids
=
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
...
...
src/transformers/modeling_flaubert.py
View file @
dbd04124
...
@@ -201,23 +201,7 @@ class FlaubertModel(XLMModel):
...
@@ -201,23 +201,7 @@ class FlaubertModel(XLMModel):
# langs = langs.transpose(0, 1)
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layers
)
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
n_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
n_layers
# do not recompute cached elements
# do not recompute cached elements
if
cache
is
not
None
and
input_ids
is
not
None
:
if
cache
is
not
None
and
input_ids
is
not
None
:
...
...
src/transformers/modeling_gpt2.py
View file @
dbd04124
...
@@ -471,19 +471,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -471,19 +471,7 @@ class GPT2Model(GPT2PreTrainedModel):
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layer
)
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
wte
(
input_ids
)
inputs_embeds
=
self
.
wte
(
input_ids
)
...
...
src/transformers/modeling_mmbt.py
View file @
dbd04124
...
@@ -23,6 +23,7 @@ import torch.nn as nn
...
@@ -23,6 +23,7 @@ import torch.nn as nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_utils
import
ModuleUtilsMixin
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -148,7 +149,7 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
...
@@ -148,7 +149,7 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
MMBT_START_DOCSTRING
,
MMBT_START_DOCSTRING
,
MMBT_INPUTS_DOCSTRING
,
MMBT_INPUTS_DOCSTRING
,
)
)
class
MMBTModel
(
nn
.
Module
):
class
MMBTModel
(
Module
UtilsMixin
):
r
"""
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
...
@@ -237,7 +238,6 @@ class MMBTModel(nn.Module):
...
@@ -237,7 +238,6 @@ class MMBTModel(nn.Module):
attention_mask
=
torch
.
cat
(
attention_mask
=
torch
.
cat
(
[
torch
.
ones
(
input_modal_shape
,
device
=
device
,
dtype
=
torch
.
long
),
attention_mask
],
dim
=
1
[
torch
.
ones
(
input_modal_shape
,
device
=
device
,
dtype
=
torch
.
long
),
attention_mask
],
dim
=
1
)
)
if
encoder_attention_mask
is
None
:
if
encoder_attention_mask
is
None
:
encoder_attention_mask
=
torch
.
ones
(
input_shape
,
device
=
device
)
encoder_attention_mask
=
torch
.
ones
(
input_shape
,
device
=
device
)
else
:
else
:
...
@@ -245,61 +245,9 @@ class MMBTModel(nn.Module):
...
@@ -245,61 +245,9 @@ class MMBTModel(nn.Module):
[
torch
.
ones
(
input_modal_shape
,
device
=
device
),
encoder_attention_mask
],
dim
=
1
[
torch
.
ones
(
input_modal_shape
,
device
=
device
),
encoder_attention_mask
],
dim
=
1
)
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
,
self
.
device
)
# ourselves in which case we just need to make it broadcastable to all heads.
encoder_extended_attention_mask
=
self
.
invert_attention_mask
(
encoder_attention_mask
)
if
attention_mask
.
dim
()
==
3
:
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_shape
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
encoder_outputs
=
self
.
transformer
.
encoder
(
encoder_outputs
=
self
.
transformer
.
encoder
(
embedding_output
,
embedding_output
,
...
...
src/transformers/modeling_openai.py
View file @
dbd04124
...
@@ -425,22 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -425,22 +425,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layer
)
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
tokens_embed
(
input_ids
)
inputs_embeds
=
self
.
tokens_embed
(
input_ids
)
...
...
src/transformers/modeling_t5.py
View file @
dbd04124
...
@@ -184,7 +184,7 @@ class T5LayerFF(nn.Module):
...
@@ -184,7 +184,7 @@ class T5LayerFF(nn.Module):
class
T5Attention
(
nn
.
Module
):
class
T5Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
):
def
__init__
(
self
,
config
:
T5Config
,
has_relative_attention_bias
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
is_decoder
=
config
.
is_decoder
self
.
is_decoder
=
config
.
is_decoder
self
.
has_relative_attention_bias
=
has_relative_attention_bias
self
.
has_relative_attention_bias
=
has_relative_attention_bias
...
@@ -693,73 +693,15 @@ class T5Stack(T5PreTrainedModel):
...
@@ -693,73 +693,15 @@ class T5Stack(T5PreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
,
self
.
device
)
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
attention_mask
.
dim
()
==
2
:
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if
self
.
config
.
is_decoder
:
seq_ids
=
torch
.
arange
(
mask_seq_length
,
device
=
inputs_embeds
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
mask_seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
causal_mask
.
to
(
attention_mask
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
if
past_key_value_states
[
0
]
is
not
None
:
extended_attention_mask
=
extended_attention_mask
[:,
:,
-
1
:,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
1e9
if
self
.
is_decoder
and
encoder_attention_mask
is
not
None
:
if
self
.
is_decoder
and
encoder_attention_mask
is
not
None
:
# If a 2D ou 3D attention mask is provided for the cross-attention
encoder_extended_attention_mask
=
self
.
invert_attention_mask
(
encoder_attention_mask
)
# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
else
:
else
:
encoder_extended_attention_mask
=
None
encoder_extended_attention_mask
=
None
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_layers
)
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_layers
present_key_value_states
=
()
present_key_value_states
=
()
all_hidden_states
=
()
all_hidden_states
=
()
all_attentions
=
()
all_attentions
=
()
...
...
src/transformers/modeling_utils.py
View file @
dbd04124
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
import
logging
import
logging
import
os
import
os
import
typing
from
typing
import
Callable
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
...
@@ -109,9 +109,102 @@ class ModuleUtilsMixin:
...
@@ -109,9 +109,102 @@ class ModuleUtilsMixin:
module
.
mem_rss_pre_forward
=
0
module
.
mem_rss_pre_forward
=
0
@
property
@
property
def
device
(
self
):
def
device
(
self
)
->
device
:
return
next
(
self
.
parameters
()).
device
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
)
->
dtype
:
return
next
(
self
.
parameters
()).
dtype
def
invert_attention_mask
(
self
,
encoder_attention_mask
:
Tensor
)
->
Tensor
:
"""type: torch.Tensor -> torch.Tensor"""
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
1e9
return
encoder_extended_attention_mask
def
get_extended_attention_mask
(
self
,
attention_mask
:
Tensor
,
input_shape
:
tuple
,
device
:
device
):
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
Arguments:
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
input_shape: tuple, shape of input_ids
device: torch.Device, usually self.device
Returns:
torch.Tensor with dtype of attention_mask.dtype
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
attention_mask
.
dim
()
==
2
:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_shape
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask
=
causal_mask
.
to
(
attention_mask
.
dtype
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
else
:
raise
ValueError
(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
.
format
(
input_shape
,
attention_mask
.
shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
def
get_head_mask
(
self
,
head_mask
,
num_hidden_layers
):
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
attention_probs has shape bsz x n_heads x N x N
Arguments:
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
num_hidden_layers: int
Returns:
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
or list with [None] for each layer
"""
if
head_mask
is
not
None
:
head_mask
=
self
.
_convert_head_mask_to_5d
(
head_mask
,
num_hidden_layers
)
else
:
head_mask
=
[
None
]
*
num_hidden_layers
return
head_mask
def
_convert_head_mask_to_5d
(
self
,
head_mask
,
num_hidden_layers
):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
assert
head_mask
.
dim
()
==
5
,
f
"head_mask.dim != 5, instead
{
head_mask
.
dim
()
}
"
head_mask
=
head_mask
.
to
(
dtype
=
self
.
dtype
)
# switch to fload if need + fp16 compatibility
return
head_mask
class
PreTrainedModel
(
nn
.
Module
,
ModuleUtilsMixin
):
class
PreTrainedModel
(
nn
.
Module
,
ModuleUtilsMixin
):
r
""" Base class for all models.
r
""" Base class for all models.
...
@@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# If we save using the predefined names, we can load using `from_pretrained`
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file
=
os
.
path
.
join
(
save_directory
,
WEIGHTS_NAME
)
output_model_file
=
os
.
path
.
join
(
save_directory
,
WEIGHTS_NAME
)
if
has
attr
(
self
.
config
,
"xla_device"
)
and
self
.
config
.
xla_device
:
if
get
attr
(
self
.
config
,
"xla_device"
,
False
)
:
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
if
xm
.
is_master_ordinal
():
if
xm
.
is_master_ordinal
():
...
@@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Make sure we are able to load base models as well as derived models (with heads)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix
=
""
start_prefix
=
""
model_to_load
=
model
model_to_load
=
model
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
has_prefix_module
=
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
())
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
has_prefix_module
:
):
start_prefix
=
cls
.
base_model_prefix
+
"."
start_prefix
=
cls
.
base_model_prefix
+
"."
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
any
(
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
has_prefix_module
:
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()
):
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
...
@@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
)
model
.
tie_weights
()
# make sure token embedding weights are still tied if needed
model
.
tie_weights
()
# make sure token embedding weights are still tied if needed
# Set model in evaluation mode to de
s
activate DropOut modules by default
# Set model in evaluation mode to deactivate DropOut modules by default
model
.
eval
()
model
.
eval
()
if
output_loading_info
:
if
output_loading_info
:
...
@@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# get encoder and store encoder outputs
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder
=
self
.
get_encoder
()
encoder_outputs
=
encoder
(
input_ids
,
attention_mask
=
attention_mask
)
encoder_outputs
:
tuple
=
encoder
(
input_ids
,
attention_mask
=
attention_mask
)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
if
num_return_sequences
>
1
or
num_beams
>
1
:
...
@@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores
[:,
all_but_token_ids_mask
]
=
-
float
(
"inf"
)
scores
[:,
all_but_token_ids_mask
]
=
-
float
(
"inf"
)
@
staticmethod
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
)
:
def
_reorder_cache
(
past
:
Tuple
,
beam_idx
:
Tensor
)
->
Tuple
[
Tensor
]
:
return
tuple
(
layer_past
.
index_select
(
1
,
beam_idx
)
for
layer_past
in
past
)
return
tuple
(
layer_past
.
index_select
(
1
,
beam_idx
)
for
layer_past
in
past
)
def
calc_banned_ngram_tokens
(
prev_input_ids
,
num_hypos
,
no_repeat_ngram_size
,
cur_len
)
:
def
calc_banned_ngram_tokens
(
prev_input_ids
:
Tensor
,
num_hypos
:
int
,
no_repeat_ngram_size
:
int
,
cur_len
:
int
)
->
None
:
#
Copied from fairseq for no_repeat_ngram in beam_search"""
"""
Copied from fairseq for no_repeat_ngram in beam_search"""
if
cur_len
+
1
<
no_repeat_ngram_size
:
if
cur_len
+
1
<
no_repeat_ngram_size
:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return
[[]
for
_
in
range
(
num_hypos
)]
return
[[]
for
_
in
range
(
num_hypos
)]
...
@@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module):
...
@@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module):
self
.
summary
=
nn
.
Linear
(
config
.
hidden_size
,
num_classes
)
self
.
summary
=
nn
.
Linear
(
config
.
hidden_size
,
num_classes
)
activation_string
=
getattr
(
config
,
"summary_activation"
,
None
)
activation_string
=
getattr
(
config
,
"summary_activation"
,
None
)
self
.
activation
=
(
self
.
activation
:
Callable
=
(
get_activation
(
activation_string
)
if
activation_string
else
Identity
())
get_activation
(
activation_string
)
if
activation_string
else
Identity
()
)
# type: typing.Callable
self
.
first_dropout
=
Identity
()
self
.
first_dropout
=
Identity
()
if
hasattr
(
config
,
"summary_first_dropout"
)
and
config
.
summary_first_dropout
>
0
:
if
hasattr
(
config
,
"summary_first_dropout"
)
and
config
.
summary_first_dropout
>
0
:
...
...
src/transformers/modeling_xlm.py
View file @
dbd04124
...
@@ -479,23 +479,7 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -479,23 +479,7 @@ class XLMModel(XLMPreTrainedModel):
# langs = langs.transpose(0, 1)
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layers
)
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
n_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
n_layers
# do not recompute cached elements
# do not recompute cached elements
if
cache
is
not
None
and
input_ids
is
not
None
:
if
cache
is
not
None
and
input_ids
is
not
None
:
...
...
templates/adding_a_new_model/modeling_xxx.py
View file @
dbd04124
...
@@ -349,10 +349,12 @@ class XxxModel(XxxPreTrainedModel):
...
@@ -349,10 +349,12 @@ class XxxModel(XxxPreTrainedModel):
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
device
)
token_type_ids
=
torch
.
zeros
(
input_shape
,
dtype
=
torch
.
long
,
device
=
device
)
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
# (this can be done with self.invert_attention_mask)
# Sizes are [batch_size, 1, 1, to_seq_length]
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
...
@@ -368,19 +370,7 @@ class XxxModel(XxxPreTrainedModel):
...
@@ -368,19 +370,7 @@ class XxxModel(XxxPreTrainedModel):
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
num_hidden_layers
)
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
(
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
##################################
##################################
# Replace this with your model code
# Replace this with your model code
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment