Unverified Commit f370bebd authored by Pedro Gabriel Gengo Lourenço's avatar Pedro Gabriel Gengo Lourenço Committed by GitHub
Browse files

Bugfix device map detr model (#26849)



* Fixed replace_batch_norm when on meta device

* lint fix

* Adding coauthor
Co-authored-by: default avatarPi Esposito <piero.skywalker@gmail.com>

* Removed tests

* Remove unused deps

* Try to fix copy issue

* try fix copy one more time

* Reverted import changes

---------
Co-authored-by: default avatarPi Esposito <piero.skywalker@gmail.com>
parent b0d1d7f7
......@@ -322,6 +322,7 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d):
new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
......@@ -1145,6 +1146,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
config_class = ConditionalDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -369,6 +369,7 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d):
new_module = DeformableDetrFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
......@@ -1061,6 +1062,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -307,6 +307,7 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d):
new_module = DetaFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
......@@ -947,11 +948,12 @@ class DetaClassificationHead(nn.Module):
return hidden_states
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetr->Deta
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
class DetaPreTrainedModel(PreTrainedModel):
config_class = DetaConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -316,6 +316,7 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d):
new_module = DetrFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
......@@ -901,6 +902,7 @@ class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -251,6 +251,7 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d):
new_module = TableTransformerFrozenBatchNorm2d(module.num_features)
if not module.weight.device == torch.device("meta"):
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
......@@ -813,6 +814,11 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
config_class = TableTransformerConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [
r"TableTransformerConvEncoder",
r"TableTransformerEncoderLayer",
r"TableTransformerDecoderLayer",
]
def _init_weights(self, module):
std = self.config.init_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment