Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6200fd7b
Unverified
Commit
6200fd7b
authored
Oct 27, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 27, 2021
Browse files
[Gradient checkpointing] Enable for Deberta + DebertaV2 + SEW-D (#14175)
* up * up * finish * up * final changes
parent
e1dc5afd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
59 deletions
+131
-59
src/transformers/models/deberta/modeling_deberta.py
src/transformers/models/deberta/modeling_deberta.py
+44
-20
src/transformers/models/deberta_v2/modeling_deberta_v2.py
src/transformers/models/deberta_v2/modeling_deberta_v2.py
+44
-20
src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/sew_d/modeling_sew_d.py
+43
-19
No files found.
src/transformers/models/deberta/modeling_deberta.py
View file @
6200fd7b
...
...
@@ -272,7 +272,7 @@ class DebertaAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -280,18 +280,18 @@ class DebertaAttention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -339,24 +339,24 @@ class DebertaLayer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -374,6 +374,7 @@ class DebertaEncoder(nn.Module):
if
self
.
max_relative_positions
<
1
:
self
.
max_relative_positions
=
config
.
max_position_embeddings
self
.
rel_embeddings
=
nn
.
Embedding
(
self
.
max_relative_positions
*
2
,
config
.
hidden_size
)
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -421,14 +422,32 @@ class DebertaEncoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
hidden_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
hidden_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
hidden_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
hidden_states
,
att_m
=
hidden_states
...
...
@@ -547,7 +566,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -565,7 +584,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -629,7 +648,7 @@ class DisentangledSelfAttention(nn.Module):
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -774,6 +793,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"deberta"
_keys_to_ignore_on_load_missing
=
[
"position_ids"
]
_keys_to_ignore_on_load_unexpected
=
[
"position_embeddings"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
@@ -788,6 +808,10 @@ class DebertaPreTrainedModel(PreTrainedModel):
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
DebertaEncoder
):
module
.
gradient_checkpointing
=
value
DEBERTA_START_DOCSTRING
=
r
"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
...
...
@@ -947,7 +971,7 @@ class DebertaModel(DebertaPreTrainedModel):
query_states
=
layer
(
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
query_states
,
relative_pos
=
rel_pos
,
rel_embeddings
=
rel_embeddings
,
...
...
src/transformers/models/deberta_v2/modeling_deberta_v2.py
View file @
6200fd7b
...
...
@@ -259,7 +259,7 @@ class DebertaV2Attention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -267,18 +267,18 @@ class DebertaV2Attention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -328,24 +328,24 @@ class DebertaV2Layer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -415,6 +415,7 @@ class DebertaV2Encoder(nn.Module):
self
.
LayerNorm
=
LayerNorm
(
config
.
hidden_size
,
config
.
layer_norm_eps
,
elementwise_affine
=
True
)
self
.
conv
=
ConvLayer
(
config
)
if
getattr
(
config
,
"conv_kernel_size"
,
0
)
>
0
else
None
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -471,14 +472,32 @@ class DebertaV2Encoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
output_states
,)
output_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
output_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
output_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
output_states
,
att_m
=
output_states
...
...
@@ -619,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -637,7 +656,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -696,7 +715,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -881,6 +900,7 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
base_model_prefix
=
"deberta"
_keys_to_ignore_on_load_missing
=
[
"position_ids"
]
_keys_to_ignore_on_load_unexpected
=
[
"position_embeddings"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
@@ -895,6 +915,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
DebertaV2Encoder
):
module
.
gradient_checkpointing
=
value
DEBERTA_START_DOCSTRING
=
r
"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
...
...
@@ -1055,7 +1079,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
query_states
=
layer
(
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
query_states
,
relative_pos
=
rel_pos
,
rel_embeddings
=
rel_embeddings
,
...
...
src/transformers/models/sew_d/modeling_sew_d.py
View file @
6200fd7b
...
...
@@ -661,7 +661,7 @@ class DisentangledSelfAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -679,7 +679,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att
(:obj:`bool`, optional):
output_attentions
(:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
...
...
@@ -738,7 +738,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
-
1
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
if
return_att
:
if
output_attentions
:
return
(
context_layer
,
attention_probs
)
else
:
return
context_layer
...
...
@@ -849,7 +849,7 @@ class SEWDAttention(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
output_attentions
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
...
...
@@ -857,18 +857,18 @@ class SEWDAttention(nn.Module):
self_output
=
self
.
self
(
hidden_states
,
attention_mask
,
return_att
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
self_output
,
att_matrix
=
self_output
if
query_states
is
None
:
query_states
=
hidden_states
attention_output
=
self
.
output
(
self_output
,
query_states
)
if
return_att
:
if
output_attentions
:
return
(
attention_output
,
att_matrix
)
else
:
return
attention_output
...
...
@@ -918,24 +918,24 @@ class SEWDLayer(nn.Module):
self
,
hidden_states
,
attention_mask
,
return_att
=
False
,
query_states
=
None
,
relative_pos
=
None
,
rel_embeddings
=
None
,
output_attentions
=
False
,
):
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
,
return_att
=
return_att
,
output_attentions
=
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
return_att
:
if
output_attentions
:
attention_output
,
att_matrix
=
attention_output
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
if
return_att
:
if
output_attentions
:
return
(
layer_output
,
att_matrix
)
else
:
return
layer_output
...
...
@@ -1007,6 +1007,7 @@ class SEWDTransformerEncoder(nn.Module):
self
.
LayerNorm
=
LayerNorm
(
config
.
hidden_size
,
config
.
layer_norm_eps
,
elementwise_affine
=
True
)
self
.
conv
=
ConvLayer
(
config
)
if
getattr
(
config
,
"conv_kernel_size"
,
0
)
>
0
else
None
self
.
gradient_checkpointing
=
False
def
get_rel_embedding
(
self
):
rel_embeddings
=
self
.
rel_embeddings
.
weight
if
self
.
relative_attention
else
None
...
...
@@ -1063,14 +1064,32 @@ class SEWDTransformerEncoder(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
output_states
,)
output_states
=
layer_module
(
next_kv
,
attention_mask
,
output_attentions
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
output_states
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
next_kv
,
attention_mask
,
query_states
,
relative_pos
,
rel_embeddings
,
)
else
:
output_states
=
layer_module
(
next_kv
,
attention_mask
,
query_states
=
query_states
,
relative_pos
=
relative_pos
,
rel_embeddings
=
rel_embeddings
,
output_attentions
=
output_attentions
,
)
if
output_attentions
:
output_states
,
att_m
=
output_states
...
...
@@ -1169,6 +1188,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class
=
SEWDConfig
base_model_prefix
=
"sew-d"
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
...
...
@@ -1233,6 +1253,10 @@ class SEWDPreTrainedModel(PreTrainedModel):
attention_mask
=
attention_mask
.
flip
([
-
1
]).
cumsum
(
-
1
).
flip
([
-
1
]).
bool
()
return
attention_mask
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
SEWDTransformerEncoder
):
module
.
gradient_checkpointing
=
value
SEWD_START_DOCSTRING
=
r
"""
SEW-D was proposed in `Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment