Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import functools
import gc import gc
import importlib.metadata import importlib.metadata
import inspect import inspect
...@@ -1848,16 +1849,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1848,16 +1849,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self): def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
""" """
Activates gradient checkpointing for the current model. Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations". activations".
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
""" """
if not self.supports_gradient_checkpointing: if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func))
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
...@@ -1874,7 +1890,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1874,7 +1890,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations". activations".
""" """
if self.supports_gradient_checkpointing: if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False)) self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads() self.disable_input_require_grads()
......
...@@ -1095,20 +1095,15 @@ class AlignTextEncoder(nn.Module): ...@@ -1095,20 +1095,15 @@ class AlignTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1197,9 +1192,10 @@ class AlignPreTrainedModel(PreTrainedModel): ...@@ -1197,9 +1192,10 @@ class AlignPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AlignTextModel, AlignVisionModel)): if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
......
...@@ -646,20 +646,15 @@ class AltRobertaEncoder(nn.Module): ...@@ -646,20 +646,15 @@ class AltRobertaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -960,18 +955,12 @@ class AltCLIPEncoder(nn.Module): ...@@ -960,18 +955,12 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1089,11 +1078,13 @@ class AltCLIPPreTrainedModel(PreTrainedModel): ...@@ -1089,11 +1078,13 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, AltCLIPEncoder): if isinstance(module, AltCLIPEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, AltRobertaEncoder): if isinstance(module, AltRobertaEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
......
...@@ -336,17 +336,11 @@ class ASTEncoder(nn.Module): ...@@ -336,17 +336,11 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -395,9 +389,10 @@ class ASTPreTrainedModel(PreTrainedModel): ...@@ -395,9 +389,10 @@ class ASTPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ASTEncoder): if isinstance(module, ASTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
......
...@@ -946,9 +946,10 @@ class AutoformerPreTrainedModel(PreTrainedModel): ...@@ -946,9 +946,10 @@ class AutoformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUTOFORMER_START_DOCSTRING = r""" AUTOFORMER_START_DOCSTRING = r"""
...@@ -1207,18 +1208,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel): ...@@ -1207,18 +1208,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1425,16 +1420,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel): ...@@ -1425,16 +1420,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1442,6 +1429,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel): ...@@ -1442,6 +1429,8 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -313,9 +313,10 @@ class BarkPreTrainedModel(PreTrainedModel): ...@@ -313,9 +313,10 @@ class BarkPreTrainedModel(PreTrainedModel):
return get_parameter_device(self) return get_parameter_device(self)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BARK_MODEL_START_DOCSTRING = """ BARK_MODEL_START_DOCSTRING = """
...@@ -637,20 +638,14 @@ class BarkCausalModel(BarkPreTrainedModel): ...@@ -637,20 +638,14 @@ class BarkCausalModel(BarkPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attention_mask, attention_mask,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
...@@ -521,9 +521,10 @@ class BartPreTrainedModel(PreTrainedModel): ...@@ -521,9 +521,10 @@ class BartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BartDecoder, BartEncoder)): if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -854,18 +855,12 @@ class BartEncoder(BartPreTrainedModel): ...@@ -854,18 +855,12 @@ class BartEncoder(BartPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1110,16 +1105,8 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1110,16 +1105,8 @@ class BartDecoder(BartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1127,6 +1114,8 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1127,6 +1114,8 @@ class BartDecoder(BartPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -510,17 +510,11 @@ class BeitEncoder(nn.Module): ...@@ -510,17 +510,11 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
relative_position_bias = ( relative_position_bias = (
...@@ -572,9 +566,10 @@ class BeitPreTrainedModel(PreTrainedModel): ...@@ -572,9 +566,10 @@ class BeitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BeitEncoder): if isinstance(module, BeitEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BEIT_START_DOCSTRING = r""" BEIT_START_DOCSTRING = r"""
......
...@@ -593,20 +593,15 @@ class BertEncoder(nn.Module): ...@@ -593,20 +593,15 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -762,9 +757,10 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -762,9 +757,10 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder): if isinstance(module, BertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
......
...@@ -401,20 +401,15 @@ class BertEncoder(nn.Module): ...@@ -401,20 +401,15 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -607,9 +602,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -607,9 +602,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder): if isinstance(module, BertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BERT_GENERATION_START_DOCSTRING = r""" BERT_GENERATION_START_DOCSTRING = r"""
......
...@@ -1617,15 +1617,8 @@ class BigBirdEncoder(nn.Module): ...@@ -1617,15 +1617,8 @@ class BigBirdEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
...@@ -1635,6 +1628,8 @@ class BigBirdEncoder(nn.Module): ...@@ -1635,6 +1628,8 @@ class BigBirdEncoder(nn.Module):
from_mask, from_mask,
to_mask, to_mask,
blocked_encoder_mask, blocked_encoder_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1784,9 +1779,10 @@ class BigBirdPreTrainedModel(PreTrainedModel): ...@@ -1784,9 +1779,10 @@ class BigBirdPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BigBirdEncoder): if isinstance(module, BigBirdEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIG_BIRD_START_DOCSTRING = r""" BIG_BIRD_START_DOCSTRING = r"""
......
...@@ -1609,9 +1609,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): ...@@ -1609,9 +1609,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -1943,15 +1944,8 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1943,15 +1944,8 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
...@@ -1960,6 +1954,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1960,6 +1954,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
to_mask, to_mask,
blocked_encoder_mask, blocked_encoder_mask,
blocked_encoder_mask, blocked_encoder_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -2289,16 +2284,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2289,16 +2284,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -2306,6 +2293,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2306,6 +2293,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -376,9 +376,10 @@ class BioGptPreTrainedModel(PreTrainedModel): ...@@ -376,9 +376,10 @@ class BioGptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BioGptModel): if isinstance(module, BioGptModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIOGPT_START_DOCSTRING = r""" BIOGPT_START_DOCSTRING = r"""
...@@ -590,20 +591,14 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -590,20 +591,14 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -669,9 +669,10 @@ class BitPreTrainedModel(PreTrainedModel): ...@@ -669,9 +669,10 @@ class BitPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BitModel): if isinstance(module, BitModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIT_START_DOCSTRING = r""" BIT_START_DOCSTRING = r"""
......
...@@ -483,9 +483,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel): ...@@ -483,9 +483,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -777,18 +778,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -777,18 +778,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1032,16 +1027,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1032,16 +1027,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1049,6 +1036,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1049,6 +1036,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -480,9 +480,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): ...@@ -480,9 +480,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -775,18 +776,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -775,18 +776,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1029,16 +1024,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1029,16 +1024,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1046,6 +1033,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1046,6 +1033,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -34,7 +34,7 @@ from ...utils import ( ...@@ -34,7 +34,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -461,9 +461,10 @@ class BlipPreTrainedModel(PreTrainedModel): ...@@ -461,9 +461,10 @@ class BlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BlipEncoder): if isinstance(module, (BlipEncoder, BlipTextEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BLIP_START_DOCSTRING = r""" BLIP_START_DOCSTRING = r"""
...@@ -622,17 +623,11 @@ class BlipEncoder(nn.Module): ...@@ -622,17 +623,11 @@ class BlipEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -422,20 +422,15 @@ class BlipTextEncoder(nn.Module): ...@@ -422,20 +422,15 @@ class BlipTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -297,9 +297,14 @@ class Blip2PreTrainedModel(PreTrainedModel): ...@@ -297,9 +297,14 @@ class Blip2PreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Blip2Encoder): if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
BLIP_2_START_DOCSTRING = r""" BLIP_2_START_DOCSTRING = r"""
...@@ -473,17 +478,11 @@ class Blip2Encoder(nn.Module): ...@@ -473,17 +478,11 @@ class Blip2Encoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -944,15 +943,8 @@ class Blip2QFormerEncoder(nn.Module): ...@@ -944,15 +943,8 @@ class Blip2QFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
......
...@@ -496,9 +496,10 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -496,9 +496,10 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, BloomModel): if isinstance(module, BloomModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod @staticmethod
def _convert_to_standard_cache( def _convert_to_standard_cache(
...@@ -761,21 +762,15 @@ class BloomModel(BloomPreTrainedModel): ...@@ -761,21 +762,15 @@ class BloomModel(BloomPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
causal_mask, causal_mask,
layer_past, layer_past,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment