"tests/vscode:/vscode.git/clone" did not exist on "5603fad2479ad22ca4689f6a4dbf56ef2f1f0973"
Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
......@@ -22,7 +22,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
......@@ -33,6 +32,7 @@ from ...modeling_outputs import (
Seq2SeqMoEOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_end_docstrings,
add_start_docstrings,
......@@ -1155,7 +1155,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1428,7 +1428,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
return custom_forward
layer_outputs = checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
......
......@@ -33,7 +33,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig
......@@ -375,7 +380,7 @@ class NystromformerEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -28,6 +28,7 @@ from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -2619,7 +2620,7 @@ class OneFormerTextTransformer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
for layer in self.layers:
if self.use_checkpoint:
hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states)
hidden_states = torch_custom_checkpointing(layer, hidden_states)
else:
hidden_states = layer(hidden_states)
return hidden_states
......
......@@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_open_llama import OpenLlamaConfig
......@@ -603,7 +604,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -29,6 +29,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -700,7 +701,7 @@ class OPTDecoder(OPTPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
......
......@@ -27,6 +27,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -754,7 +755,7 @@ class OwlViTEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......
......@@ -34,6 +34,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_end_docstrings,
add_start_docstrings,
......@@ -805,7 +806,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1089,7 +1090,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -33,6 +33,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_end_docstrings,
add_start_docstrings,
......@@ -1072,7 +1073,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
global_hidden_states,
......@@ -1330,7 +1331,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...modeling_outputs import (
......@@ -31,7 +30,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, torch_custom_checkpointing
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
......@@ -350,7 +349,7 @@ class Pix2StructVisionEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......@@ -1502,7 +1501,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
return custom_forward
layer_outputs = checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
......
......@@ -33,6 +33,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
......@@ -810,7 +811,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
......@@ -1067,7 +1068,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
......
......@@ -28,6 +28,7 @@ from torch.nn import LayerNorm
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -1336,7 +1337,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer),
hidden_states,
extended_attention_mask,
......@@ -1577,7 +1578,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer),
hidden_states,
extended_attention_mask,
......
......@@ -39,7 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -586,7 +586,7 @@ class QDQBertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -31,7 +31,12 @@ from ...modeling_outputs import (
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_realm import RealmConfig
......@@ -591,7 +596,7 @@ class RealmEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -548,7 +553,7 @@ class RemBertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -21,10 +21,10 @@ import math
from typing import Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import nn
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import add_start_docstrings, logging
from ..bert.modeling_bert import BertModel
from .configuration_retribert import RetriBertConfig
......@@ -141,7 +141,7 @@ class RetriBertModel(RetriBertPreTrainedModel):
for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
pooled_output = torch_custom_checkpointing(partial_encode, b_embedding_output, b_attention_mask)
pooled_output_list.append(pooled_output)
return torch.cat(pooled_output_list, dim=0)
......
......@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -515,7 +520,7 @@ class RobertaEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -517,7 +522,7 @@ class RobertaPreLayerNormEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -35,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -649,7 +654,7 @@ class RoCBertEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -36,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_custom_checkpointing,
)
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
......@@ -585,7 +590,7 @@ class RoFormerEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
......
......@@ -28,6 +28,7 @@ from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
......@@ -1049,7 +1050,7 @@ class SamVisionEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module),
hidden_states,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment