Unverified Commit 3170af71 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Detr`] Fix detr BatchNorm replacement issue (#25230)



* fix detr weird issue

* Update src/transformers/models/conditional_detr/modeling_conditional_detr.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix copies

* fix copies

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 05ebb026
...@@ -310,19 +310,27 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module): ...@@ -310,19 +310,27 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
def replace_batch_norm(m, name=""): def replace_batch_norm(model):
for attr_str in dir(m): r"""
target_attr = getattr(m, attr_str) Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
if isinstance(target_attr, nn.BatchNorm2d):
frozen = ConditionalDetrFrozenBatchNorm2d(target_attr.num_features) Args:
bn = getattr(m, attr_str) model (torch.nn.Module):
frozen.weight.data.copy_(bn.weight) input model
frozen.bias.data.copy_(bn.bias) """
frozen.running_mean.data.copy_(bn.running_mean) for name, module in model.named_children():
frozen.running_var.data.copy_(bn.running_var) if isinstance(module, nn.BatchNorm2d):
setattr(m, attr_str, frozen) new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features)
for n, ch in m.named_children():
replace_batch_norm(ch, n) new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
......
...@@ -357,19 +357,27 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module): ...@@ -357,19 +357,27 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
def replace_batch_norm(m, name=""): def replace_batch_norm(model):
for attr_str in dir(m): r"""
target_attr = getattr(m, attr_str) Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DeformableDetrFrozenBatchNorm2d(target_attr.num_features) Args:
bn = getattr(m, attr_str) model (torch.nn.Module):
frozen.weight.data.copy_(bn.weight) input model
frozen.bias.data.copy_(bn.bias) """
frozen.running_mean.data.copy_(bn.running_mean) for name, module in model.named_children():
frozen.running_var.data.copy_(bn.running_var) if isinstance(module, nn.BatchNorm2d):
setattr(m, attr_str, frozen) new_module = DeformableDetrFrozenBatchNorm2d(module.num_features)
for n, ch in m.named_children():
replace_batch_norm(ch, n) new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DeformableDetrConvEncoder(nn.Module): class DeformableDetrConvEncoder(nn.Module):
......
...@@ -295,19 +295,27 @@ class DetaFrozenBatchNorm2d(nn.Module): ...@@ -295,19 +295,27 @@ class DetaFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->Deta # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->Deta
def replace_batch_norm(m, name=""): def replace_batch_norm(model):
for attr_str in dir(m): r"""
target_attr = getattr(m, attr_str) Recursively replace all `torch.nn.BatchNorm2d` with `DetaFrozenBatchNorm2d`.
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetaFrozenBatchNorm2d(target_attr.num_features) Args:
bn = getattr(m, attr_str) model (torch.nn.Module):
frozen.weight.data.copy_(bn.weight) input model
frozen.bias.data.copy_(bn.bias) """
frozen.running_mean.data.copy_(bn.running_mean) for name, module in model.named_children():
frozen.running_var.data.copy_(bn.running_var) if isinstance(module, nn.BatchNorm2d):
setattr(m, attr_str, frozen) new_module = DetaFrozenBatchNorm2d(module.num_features)
for n, ch in m.named_children():
replace_batch_norm(ch, n) new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DetaBackboneWithPositionalEncodings(nn.Module): class DetaBackboneWithPositionalEncodings(nn.Module):
......
...@@ -304,19 +304,27 @@ class DetrFrozenBatchNorm2d(nn.Module): ...@@ -304,19 +304,27 @@ class DetrFrozenBatchNorm2d(nn.Module):
return x * scale + bias return x * scale + bias
def replace_batch_norm(m, name=""): def replace_batch_norm(model):
for attr_str in dir(m): r"""
target_attr = getattr(m, attr_str) Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetrFrozenBatchNorm2d(target_attr.num_features) Args:
bn = getattr(m, attr_str) model (torch.nn.Module):
frozen.weight.data.copy_(bn.weight) input model
frozen.bias.data.copy_(bn.bias) """
frozen.running_mean.data.copy_(bn.running_mean) for name, module in model.named_children():
frozen.running_var.data.copy_(bn.running_var) if isinstance(module, nn.BatchNorm2d):
setattr(m, attr_str, frozen) new_module = DetrFrozenBatchNorm2d(module.num_features)
for n, ch in m.named_children():
replace_batch_norm(ch, n) new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DetrConvEncoder(nn.Module): class DetrConvEncoder(nn.Module):
......
...@@ -239,19 +239,27 @@ class TableTransformerFrozenBatchNorm2d(nn.Module): ...@@ -239,19 +239,27 @@ class TableTransformerFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->TableTransformer # Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->TableTransformer
def replace_batch_norm(m, name=""): def replace_batch_norm(model):
for attr_str in dir(m): r"""
target_attr = getattr(m, attr_str) Recursively replace all `torch.nn.BatchNorm2d` with `TableTransformerFrozenBatchNorm2d`.
if isinstance(target_attr, nn.BatchNorm2d):
frozen = TableTransformerFrozenBatchNorm2d(target_attr.num_features) Args:
bn = getattr(m, attr_str) model (torch.nn.Module):
frozen.weight.data.copy_(bn.weight) input model
frozen.bias.data.copy_(bn.bias) """
frozen.running_mean.data.copy_(bn.running_mean) for name, module in model.named_children():
frozen.running_var.data.copy_(bn.running_var) if isinstance(module, nn.BatchNorm2d):
setattr(m, attr_str, frozen) new_module = TableTransformerFrozenBatchNorm2d(module.num_features)
for n, ch in m.named_children():
replace_batch_norm(ch, n) new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment