Unverified Commit 235be085 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

[DETA] fix backbone freeze/unfreeze function (#27843)



* [DETA] fix freeze/unfreeze function

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add freeze/unfreeze test case in DETA

* fix type

* fix typo 2

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent df5c5c62
......@@ -1414,14 +1414,12 @@ class DetaModel(DetaPreTrainedModel):
def get_decoder(self):
return self.decoder
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.freeze_backbone
def freeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(False)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.unfreeze_backbone
def unfreeze_backbone(self):
for name, param in self.backbone.conv_encoder.model.named_parameters():
for name, param in self.backbone.model.named_parameters():
param.requires_grad_(True)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModel.get_valid_ratio
......
......@@ -162,6 +162,26 @@ class DetaModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.hidden_size))
def create_and_check_deta_freeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()
model.freeze_backbone()
for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(False, param.requires_grad)
def create_and_check_deta_unfreeze_backbone(self, config, pixel_values, pixel_mask, labels):
model = DetaModel(config=config)
model.to(torch_device)
model.eval()
model.unfreeze_backbone()
for _, param in model.backbone.model.named_parameters():
self.parent.assertEqual(True, param.requires_grad)
def create_and_check_deta_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DetaForObjectDetection(config=config)
model.to(torch_device)
......@@ -250,6 +270,14 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_model(*config_and_inputs)
def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment