Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6200fd7b
Unverified
Commit
6200fd7b
authored
Oct 27, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 27, 2021
Browse files
[Gradient checkpointing] Enable for Deberta + DebertaV2 + SEW-D (#14175)
* up * up * finish * up * final changes
parent
e1dc5afd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
59 deletions
+131
-59
src/transformers/models/deberta/modeling_deberta.py
src/transformers/models/deberta/modeling_deberta.py
+44
-20
src/transformers/models/deberta_v2/modeling_deberta_v2.py
src/transformers/models/deberta_v2/modeling_deberta_v2.py
+44
-20
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+43
-19
No files found.
src/transformers/models/deberta/modeling_deberta.py
View file @
6200fd7b
...
...
@@ -272,7 +272,7 @@ class DebertaAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -280,18 +280,18 @@ class DebertaAttention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -339,24 +339,24 @@ class DebertaLayer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -374,6 +374,7 @@ class DebertaEncoder(nn.Module):
if
self
.
max_relative_positions
<
1
:
self
.
max_relative_positions
=
config
.
max_position_embeddings
self
.
rel_embeddings
=
nn
.
Embedding
(
self
.
max_relative_positions
*
2
,
config
.
hidden_size
)
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -421,14 +422,32 @@ class DebertaEncoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
hidden_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
hidden_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
hidden_states
,
att_m
=
hidden_states
...
...
@@ -547,7 +566,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -565,7 +584,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -629,7 +648,7 @@ class DisentangledSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -774,6 +793,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"deberta"
_keys_to_ignore_on_load_missing
=
[
"position_ids"
]
_keys_to_ignore_on_load_unexpected
=
[
"position_embeddings"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
@@ -788,6 +808,10 @@ class DebertaPreTrainedModel(PreTrainedModel):
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
DebertaEncoder
):
module
.
gradient_checkpointing
=
value
DEBERTA_START_DOCSTRING
=
r
"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
...
...
@@ -947,7 +971,7 @@ class DebertaModel(DebertaPreTrainedModel):
query_states
=
layer
(
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
query_states
,
relative_pos
=
rel_pos
,
rel_embeddings
=
rel_embeddings
,
...
...
src/transformers/models/deberta_v2/modeling_deberta_v2.py
View file @
6200fd7b
...
...
@@ -259,7 +259,7 @@ class DebertaV2Attention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -267,18 +267,18 @@ class DebertaV2Attention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -328,24 +328,24 @@ class DebertaV2Layer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -415,6 +415,7 @@ class DebertaV2Encoder(nn.Module):
self
.
LayerNorm
=
LayerNorm
(
config
.
hidden_size
,
config
.
layer_norm_eps
,
elementwise_affine
=
True
)
self
.
conv
=
ConvLayer
(
config
)
if
getattr
(
config
,
"conv_kernel_size"
,
0
)
>
0
else
None
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -471,14 +472,32 @@ class DebertaV2Encoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
output_states
,)
output_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
output_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
output_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
output_states
,
att_m
=
output_states
...
...
@@ -619,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -637,7 +656,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -696,7 +715,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -881,6 +900,7 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
base_model_prefix
=
"deberta"
_keys_to_ignore_on_load_missing
=
[
"position_ids"
]
_keys_to_ignore_on_load_unexpected
=
[
"position_embeddings"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
@@ -895,6 +915,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
DebertaV2Encoder
):
module
.
gradient_checkpointing
=
value
DEBERTA_START_DOCSTRING
=
r
"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
...
...
@@ -1055,7 +1079,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
query_states
=
layer
(
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
query_states
,
relative_pos
=
rel_pos
,
rel_embeddings
=
rel_embeddings
,
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
6200fd7b
...
...
@@ -661,7 +661,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -679,7 +679,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -738,7 +738,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -849,7 +849,7 @@ class SEWDAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -857,18 +857,18 @@ class SEWDAttention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -918,24 +918,24 @@ class SEWDLayer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -1007,6 +1007,7 @@ class SEWDTransformerEncoder(nn.Module):
self
.
LayerNorm
=
LayerNorm
(
config
.
hidden_size
,
config
.
layer_norm_eps
,
elementwise_affine
=
True
)
self
.
conv
=
ConvLayer
(
config
)
if
getattr
(
config
,
"conv_kernel_size"
,
0
)
>
0
else
None
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -1063,14 +1064,32 @@ class SEWDTransformerEncoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
output_states
,)
output_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
output_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
output_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
output_states
,
att_m
=
output_states
...
...
@@ -1169,6 +1188,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class
=
SEWDConfig
base_model_prefix
=
"sew-d"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
...
...
@@ -1233,6 +1253,10 @@ class SEWDPreTrainedModel(PreTrainedModel):
attention_mask
=
attention_mask
.
flip
([
-
1
]).
cumsum
(
-
1
).
flip
([
-
1
]).
bool
()
return
attention_mask
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
SEWDTransformerEncoder
):
module
.
gradient_checkpointing
=
value
SEWD_START_DOCSTRING
=
r
"""
SEW-D was proposed in `Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment