Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
...@@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -441,7 +441,7 @@ class VideoMAEEncoder(nn.Module): ...@@ -441,7 +441,7 @@ class VideoMAEEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
...@@ -724,7 +724,7 @@ class VideoMAEDecoder(nn.Module): ...@@ -724,7 +724,7 @@ class VideoMAEDecoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
None, None,
......
...@@ -38,6 +38,7 @@ from ...pytorch_utils import ( ...@@ -38,6 +38,7 @@ from ...pytorch_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
meshgrid, meshgrid,
prune_linear_layer, prune_linear_layer,
torch_custom_checkpointing,
) )
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vilt import ViltConfig from .configuration_vilt import ViltConfig
...@@ -536,7 +537,7 @@ class ViltEncoder(nn.Module): ...@@ -536,7 +537,7 @@ class ViltEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -32,7 +32,12 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,12 @@ from ...modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -423,7 +428,7 @@ class VisualBertEncoder(nn.Module): ...@@ -423,7 +428,7 @@ class VisualBertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -32,7 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,7 @@ from ...modeling_outputs import (
MaskedImageModelingOutput, MaskedImageModelingOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -404,7 +404,7 @@ class ViTEncoder(nn.Module): ...@@ -404,7 +404,7 @@ class ViTEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..auto import AutoBackbone from ..auto import AutoBackbone
from .configuration_vit_hybrid import ViTHybridConfig from .configuration_vit_hybrid import ViTHybridConfig
...@@ -422,7 +422,7 @@ class ViTHybridEncoder(nn.Module): ...@@ -422,7 +422,7 @@ class ViTHybridEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -543,7 +543,7 @@ class ViTMAEEncoder(nn.Module): ...@@ -543,7 +543,7 @@ class ViTMAEEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
...@@ -800,7 +800,7 @@ class ViTMAEDecoder(nn.Module): ...@@ -800,7 +800,7 @@ class ViTMAEDecoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
None, None,
......
...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vit_msn import ViTMSNConfig from .configuration_vit_msn import ViTMSNConfig
...@@ -394,7 +394,7 @@ class ViTMSNEncoder(nn.Module): ...@@ -394,7 +394,7 @@ class ViTMSNEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -37,6 +37,7 @@ from ...modeling_outputs import ( ...@@ -37,6 +37,7 @@ from ...modeling_outputs import (
XVectorOutput, XVectorOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -458,7 +459,7 @@ class Wav2Vec2FeatureEncoder(nn.Module): ...@@ -458,7 +459,7 @@ class Wav2Vec2FeatureEncoder(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch_custom_checkpointing(
create_custom_forward(conv_layer), create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
...@@ -810,7 +811,7 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -810,7 +811,7 @@ class Wav2Vec2Encoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -899,7 +900,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -899,7 +900,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,6 +35,7 @@ from ...modeling_outputs import ( ...@@ -35,6 +35,7 @@ from ...modeling_outputs import (
XVectorOutput, XVectorOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -523,7 +524,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): ...@@ -523,7 +524,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch_custom_checkpointing(
create_custom_forward(conv_layer), create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
...@@ -916,7 +917,7 @@ class Wav2Vec2ConformerEncoder(nn.Module): ...@@ -916,7 +917,7 @@ class Wav2Vec2ConformerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -36,6 +36,7 @@ from ...modeling_outputs import ( ...@@ -36,6 +36,7 @@ from ...modeling_outputs import (
XVectorOutput, XVectorOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_wavlm import WavLMConfig from .configuration_wavlm import WavLMConfig
...@@ -361,7 +362,7 @@ class WavLMFeatureEncoder(nn.Module): ...@@ -361,7 +362,7 @@ class WavLMFeatureEncoder(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch_custom_checkpointing(
create_custom_forward(conv_layer), create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
...@@ -720,7 +721,7 @@ class WavLMEncoder(nn.Module): ...@@ -720,7 +721,7 @@ class WavLMEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -811,7 +812,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): ...@@ -811,7 +812,7 @@ class WavLMEncoderStableLayerNorm(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,6 +34,7 @@ from ...modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -853,7 +854,7 @@ class WhisperEncoder(WhisperPreTrainedModel): ...@@ -853,7 +854,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
None, None,
...@@ -1085,7 +1086,7 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1085,7 +1086,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -26,6 +26,7 @@ from torch import nn ...@@ -26,6 +26,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -708,7 +709,7 @@ class XCLIPEncoder(nn.Module): ...@@ -708,7 +709,7 @@ class XCLIPEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -955,7 +956,7 @@ class XCLIPVisionEncoder(nn.Module): ...@@ -955,7 +956,7 @@ class XCLIPVisionEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss ...@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_xglm import XGLMConfig from .configuration_xglm import XGLMConfig
...@@ -683,7 +684,7 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -683,7 +684,7 @@ class XGLMModel(XGLMPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -29,6 +29,7 @@ from torch.nn import LayerNorm ...@@ -29,6 +29,7 @@ from torch.nn import LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -1356,7 +1357,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): ...@@ -1356,7 +1357,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
...@@ -1600,7 +1601,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1600,7 +1601,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
......
...@@ -35,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -516,7 +521,7 @@ class XLMRobertaEncoder(nn.Module): ...@@ -516,7 +521,7 @@ class XLMRobertaEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -504,7 +509,7 @@ class XLMRobertaXLEncoder(nn.Module): ...@@ -504,7 +509,7 @@ class XLMRobertaXLEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_xmod import XmodConfig from .configuration_xmod import XmodConfig
...@@ -578,7 +583,7 @@ class XmodEncoder(nn.Module): ...@@ -578,7 +583,7 @@ class XmodEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
lang_ids, lang_ids,
......
...@@ -27,7 +27,7 @@ from torch import Tensor, nn ...@@ -27,7 +27,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -499,7 +499,7 @@ class YolosEncoder(nn.Module): ...@@ -499,7 +499,7 @@ class YolosEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -34,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_yoso import YosoConfig from .configuration_yoso import YosoConfig
...@@ -566,7 +571,7 @@ class YosoEncoder(nn.Module): ...@@ -566,7 +571,7 @@ class YosoEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -285,3 +285,18 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: ...@@ -285,3 +285,18 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
return tensor.device, storage_ptr(tensor), storage_size(tensor) return tensor.device, storage_ptr(tensor), storage_size(tensor)
def torch_custom_checkpointing(*args):
r"""
A correct usage of `torch.utils.checkpoint.checkpoint` as the default call leads to silent bugs that leads to the
gradients of the last layers not being updated. For more in depth detail of the issue, please have a look at:
https://github.com/huggingface/transformers/pull/24247
"""
kwargs = {}
if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters):
kwargs["use_reentrant"] = False
return torch.utils.checkpoint.checkpoint(
*args,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment