"docs/vscode:/vscode.git/clone" did not exist on "cb6b56859a251b1f0e8e0ba5df05f8113e353b51"
Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
......@@ -34,7 +34,11 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......
......@@ -30,7 +30,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -1100,7 +1105,7 @@ class AlignTextEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -30,7 +30,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndProjection,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
......@@ -651,7 +656,7 @@ class AltRobertaEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......@@ -965,7 +970,7 @@ class AltCLIPEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......
......@@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_audio_spectrogram_transformer import ASTConfig
......@@ -343,7 +343,7 @@ class ASTEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
layer_head_mask,
......
......@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqTSPredictionOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_autoformer import AutoformerConfig
......@@ -1210,7 +1211,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1428,7 +1429,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -35,6 +35,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -849,7 +850,7 @@ class BartEncoder(BartPretrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1105,7 +1106,7 @@ class BartDecoder(BartPretrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -517,7 +517,7 @@ class BeitEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
layer_head_mask,
......
......@@ -40,7 +40,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -598,7 +603,7 @@ class BertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -408,7 +413,7 @@ class BertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -37,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -1622,7 +1622,7 @@ class BigBirdEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -36,6 +36,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -1945,7 +1946,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -2291,7 +2292,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -32,6 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -594,7 +595,7 @@ class BioGptModel(BioGptPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -36,6 +36,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_end_docstrings,
add_start_docstrings,
......@@ -779,7 +780,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1034,7 +1035,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_end_docstrings,
add_start_docstrings,
......@@ -777,7 +778,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1031,7 +1032,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -25,6 +25,7 @@ from torch.nn.functional import normalize
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -620,7 +621,7 @@ class BlipEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......
......@@ -34,6 +34,7 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging
from .configuration_blip import BlipTextConfig
......@@ -427,7 +428,7 @@ class BlipTextEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -31,7 +31,12 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -492,7 +497,7 @@ class Blip2Encoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -963,7 +968,7 @@ class Blip2QFormerEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -33,6 +33,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging
from .configuration_bloom import BloomConfig
......@@ -775,7 +776,7 @@ class BloomModel(BloomPreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
alibi,
......
......@@ -32,8 +32,13 @@ from ...modeling_outputs import (
ModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
......@@ -810,7 +815,7 @@ class BridgeTowerTextEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -529,7 +534,7 @@ class CamembertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment