"tests/models/vscode:/vscode.git/clone" did not exist on "fa7f3cf336eb5d93cfaa7611723c299e7851fb02"
Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
......@@ -39,7 +39,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ModelOutput, logging
from ..auto import AutoBackbone
from .configuration_dpt import DPTConfig
......@@ -535,7 +535,7 @@ class DPTViTEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
layer_head_mask,
......
......@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -576,7 +581,7 @@ class ElectraEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -38,7 +38,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -511,7 +516,7 @@ class ErnieEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -30,7 +30,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import logging
from .configuration_esm import EsmConfig
......@@ -610,7 +611,7 @@ class EsmEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -26,7 +26,8 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -668,7 +669,7 @@ class FlavaEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -43,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -297,7 +297,7 @@ class FNetEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states)
layer_outputs = torch_custom_checkpointing(create_custom_forward(layer_module), hidden_states)
else:
layer_outputs = layer_module(hidden_states)
......
......@@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -593,7 +594,7 @@ class FocalNetEncoder(nn.Module):
return custom_forward
stage_outputs = torch.utils.checkpoint.checkpoint(
stage_outputs = torch_custom_checkpointing(
create_custom_forward(stage_module),
hidden_states,
input_dimensions,
......
......@@ -34,7 +34,12 @@ from ...modeling_outputs import (
CausalLMOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_git import GitConfig, GitVisionConfig
......@@ -457,7 +462,7 @@ class GitEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......@@ -883,7 +888,7 @@ class GitVisionEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......
......@@ -35,8 +35,12 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...modeling_utils import Conv1D, PreTrainedModel, SequenceSummary
from ...pytorch_utils import (
find_pruneable_heads_and_indices,
prune_conv1d_layer,
torch_custom_checkpointing,
)
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -890,7 +894,7 @@ class GPT2Model(GPT2PreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
None,
......
......@@ -28,6 +28,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -661,7 +662,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
None,
......
......@@ -34,6 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_gpt_neo import GPTNeoConfig
......@@ -613,7 +614,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
None,
......
......@@ -36,6 +36,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig
......@@ -557,7 +558,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(layer),
hidden_states,
attention_mask,
......
......@@ -31,6 +31,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -677,7 +678,7 @@ class GPTJModel(GPTJPreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
None,
......
......@@ -28,6 +28,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -1037,7 +1038,7 @@ class GroupViTTextEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......
......@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -353,7 +354,7 @@ class HubertFeatureEncoder(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
hidden_states = torch_custom_checkpointing(
create_custom_forward(conv_layer),
hidden_states,
)
......@@ -738,7 +739,7 @@ class HubertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer),
hidden_states,
attention_mask,
......@@ -828,7 +829,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer),
hidden_states,
attention_mask,
......
......@@ -32,7 +32,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig
......@@ -826,7 +826,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
outputs = torch_custom_checkpointing(
create_custom_forward(block),
hidden_states,
None,
......
......@@ -30,6 +30,7 @@ from ...modeling_outputs import (
Seq2SeqTSPredictionOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_informer import InformerConfig
......@@ -1217,14 +1218,14 @@ class InformerEncoder(InformerPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
if conv_layer is not None:
output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0])
output = torch_custom_checkpointing(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
......@@ -1440,7 +1441,7 @@ class InformerDecoder(InformerPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -33,7 +33,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig
......@@ -492,7 +497,7 @@ class LayoutLMEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -31,7 +31,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
......@@ -455,7 +455,7 @@ class LayoutLMv2Encoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -32,7 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlmv3 import LayoutLMv3Config
......@@ -671,7 +671,7 @@ class LayoutLMv3Encoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment