Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
...@@ -34,7 +34,11 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,11 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
......
...@@ -30,7 +30,12 @@ from ...modeling_outputs import ( ...@@ -30,7 +30,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -1100,7 +1105,7 @@ class AlignTextEncoder(nn.Module): ...@@ -1100,7 +1105,7 @@ class AlignTextEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -30,7 +30,12 @@ from ...modeling_outputs import ( ...@@ -30,7 +30,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndProjection, BaseModelOutputWithPoolingAndProjection,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
...@@ -651,7 +656,7 @@ class AltRobertaEncoder(nn.Module): ...@@ -651,7 +656,7 @@ class AltRobertaEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -965,7 +970,7 @@ class AltCLIPEncoder(nn.Module): ...@@ -965,7 +970,7 @@ class AltCLIPEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_audio_spectrogram_transformer import ASTConfig from .configuration_audio_spectrogram_transformer import ASTConfig
...@@ -343,7 +343,7 @@ class ASTEncoder(nn.Module): ...@@ -343,7 +343,7 @@ class ASTEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -34,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqTSPredictionOutput, Seq2SeqTSPredictionOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_autoformer import AutoformerConfig from .configuration_autoformer import AutoformerConfig
...@@ -1210,7 +1211,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel): ...@@ -1210,7 +1211,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1428,7 +1429,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel): ...@@ -1428,7 +1429,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,6 +35,7 @@ from ...modeling_outputs import ( ...@@ -35,6 +35,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -849,7 +850,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -849,7 +850,7 @@ class BartEncoder(BartPretrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1105,7 +1106,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1105,7 +1106,7 @@ class BartDecoder(BartPretrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,7 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput, SemanticSegmenterOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -517,7 +517,7 @@ class BeitEncoder(nn.Module): ...@@ -517,7 +517,7 @@ class BeitEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -40,7 +40,12 @@ from ...modeling_outputs import ( ...@@ -40,7 +40,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -598,7 +603,7 @@ class BertEncoder(nn.Module): ...@@ -598,7 +603,7 @@ class BertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss ...@@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -408,7 +413,7 @@ class BertEncoder(nn.Module): ...@@ -408,7 +413,7 @@ class BertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -37,7 +37,7 @@ from ...modeling_outputs import ( ...@@ -37,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -1622,7 +1622,7 @@ class BigBirdEncoder(nn.Module): ...@@ -1622,7 +1622,7 @@ class BigBirdEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -36,6 +36,7 @@ from ...modeling_outputs import ( ...@@ -36,6 +36,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -1945,7 +1946,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1945,7 +1946,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -2291,7 +2292,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2291,7 +2292,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -32,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,6 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -594,7 +595,7 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -594,7 +595,7 @@ class BioGptModel(BioGptPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -36,6 +36,7 @@ from ...modeling_outputs import ( ...@@ -36,6 +36,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -779,7 +780,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -779,7 +780,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1034,7 +1035,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1034,7 +1035,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -777,7 +778,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -777,7 +778,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1031,7 +1032,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1031,7 +1032,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -25,6 +25,7 @@ from torch.nn.functional import normalize ...@@ -25,6 +25,7 @@ from torch.nn.functional import normalize
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -620,7 +621,7 @@ class BlipEncoder(nn.Module): ...@@ -620,7 +621,7 @@ class BlipEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,6 +34,7 @@ from ...modeling_utils import ( ...@@ -34,6 +34,7 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging from ...utils import logging
from .configuration_blip import BlipTextConfig from .configuration_blip import BlipTextConfig
...@@ -427,7 +428,7 @@ class BlipTextEncoder(nn.Module): ...@@ -427,7 +428,7 @@ class BlipTextEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -31,7 +31,12 @@ from ...modeling_outputs import ( ...@@ -31,7 +31,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -492,7 +497,7 @@ class Blip2Encoder(nn.Module): ...@@ -492,7 +497,7 @@ class Blip2Encoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -963,7 +968,7 @@ class Blip2QFormerEncoder(nn.Module): ...@@ -963,7 +968,7 @@ class Blip2QFormerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -33,6 +33,7 @@ from ...modeling_outputs import ( ...@@ -33,6 +33,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging from ...utils import logging
from .configuration_bloom import BloomConfig from .configuration_bloom import BloomConfig
...@@ -775,7 +776,7 @@ class BloomModel(BloomPreTrainedModel): ...@@ -775,7 +776,7 @@ class BloomModel(BloomPreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
......
...@@ -32,8 +32,13 @@ from ...modeling_outputs import ( ...@@ -32,8 +32,13 @@ from ...modeling_outputs import (
ModelOutput, ModelOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
...@@ -810,7 +815,7 @@ class BridgeTowerTextEncoder(nn.Module): ...@@ -810,7 +815,7 @@ class BridgeTowerTextEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -529,7 +534,7 @@ class CamembertEncoder(nn.Module): ...@@ -529,7 +534,7 @@ class CamembertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment