Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
...@@ -22,7 +22,6 @@ from typing import List, Optional, Tuple, Union ...@@ -22,7 +22,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
...@@ -33,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -33,6 +32,7 @@ from ...modeling_outputs import (
Seq2SeqMoEOutput, Seq2SeqMoEOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -1155,7 +1155,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): ...@@ -1155,7 +1155,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1428,7 +1428,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): ...@@ -1428,7 +1428,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
return custom_forward return custom_forward
layer_outputs = checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
......
...@@ -33,7 +33,12 @@ from ...modeling_outputs import ( ...@@ -33,7 +33,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig from .configuration_nystromformer import NystromformerConfig
...@@ -375,7 +380,7 @@ class NystromformerEncoder(nn.Module): ...@@ -375,7 +380,7 @@ class NystromformerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -28,6 +28,7 @@ from ... import AutoBackbone ...@@ -28,6 +28,7 @@ from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -2619,7 +2620,7 @@ class OneFormerTextTransformer(nn.Module): ...@@ -2619,7 +2620,7 @@ class OneFormerTextTransformer(nn.Module):
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor):
for layer in self.layers: for layer in self.layers:
if self.use_checkpoint: if self.use_checkpoint:
hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) hidden_states = torch_custom_checkpointing(layer, hidden_states)
else: else:
hidden_states = layer(hidden_states) hidden_states = layer(hidden_states)
return hidden_states return hidden_states
......
...@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_open_llama import OpenLlamaConfig from .configuration_open_llama import OpenLlamaConfig
...@@ -603,7 +604,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): ...@@ -603,7 +604,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -29,6 +29,7 @@ from ...modeling_outputs import ( ...@@ -29,6 +29,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -700,7 +701,7 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -700,7 +701,7 @@ class OPTDecoder(OPTPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
causal_attention_mask, causal_attention_mask,
......
...@@ -27,6 +27,7 @@ from torch import nn ...@@ -27,6 +27,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -754,7 +755,7 @@ class OwlViTEncoder(nn.Module): ...@@ -754,7 +755,7 @@ class OwlViTEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -34,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -805,7 +806,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -805,7 +806,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1089,7 +1090,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1089,7 +1090,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -33,6 +33,7 @@ from ...modeling_outputs import ( ...@@ -33,6 +33,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_end_docstrings, add_end_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -1072,7 +1073,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): ...@@ -1072,7 +1073,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
global_hidden_states, global_hidden_states,
...@@ -1330,7 +1331,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1330,7 +1331,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
...@@ -31,7 +30,7 @@ from ...modeling_outputs import ( ...@@ -31,7 +30,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS, torch_custom_checkpointing
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -350,7 +349,7 @@ class Pix2StructVisionEncoder(nn.Module): ...@@ -350,7 +349,7 @@ class Pix2StructVisionEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1502,7 +1501,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ...@@ -1502,7 +1501,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
......
...@@ -33,6 +33,7 @@ from ...modeling_outputs import ( ...@@ -33,6 +33,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_end_docstrings, add_end_docstrings,
...@@ -810,7 +811,7 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -810,7 +811,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1067,7 +1068,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1067,7 +1068,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -28,6 +28,7 @@ from torch.nn import LayerNorm ...@@ -28,6 +28,7 @@ from torch.nn import LayerNorm
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -1336,7 +1337,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1336,7 +1337,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
...@@ -1577,7 +1578,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1577,7 +1578,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
......
...@@ -39,7 +39,7 @@ from ...modeling_outputs import ( ...@@ -39,7 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -586,7 +586,7 @@ class QDQBertEncoder(nn.Module): ...@@ -586,7 +586,7 @@ class QDQBertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -31,7 +31,12 @@ from ...modeling_outputs import ( ...@@ -31,7 +31,12 @@ from ...modeling_outputs import (
ModelOutput, ModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig from .configuration_realm import RealmConfig
...@@ -591,7 +596,7 @@ class RealmEncoder(nn.Module): ...@@ -591,7 +596,7 @@ class RealmEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -36,7 +36,12 @@ from ...modeling_outputs import ( ...@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -548,7 +553,7 @@ class RemBertEncoder(nn.Module): ...@@ -548,7 +553,7 @@ class RemBertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -21,10 +21,10 @@ import math ...@@ -21,10 +21,10 @@ import math
from typing import Optional from typing import Optional
import torch import torch
import torch.utils.checkpoint as checkpoint
from torch import nn from torch import nn
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_start_docstrings, logging from ...utils import add_start_docstrings, logging
from ..bert.modeling_bert import BertModel from ..bert.modeling_bert import BertModel
from .configuration_retribert import RetriBertConfig from .configuration_retribert import RetriBertConfig
...@@ -141,7 +141,7 @@ class RetriBertModel(RetriBertPreTrainedModel): ...@@ -141,7 +141,7 @@ class RetriBertModel(RetriBertPreTrainedModel):
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) pooled_output = torch_custom_checkpointing(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output) pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0) return torch.cat(pooled_output_list, dim=0)
......
...@@ -35,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -515,7 +520,7 @@ class RobertaEncoder(nn.Module): ...@@ -515,7 +520,7 @@ class RobertaEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -517,7 +522,7 @@ class RobertaPreLayerNormEncoder(nn.Module): ...@@ -517,7 +522,7 @@ class RobertaPreLayerNormEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -35,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -649,7 +654,7 @@ class RoCBertEncoder(nn.Module): ...@@ -649,7 +654,7 @@ class RoCBertEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -36,7 +36,12 @@ from ...modeling_outputs import ( ...@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -585,7 +590,7 @@ class RoFormerEncoder(nn.Module): ...@@ -585,7 +590,7 @@ class RoFormerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -28,6 +28,7 @@ from torch import Tensor, nn ...@@ -28,6 +28,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
...@@ -1049,7 +1050,7 @@ class SamVisionEncoder(nn.Module): ...@@ -1049,7 +1050,7 @@ class SamVisionEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment