Unverified Commit 6200fd7b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Gradient checkpointing] Enable for Deberta + DebertaV2 + SEW-D (#14175)

* up

* up

* finish

* up

* final changes
parent e1dc5afd
......@@ -272,7 +272,7 @@ class DebertaAttention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -280,18 +280,18 @@ class DebertaAttention(nn.Module):
self_output = self.self(
hidden_states,
attention_mask,
return_att,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
self_output, att_matrix = self_output
if query_states is None:
query_states = hidden_states
attention_output = self.output(self_output, query_states)
if return_att:
if output_attentions:
return (attention_output, att_matrix)
else:
return attention_output
......@@ -339,24 +339,24 @@ class DebertaLayer(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
output_attentions=False,
):
attention_output = self.attention(
hidden_states,
attention_mask,
return_att=return_att,
output_attentions=output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
attention_output, att_matrix = attention_output
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
if return_att:
if output_attentions:
return (layer_output, att_matrix)
else:
return layer_output
......@@ -374,6 +374,7 @@ class DebertaEncoder(nn.Module):
if self.max_relative_positions < 1:
self.max_relative_positions = config.max_position_embeddings
self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
self.gradient_checkpointing = False
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
......@@ -421,14 +422,32 @@ class DebertaEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
next_kv,
attention_mask,
query_states,
relative_pos,
rel_embeddings,
)
else:
hidden_states = layer_module(
next_kv,
attention_mask,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
hidden_states, att_m = hidden_states
......@@ -547,7 +566,7 @@ class DisentangledSelfAttention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -565,7 +584,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att (:obj:`bool`, optional):
output_attentions (:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
......@@ -629,7 +648,7 @@ class DisentangledSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
if return_att:
if output_attentions:
return (context_layer, attention_probs)
else:
return context_layer
......@@ -774,6 +793,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -788,6 +808,10 @@ class DebertaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DebertaEncoder):
module.gradient_checkpointing = value
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
......@@ -947,7 +971,7 @@ class DebertaModel(DebertaPreTrainedModel):
query_states = layer(
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=query_states,
relative_pos=rel_pos,
rel_embeddings=rel_embeddings,
......
......@@ -259,7 +259,7 @@ class DebertaV2Attention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -267,18 +267,18 @@ class DebertaV2Attention(nn.Module):
self_output = self.self(
hidden_states,
attention_mask,
return_att,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
self_output, att_matrix = self_output
if query_states is None:
query_states = hidden_states
attention_output = self.output(self_output, query_states)
if return_att:
if output_attentions:
return (attention_output, att_matrix)
else:
return attention_output
......@@ -328,24 +328,24 @@ class DebertaV2Layer(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
output_attentions=False,
):
attention_output = self.attention(
hidden_states,
attention_mask,
return_att=return_att,
output_attentions=output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
attention_output, att_matrix = attention_output
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
if return_att:
if output_attentions:
return (layer_output, att_matrix)
else:
return layer_output
......@@ -415,6 +415,7 @@ class DebertaV2Encoder(nn.Module):
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
self.gradient_checkpointing = False
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
......@@ -471,14 +472,32 @@ class DebertaV2Encoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (output_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
output_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
next_kv,
attention_mask,
query_states,
relative_pos,
rel_embeddings,
)
else:
output_states = layer_module(
next_kv,
attention_mask,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
output_states, att_m = output_states
......@@ -619,7 +638,7 @@ class DisentangledSelfAttention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -637,7 +656,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att (:obj:`bool`, optional):
output_attentions (:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
......@@ -696,7 +715,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
if return_att:
if output_attentions:
return (context_layer, attention_probs)
else:
return context_layer
......@@ -881,6 +900,7 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "deberta"
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -895,6 +915,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DebertaV2Encoder):
module.gradient_checkpointing = value
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
......@@ -1055,7 +1079,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
query_states = layer(
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=query_states,
relative_pos=rel_pos,
rel_embeddings=rel_embeddings,
......
......@@ -661,7 +661,7 @@ class DisentangledSelfAttention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -679,7 +679,7 @@ class DisentangledSelfAttention(nn.Module):
sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
th token.
return_att (:obj:`bool`, optional):
output_attentions (:obj:`bool`, optional):
Whether return the attention matrix.
query_states (:obj:`torch.FloatTensor`, optional):
......@@ -738,7 +738,7 @@ class DisentangledSelfAttention(nn.Module):
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(*new_context_layer_shape)
if return_att:
if output_attentions:
return (context_layer, attention_probs)
else:
return context_layer
......@@ -849,7 +849,7 @@ class SEWDAttention(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
output_attentions=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
......@@ -857,18 +857,18 @@ class SEWDAttention(nn.Module):
self_output = self.self(
hidden_states,
attention_mask,
return_att,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
self_output, att_matrix = self_output
if query_states is None:
query_states = hidden_states
attention_output = self.output(self_output, query_states)
if return_att:
if output_attentions:
return (attention_output, att_matrix)
else:
return attention_output
......@@ -918,24 +918,24 @@ class SEWDLayer(nn.Module):
self,
hidden_states,
attention_mask,
return_att=False,
query_states=None,
relative_pos=None,
rel_embeddings=None,
output_attentions=False,
):
attention_output = self.attention(
hidden_states,
attention_mask,
return_att=return_att,
output_attentions=output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
if return_att:
if output_attentions:
attention_output, att_matrix = attention_output
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
if return_att:
if output_attentions:
return (layer_output, att_matrix)
else:
return layer_output
......@@ -1007,6 +1007,7 @@ class SEWDTransformerEncoder(nn.Module):
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
self.gradient_checkpointing = False
def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
......@@ -1063,14 +1064,32 @@ class SEWDTransformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (output_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
output_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
next_kv,
attention_mask,
query_states,
relative_pos,
rel_embeddings,
)
else:
output_states = layer_module(
next_kv,
attention_mask,
output_attentions,
query_states=query_states,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
output_attentions=output_attentions,
)
if output_attentions:
output_states, att_m = output_states
......@@ -1169,6 +1188,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class = SEWDConfig
base_model_prefix = "sew-d"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1233,6 +1253,10 @@ class SEWDPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, SEWDTransformerEncoder):
module.gradient_checkpointing = value
SEWD_START_DOCSTRING = r"""
SEW-D was proposed in `Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment