"git@developer.sourcefind.cn:sugon_wxj/megatron-lm.git" did not exist on "1d391bba132ac2cb6077ee10bc4138a7260d39f2"
Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
...@@ -39,7 +39,7 @@ from ...file_utils import ( ...@@ -39,7 +39,7 @@ from ...file_utils import (
) )
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ModelOutput, logging from ...utils import ModelOutput, logging
from ..auto import AutoBackbone from ..auto import AutoBackbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
...@@ -535,7 +535,7 @@ class DPTViTEncoder(nn.Module): ...@@ -535,7 +535,7 @@ class DPTViTEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
......
...@@ -36,7 +36,12 @@ from ...modeling_outputs import ( ...@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -576,7 +581,7 @@ class ElectraEncoder(nn.Module): ...@@ -576,7 +581,7 @@ class ElectraEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -38,7 +38,12 @@ from ...modeling_outputs import ( ...@@ -38,7 +38,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -511,7 +516,7 @@ class ErnieEncoder(nn.Module): ...@@ -511,7 +516,7 @@ class ErnieEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -30,7 +30,8 @@ from ...modeling_outputs import ( ...@@ -30,7 +30,8 @@ from ...modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import logging from ...utils import logging
from .configuration_esm import EsmConfig from .configuration_esm import EsmConfig
...@@ -610,7 +611,7 @@ class EsmEncoder(nn.Module): ...@@ -610,7 +611,7 @@ class EsmEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -26,7 +26,8 @@ from torch import nn ...@@ -26,7 +26,8 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -668,7 +669,7 @@ class FlavaEncoder(nn.Module): ...@@ -668,7 +669,7 @@ class FlavaEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -43,7 +43,7 @@ from ...modeling_outputs import ( ...@@ -43,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -297,7 +297,7 @@ class FNetEncoder(nn.Module): ...@@ -297,7 +297,7 @@ class FNetEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) layer_outputs = torch_custom_checkpointing(create_custom_forward(layer_module), hidden_states)
else: else:
layer_outputs = layer_module(hidden_states) layer_outputs = layer_module(hidden_states)
......
...@@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -593,7 +594,7 @@ class FocalNetEncoder(nn.Module): ...@@ -593,7 +594,7 @@ class FocalNetEncoder(nn.Module):
return custom_forward return custom_forward
stage_outputs = torch.utils.checkpoint.checkpoint( stage_outputs = torch_custom_checkpointing(
create_custom_forward(stage_module), create_custom_forward(stage_module),
hidden_states, hidden_states,
input_dimensions, input_dimensions,
......
...@@ -34,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,12 @@ from ...modeling_outputs import (
CausalLMOutputWithPast, CausalLMOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_git import GitConfig, GitVisionConfig from .configuration_git import GitConfig, GitVisionConfig
...@@ -457,7 +462,7 @@ class GitEncoder(nn.Module): ...@@ -457,7 +462,7 @@ class GitEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -883,7 +888,7 @@ class GitVisionEncoder(nn.Module): ...@@ -883,7 +888,7 @@ class GitVisionEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,8 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,8 +35,12 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import Conv1D, PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import (
find_pruneable_heads_and_indices,
prune_conv1d_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -890,7 +894,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -890,7 +894,7 @@ class GPT2Model(GPT2PreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
......
...@@ -28,6 +28,7 @@ from ...modeling_outputs import ( ...@@ -28,6 +28,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -661,7 +662,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): ...@@ -661,7 +662,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
......
...@@ -34,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,6 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_gpt_neo import GPTNeoConfig from .configuration_gpt_neo import GPTNeoConfig
...@@ -613,7 +614,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -613,7 +614,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
......
...@@ -36,6 +36,7 @@ from ...modeling_outputs import ( ...@@ -36,6 +36,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig from .configuration_gpt_neox import GPTNeoXConfig
...@@ -557,7 +558,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): ...@@ -557,7 +558,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -31,6 +31,7 @@ from ...modeling_outputs import ( ...@@ -31,6 +31,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -677,7 +678,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -677,7 +678,7 @@ class GPTJModel(GPTJPreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
......
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -1037,7 +1038,7 @@ class GroupViTTextEncoder(nn.Module): ...@@ -1037,7 +1038,7 @@ class GroupViTTextEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -27,6 +27,7 @@ from ...activations import ACT2FN ...@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -353,7 +354,7 @@ class HubertFeatureEncoder(nn.Module): ...@@ -353,7 +354,7 @@ class HubertFeatureEncoder(nn.Module):
return custom_forward return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch_custom_checkpointing(
create_custom_forward(conv_layer), create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
...@@ -738,7 +739,7 @@ class HubertEncoder(nn.Module): ...@@ -738,7 +739,7 @@ class HubertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -828,7 +829,7 @@ class HubertEncoderStableLayerNorm(nn.Module): ...@@ -828,7 +829,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer), create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -32,7 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig
...@@ -826,7 +826,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -826,7 +826,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch_custom_checkpointing(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
None, None,
......
...@@ -30,6 +30,7 @@ from ...modeling_outputs import ( ...@@ -30,6 +30,7 @@ from ...modeling_outputs import (
Seq2SeqTSPredictionOutput, Seq2SeqTSPredictionOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_informer import InformerConfig from .configuration_informer import InformerConfig
...@@ -1217,14 +1218,14 @@ class InformerEncoder(InformerPreTrainedModel): ...@@ -1217,14 +1218,14 @@ class InformerEncoder(InformerPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
) )
if conv_layer is not None: if conv_layer is not None:
output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) output = torch_custom_checkpointing(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:] layer_outputs = (output,) + layer_outputs[1:]
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1440,7 +1441,7 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1440,7 +1441,7 @@ class InformerDecoder(InformerPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -33,7 +33,12 @@ from ...modeling_outputs import ( ...@@ -33,7 +33,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlm import LayoutLMConfig from .configuration_layoutlm import LayoutLMConfig
...@@ -492,7 +497,7 @@ class LayoutLMEncoder(nn.Module): ...@@ -492,7 +497,7 @@ class LayoutLMEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -31,7 +31,7 @@ from ...modeling_outputs import ( ...@@ -31,7 +31,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
...@@ -455,7 +455,7 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -455,7 +455,7 @@ class LayoutLMv2Encoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -32,7 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_layoutlmv3 import LayoutLMv3Config from .configuration_layoutlmv3 import LayoutLMv3Config
...@@ -671,7 +671,7 @@ class LayoutLMv3Encoder(nn.Module): ...@@ -671,7 +671,7 @@ class LayoutLMv3Encoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment