Unverified Commit 3b6943e7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DETR] Add num_channels attribute (#18714)



* Add num_channels attribute

* Fix code quality
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 811c4c9f
...@@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig): ...@@ -42,8 +42,9 @@ class DetrConfig(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 100): num_queries (`int`, *optional*, defaults to 100):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
detect in a single image. For COCO, we recommend 100 queries. detect in a single image. For COCO, we recommend 100 queries.
...@@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig): ...@@ -132,6 +133,7 @@ class DetrConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
num_channels=3,
num_queries=100, num_queries=100,
max_position_embeddings=1024, max_position_embeddings=1024,
encoder_layers=6, encoder_layers=6,
...@@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig): ...@@ -167,6 +169,7 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs **kwargs
): ):
self.num_channels = num_channels
self.num_queries = num_queries self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.d_model = d_model self.d_model = d_model
......
...@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module): ...@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
""" """
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool): def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3):
super().__init__() super().__init__()
kwargs = {} kwargs = {}
...@@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module): ...@@ -336,7 +336,12 @@ class DetrTimmConvEncoder(nn.Module):
requires_backends(self, ["timm"]) requires_backends(self, ["timm"])
backbone = create_model( backbone = create_model(
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs name,
pretrained=use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=num_channels,
**kwargs,
) )
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
...@@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel): ...@@ -1179,7 +1184,9 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config) super().__init__(config)
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone) backbone = DetrTimmConvEncoder(
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config) position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings) self.backbone = DetrConvModel(backbone, position_embeddings)
......
...@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -416,6 +416,26 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.assertTrue(outputs) self.assertTrue(outputs)
def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# use greyscale pixel values
inputs_dict["pixel_values"] = floats_tensor(
[self.model_tester.batch_size, 1, self.model_tester.min_size, self.model_tester.max_size]
)
# let's set num_channels to 1
config.num_channels = 1
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertTrue(outputs)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment