Unverified Commit f370bebd authored by Pedro Gabriel Gengo Lourenço's avatar Pedro Gabriel Gengo Lourenço Committed by GitHub
Browse files

Bugfix device map detr model (#26849)



* Fixed replace_batch_norm when on meta device

* lint fix

* Adding coauthor
Co-authored-by: default avatarPi Esposito <piero.skywalker@gmail.com>

* Removed tests

* Remove unused deps

* Try to fix copy issue

* try fix copy one more time

* Reverted import changes

---------
Co-authored-by: default avatarPi Esposito <piero.skywalker@gmail.com>
parent b0d1d7f7
...@@ -322,10 +322,11 @@ def replace_batch_norm(model): ...@@ -322,10 +322,11 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d): if isinstance(module, nn.BatchNorm2d):
new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features) new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight) if not module.weight.device == torch.device("meta"):
new_module.bias.data.copy_(module.bias) new_module.weight.data.copy_(module.weight)
new_module.running_mean.data.copy_(module.running_mean) new_module.bias.data.copy_(module.bias)
new_module.running_var.data.copy_(module.running_var) new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module model._modules[name] = new_module
...@@ -1145,6 +1146,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): ...@@ -1145,6 +1146,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
config_class = ConditionalDetrConfig config_class = ConditionalDetrConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -369,10 +369,11 @@ def replace_batch_norm(model): ...@@ -369,10 +369,11 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d): if isinstance(module, nn.BatchNorm2d):
new_module = DeformableDetrFrozenBatchNorm2d(module.num_features) new_module = DeformableDetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight) if not module.weight.device == torch.device("meta"):
new_module.bias.data.copy_(module.bias) new_module.weight.data.copy_(module.weight)
new_module.running_mean.data.copy_(module.running_mean) new_module.bias.data.copy_(module.bias)
new_module.running_var.data.copy_(module.running_var) new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module model._modules[name] = new_module
...@@ -1061,6 +1062,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): ...@@ -1061,6 +1062,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig config_class = DeformableDetrConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -307,10 +307,11 @@ def replace_batch_norm(model): ...@@ -307,10 +307,11 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d): if isinstance(module, nn.BatchNorm2d):
new_module = DetaFrozenBatchNorm2d(module.num_features) new_module = DetaFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight) if not module.weight.device == torch.device("meta"):
new_module.bias.data.copy_(module.bias) new_module.weight.data.copy_(module.weight)
new_module.running_mean.data.copy_(module.running_mean) new_module.bias.data.copy_(module.bias)
new_module.running_var.data.copy_(module.running_var) new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module model._modules[name] = new_module
...@@ -947,11 +948,12 @@ class DetaClassificationHead(nn.Module): ...@@ -947,11 +948,12 @@ class DetaClassificationHead(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetr->Deta # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
class DetaPreTrainedModel(PreTrainedModel): class DetaPreTrainedModel(PreTrainedModel):
config_class = DetaConfig config_class = DetaConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -316,10 +316,11 @@ def replace_batch_norm(model): ...@@ -316,10 +316,11 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d): if isinstance(module, nn.BatchNorm2d):
new_module = DetrFrozenBatchNorm2d(module.num_features) new_module = DetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight) if not module.weight.device == torch.device("meta"):
new_module.bias.data.copy_(module.bias) new_module.weight.data.copy_(module.weight)
new_module.running_mean.data.copy_(module.running_mean) new_module.bias.data.copy_(module.bias)
new_module.running_var.data.copy_(module.running_var) new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module model._modules[name] = new_module
...@@ -901,6 +902,7 @@ class DetrPreTrainedModel(PreTrainedModel): ...@@ -901,6 +902,7 @@ class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig config_class = DetrConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -251,10 +251,11 @@ def replace_batch_norm(model): ...@@ -251,10 +251,11 @@ def replace_batch_norm(model):
if isinstance(module, nn.BatchNorm2d): if isinstance(module, nn.BatchNorm2d):
new_module = TableTransformerFrozenBatchNorm2d(module.num_features) new_module = TableTransformerFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight) if not module.weight.device == torch.device("meta"):
new_module.bias.data.copy_(module.bias) new_module.weight.data.copy_(module.weight)
new_module.running_mean.data.copy_(module.running_mean) new_module.bias.data.copy_(module.bias)
new_module.running_var.data.copy_(module.running_var) new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module model._modules[name] = new_module
...@@ -813,6 +814,11 @@ class TableTransformerPreTrainedModel(PreTrainedModel): ...@@ -813,6 +814,11 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
config_class = TableTransformerConfig config_class = TableTransformerConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values" main_input_name = "pixel_values"
_no_split_modules = [
r"TableTransformerConvEncoder",
r"TableTransformerEncoderLayer",
r"TableTransformerDecoderLayer",
]
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment