Unverified Commit 27e90738 authored by Anugunj Naman's avatar Anugunj Naman Committed by GitHub
Browse files

Fix Automatic Download of Pretrained Weights in DETR (#17712)



* added use_backbone_pretrained

* style fixes

* update

* Update detr.mdx

* Update detr.mdx

* Update detr.mdx

* update using doc py

* Update detr.mdx

* Update src/transformers/models/detr/configuration_detr.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent b681e12d
......@@ -113,6 +113,28 @@ Tips:
- The size of the images will determine the amount of memory being used, and will thus determine the `batch_size`.
It is advised to use a batch size of 2 per GPU. See [this Github thread](https://github.com/facebookresearch/detr/issues/150) for more info.
There are three ways to instantiate a DETR model (depending on what you prefer):
Option 1: Instantiate DETR with pre-trained weights for entire model
```py
>>> from transformers import DetrForObjectDetection
>>> model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")
```
Option 2: Instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
```py
>>> from transformers import DetrConfig, DetrForObjectDetection
>>> config = DetrConfig()
>>> model = DetrForObjectDetection(config)
```
Option 3: Instantiate DETR with randomly initialized weights for backbone + Transformer
```py
>>> config = DetrConfig(use_pretrained_backbone=False)
>>> model = DetrForObjectDetection(config)
```
As a summary, consider the following table:
| Task | Object detection | Instance segmentation | Panoptic segmentation |
......@@ -166,4 +188,4 @@ mean Average Precision (mAP) and Panoptic Quality (PQ). The latter objects are i
## DetrForSegmentation
[[autodoc]] DetrForSegmentation
- forward
\ No newline at end of file
- forward
......@@ -82,6 +82,8 @@ class DetrConfig(PretrainedConfig):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
class_cost (`float`, *optional*, defaults to 1):
......@@ -147,6 +149,7 @@ class DetrConfig(PretrainedConfig):
auxiliary_loss=False,
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
dilation=False,
class_cost=1,
bbox_cost=5,
......@@ -180,6 +183,7 @@ class DetrConfig(PretrainedConfig):
self.auxiliary_loss = auxiliary_loss
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
......
......@@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):
"""
def __init__(self, name: str, dilation: bool):
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
super().__init__()
kwargs = {}
......@@ -335,7 +335,9 @@ class DetrTimmConvEncoder(nn.Module):
requires_backends(self, ["timm"])
backbone = create_model(name, pretrained=True, features_only=True, out_indices=(1, 2, 3, 4), **kwargs)
backbone = create_model(
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
......@@ -1177,7 +1179,7 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config)
# Create backbone + positional encoding
backbone = DetrTimmConvEncoder(config.backbone, config.dilation)
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment