Unverified Commit 3b6943e7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DETR] Add num_channels attribute (#18714)



* Add num_channels attribute

* Fix code quality
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 811c4c9f
......@@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
detect in a single image. For COCO, we recommend 100 queries.
......@@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig):
def __init__(
self,
num_channels=3,
num_queries=100,
max_position_embeddings=1024,
encoder_layers=6,
......@@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs
):
self.num_channels = num_channels
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
......
......@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
"""
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
super().__init__()
kwargs = {}
......@@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module):
requires_backends(self, ["timm"])
backbone = create_model(
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
)
# replace batch norm by frozen batch norm
with torch.no_grad():
......@@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
backbone = DetrTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings)
......
......@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.assertTrue(outputs)
def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# use greyscale pixel values
inputs_dict["pixel_values"] = floats_tensor(
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
)
# let's set num_channels to 1
config.num_channels = 1
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertTrue(outputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment