Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import gc
import importlib.metadata
import inspect
......@@ -1848,16 +1849,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self):
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func))
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
......@@ -1874,7 +1890,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()
......
......@@ -1095,20 +1095,15 @@ class AlignTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -1197,9 +1192,10 @@ class AlignPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (AlignTextModel, AlignVisionModel)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
......
......@@ -646,20 +646,15 @@ class AltRobertaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -960,18 +955,12 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1089,11 +1078,13 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, AltCLIPEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, AltRobertaEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
......
......@@ -336,17 +336,11 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -395,9 +389,10 @@ class ASTPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ASTEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
......
......@@ -946,9 +946,10 @@ class AutoformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUTOFORMER_START_DOCSTRING = r"""
......@@ -1207,18 +1208,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1425,16 +1420,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1442,6 +1429,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -313,9 +313,10 @@ class BarkPreTrainedModel(PreTrainedModel):
return get_parameter_device(self)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BARK_MODEL_START_DOCSTRING = """
......@@ -637,20 +638,14 @@ class BarkCausalModel(BarkPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -521,9 +521,10 @@ class BartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -854,18 +855,12 @@ class BartEncoder(BartPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1110,16 +1105,8 @@ class BartDecoder(BartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1127,6 +1114,8 @@ class BartDecoder(BartPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -510,17 +510,11 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
relative_position_bias = (
......@@ -572,9 +566,10 @@ class BeitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BEIT_START_DOCSTRING = r"""
......
......@@ -593,20 +593,15 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -762,9 +757,10 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......
......@@ -401,20 +401,15 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -607,9 +602,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BERT_GENERATION_START_DOCSTRING = r"""
......
......@@ -1617,15 +1617,8 @@ class BigBirdEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
......@@ -1635,6 +1628,8 @@ class BigBirdEncoder(nn.Module):
from_mask,
to_mask,
blocked_encoder_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -1784,9 +1779,10 @@ class BigBirdPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BigBirdEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIG_BIRD_START_DOCSTRING = r"""
......
......@@ -1609,9 +1609,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -1943,15 +1944,8 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
......@@ -1960,6 +1954,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
to_mask,
blocked_encoder_mask,
blocked_encoder_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -2289,16 +2284,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -2306,6 +2293,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -376,9 +376,10 @@ class BioGptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BioGptModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIOGPT_START_DOCSTRING = r"""
......@@ -590,20 +591,14 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -669,9 +669,10 @@ class BitPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BitModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIT_START_DOCSTRING = r"""
......
......@@ -483,9 +483,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -777,18 +778,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1032,16 +1027,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1049,6 +1036,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -480,9 +480,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -775,18 +776,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1029,16 +1024,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1046,6 +1033,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -34,7 +34,7 @@ from ...utils import (
replace_return_docstrings,
)
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel
logger = logging.get_logger(__name__)
......@@ -461,9 +461,10 @@ class BlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BlipEncoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlipEncoder, BlipTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BLIP_START_DOCSTRING = r"""
......@@ -622,17 +623,11 @@ class BlipEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -422,20 +422,15 @@ class BlipTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -297,9 +297,14 @@ class Blip2PreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Blip2Encoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
BLIP_2_START_DOCSTRING = r"""
......@@ -473,17 +478,11 @@ class Blip2Encoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -944,15 +943,8 @@ class Blip2QFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
......
......@@ -496,9 +496,10 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, BloomModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod
def _convert_to_standard_cache(
......@@ -761,21 +762,15 @@ class BloomModel(BloomPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment