"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "e36332e67d4a3152cd7c8cab46c141256d79af0f"
Unverified Commit 0548af54 authored by Nate Cibik's avatar Nate Cibik Committed by GitHub
Browse files

Enable Gradient Checkpointing in Deformable DETR (#28686)

* Enabled gradient checkpointing in Deformable DETR

* Enabled gradient checkpointing in Deformable DETR encoder

* Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code
parent f72c7c22
......@@ -1048,6 +1048,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
def _init_weights(self, module):
......@@ -1143,6 +1144,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
self.gradient_checkpointing = False
self.dropout = config.dropout
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
......@@ -1235,15 +1237,27 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
reference_points,
spatial_shapes,
level_start_index,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
......@@ -1368,9 +1382,13 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(
......
......@@ -942,7 +942,6 @@ class DetaClassificationHead(nn.Module):
return hidden_states
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
class DetaPreTrainedModel(PreTrainedModel):
config_class = DetaConfig
base_model_prefix = "model"
......@@ -1028,7 +1027,6 @@ DETA_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetr->Deta
class DetaEncoder(DetaPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
......@@ -1159,7 +1157,6 @@ class DetaEncoder(DetaPreTrainedModel):
)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoder with DeformableDetr->Deta,Deformable DETR->DETA
class DetaDecoder(DetaPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetaDecoderLayer`].
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment