Unverified Commit 3170af71 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Detr`] Fix detr BatchNorm replacement issue (#25230)



* fix detr weird issue

* Update src/transformers/models/conditional_detr/modeling_conditional_detr.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix copies

* fix copies

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 05ebb026
......@@ -310,19 +310,27 @@ class ConditionalDetrFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->ConditionalDetr
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = ConditionalDetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `ConditionalDetrFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
......
......@@ -357,19 +357,27 @@ class DeformableDetrFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->DeformableDetr
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DeformableDetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `DeformableDetrFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = DeformableDetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DeformableDetrConvEncoder(nn.Module):
......
......@@ -295,19 +295,27 @@ class DetaFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->Deta
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetaFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `DetaFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = DetaFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DetaBackboneWithPositionalEncodings(nn.Module):
......
......@@ -304,19 +304,27 @@ class DetrFrozenBatchNorm2d(nn.Module):
return x * scale + bias
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = DetrFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `DetrFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = DetrFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
class DetrConvEncoder(nn.Module):
......
......@@ -239,19 +239,27 @@ class TableTransformerFrozenBatchNorm2d(nn.Module):
# Copied from transformers.models.detr.modeling_detr.replace_batch_norm with Detr->TableTransformer
def replace_batch_norm(m, name=""):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if isinstance(target_attr, nn.BatchNorm2d):
frozen = TableTransformerFrozenBatchNorm2d(target_attr.num_features)
bn = getattr(m, attr_str)
frozen.weight.data.copy_(bn.weight)
frozen.bias.data.copy_(bn.bias)
frozen.running_mean.data.copy_(bn.running_mean)
frozen.running_var.data.copy_(bn.running_var)
setattr(m, attr_str, frozen)
for n, ch in m.named_children():
replace_batch_norm(ch, n)
def replace_batch_norm(model):
r"""
Recursively replace all `torch.nn.BatchNorm2d` with `TableTransformerFrozenBatchNorm2d`.
Args:
model (torch.nn.Module):
input model
"""
for name, module in model.named_children():
if isinstance(module, nn.BatchNorm2d):
new_module = TableTransformerFrozenBatchNorm2d(module.num_features)
new_module.weight.data.copy_(module.weight)
new_module.bias.data.copy_(module.bias)
new_module.running_mean.data.copy_(module.running_mean)
new_module.running_var.data.copy_(module.running_var)
model._modules[name] = new_module
if len(list(module.children())) > 0:
replace_batch_norm(module)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment