Unverified Commit 899d8351 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

[DETA] Improvement and Sync from DETA especially for training (#27990)



* [DETA] fix freeze/unfreeze function

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add freeze/unfreeze test case in DETA

* fix type

* fix typo 2

* fix : enable aux and enc loss in training pipeline

* Add unsynced variables from original DETA for training

* modification for passing CI test

* make style

* make fix

* manual make fix

* change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking

* remove print

* divide configuration in DetaModel and DetaForObjectDetection

* image smaller size than 224 will give topk error

* pred_boxes and logits should be equivalent to two_stage_num_proposals

* add missing part in DetaConfig

* Update src/transformers/models/deta/modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add docstring in configure and prettify TO DO part

* change distribute related code to accelerate

* Update src/transformers/models/deta/configuration_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/deta/test_modeling_deta.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* protect importing accelerate

* change variable name to specific value

* wrong import

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 57e9c832
......@@ -109,6 +109,10 @@ class DetaConfig(PretrainedConfig):
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
assign_first_stage (`bool`, *optional*, defaults to `True`):
Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7.
assign_second_stage (`bool`, *optional*, defaults to `True`):
Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure.
Examples:
......@@ -161,6 +165,7 @@ class DetaConfig(PretrainedConfig):
two_stage_num_proposals=300,
with_box_refine=True,
assign_first_stage=True,
assign_second_stage=True,
class_cost=1,
bbox_cost=5,
giou_cost=2,
......@@ -208,6 +213,7 @@ class DetaConfig(PretrainedConfig):
self.two_stage_num_proposals = two_stage_num_proposals
self.with_box_refine = with_box_refine
self.assign_first_stage = assign_first_stage
self.assign_second_stage = assign_second_stage
if two_stage is True and with_box_refine is False:
raise ValueError("If two_stage is True, with_box_refine must be True.")
# Hungarian matcher
......
......@@ -1052,7 +1052,7 @@ class DetaImageProcessor(BaseImageProcessor):
score = all_scores[b]
lbls = all_labels[b]
pre_topk = score.topk(min(10000, len(score))).indices
pre_topk = score.topk(min(10000, num_queries * num_labels)).indices
box = box[pre_topk]
score = score[pre_topk]
lbls = lbls[pre_topk]
......
......@@ -38,7 +38,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_torchvision_available, logging, requires_backends
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone
from .configuration_deta import DetaConfig
......@@ -46,6 +46,10 @@ from .configuration_deta import DetaConfig
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
......@@ -105,7 +109,6 @@ class DetaDecoderOutput(ModelOutput):
@dataclass
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModelOutput with DeformableDetr->Deta,Deformable DETR->DETA
class DetaModelOutput(ModelOutput):
"""
Base class for outputs of the Deformable DETR encoder-decoder model.
......@@ -147,6 +150,8 @@ class DetaModelOutput(ModelOutput):
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
"""
init_reference_points: torch.FloatTensor = None
......@@ -161,10 +166,10 @@ class DetaModelOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional[torch.FloatTensor] = None
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
output_proposals: Optional[torch.FloatTensor] = None
@dataclass
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrObjectDetectionOutput with DeformableDetr->Deta
class DetaObjectDetectionOutput(ModelOutput):
"""
Output type of [`DetaForObjectDetection`].
......@@ -223,6 +228,8 @@ class DetaObjectDetectionOutput(ModelOutput):
foreground and background).
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
Logits of predicted bounding boxes coordinates in the first stage.
output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`):
Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals.
"""
loss: Optional[torch.FloatTensor] = None
......@@ -242,6 +249,7 @@ class DetaObjectDetectionOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
enc_outputs_class: Optional = None
enc_outputs_coord_logits: Optional = None
output_proposals: Optional[torch.FloatTensor] = None
def _get_clones(module, N):
......@@ -1632,6 +1640,7 @@ class DetaModel(DetaPreTrainedModel):
batch_size, _, num_channels = encoder_outputs[0].shape
enc_outputs_class = None
enc_outputs_coord_logits = None
output_proposals = None
if self.config.two_stage:
object_query_embedding, output_proposals, level_ids = self.gen_encoder_output_proposals(
encoder_outputs[0], ~mask_flatten, spatial_shapes
......@@ -1746,6 +1755,7 @@ class DetaModel(DetaPreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
enc_outputs_class=enc_outputs_class,
enc_outputs_coord_logits=enc_outputs_coord_logits,
output_proposals=output_proposals,
)
......@@ -1804,12 +1814,15 @@ class DetaForObjectDetection(DetaPreTrainedModel):
self.post_init()
@torch.jit.unused
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection._set_aux_loss
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
aux_loss = [
{"logits": logits, "pred_boxes": pred_boxes}
for logits, pred_boxes in zip(outputs_class.transpose(0, 1)[:-1], outputs_coord.transpose(0, 1)[:-1])
]
return aux_loss
@add_start_docstrings_to_model_forward(DETA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1929,21 +1942,25 @@ class DetaForObjectDetection(DetaPreTrainedModel):
focal_alpha=self.config.focal_alpha,
losses=losses,
num_queries=self.config.num_queries,
assign_first_stage=self.config.assign_first_stage,
assign_second_stage=self.config.assign_second_stage,
)
criterion.to(logits.device)
# Third: compute the losses, based on outputs and labels
outputs_loss = {}
outputs_loss["logits"] = logits
outputs_loss["pred_boxes"] = pred_boxes
outputs_loss["init_reference"] = init_reference
if self.config.auxiliary_loss:
intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4]
outputs_class = self.class_embed(intermediate)
outputs_coord = self.bbox_embed(intermediate).sigmoid()
auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord)
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
if self.config.two_stage:
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
outputs_loss["enc_outputs"] = {
"logits": outputs.enc_outputs_class,
"pred_boxes": enc_outputs_coord,
"anchors": outputs.output_proposals.sigmoid(),
}
loss_dict = criterion(outputs_loss, labels)
# Fourth: compute total loss, as a weighted sum of the various losses
......@@ -1953,6 +1970,7 @@ class DetaForObjectDetection(DetaPreTrainedModel):
aux_weight_dict = {}
for i in range(self.config.decoder_layers - 1):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
......@@ -1983,6 +2001,7 @@ class DetaForObjectDetection(DetaPreTrainedModel):
init_reference_points=outputs.init_reference_points,
enc_outputs_class=outputs.enc_outputs_class,
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
output_proposals=outputs.output_proposals,
)
return dict_outputs
......@@ -2192,7 +2211,7 @@ class DetaLoss(nn.Module):
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
losses applied, see each loss' doc.
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
outputs_without_aux = {k: v for k, v in outputs.items() if k not in ("auxiliary_outputs", "enc_outputs")}
# Retrieve the matching between the outputs of the last layer and the targets
if self.assign_second_stage:
......@@ -2203,11 +2222,12 @@ class DetaLoss(nn.Module):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()
# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
# Compute all the requested losses
losses = {}
......@@ -2228,17 +2248,13 @@ class DetaLoss(nn.Module):
enc_outputs = outputs["enc_outputs"]
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
if self.assign_first_stage:
indices = self.stg1_assigner(enc_outputs, bin_targets)
else:
indices = self.matcher(enc_outputs, bin_targets)
for loss in self.losses:
kwargs = {}
if loss == "labels":
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
l_dict = {k + "_enc": v for k, v in l_dict.items()}
losses.update(l_dict)
......@@ -2662,7 +2678,7 @@ class DetaStage2Assigner(nn.Module):
sampled_idxs,
sampled_gt_classes,
) = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
matched_idxs, matched_labels, targets[b]["labels"]
matched_idxs, matched_labels, targets[b]["class_labels"]
)
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
pos_gt_inds = matched_idxs[pos_pr_inds]
......@@ -2727,7 +2743,7 @@ class DetaStage1Assigner(nn.Module):
) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
matched_labels = self._subsample_labels(matched_labels)
all_pr_inds = torch.arange(len(anchors))
all_pr_inds = torch.arange(len(anchors), device=matched_labels.device)
pos_pr_inds = all_pr_inds[matched_labels == 1]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
......
......@@ -57,14 +57,17 @@ class DetaModelTester:
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
num_queries=12,
two_stage_num_proposals=12,
num_channels=3,
image_size=196,
image_size=224,
n_targets=8,
num_labels=91,
num_feature_levels=4,
encoder_n_points=2,
decoder_n_points=6,
two_stage=False,
two_stage=True,
assign_first_stage=True,
assign_second_stage=True,
):
self.parent = parent
self.batch_size = batch_size
......@@ -78,6 +81,7 @@ class DetaModelTester:
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_queries = num_queries
self.two_stage_num_proposals = two_stage_num_proposals
self.num_channels = num_channels
self.image_size = image_size
self.n_targets = n_targets
......@@ -86,6 +90,8 @@ class DetaModelTester:
self.encoder_n_points = encoder_n_points
self.decoder_n_points = decoder_n_points
self.two_stage = two_stage
self.assign_first_stage = assign_first_stage
self.assign_second_stage = assign_second_stage
# we also set the expected seq length for both encoder and decoder
self.encoder_seq_length = (
......@@ -96,7 +102,7 @@ class DetaModelTester:
)
self.decoder_seq_length = self.num_queries
def prepare_config_and_inputs(self):
def prepare_config_and_inputs(self, model_class_name):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
......@@ -114,10 +120,10 @@ class DetaModelTester:
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
labels.append(target)
config = self.get_config()
config = self.get_config(model_class_name)
return config, pixel_values, pixel_mask, labels
def get_config(self):
def get_config(self, model_class_name):
resnet_config = ResNetConfig(
num_channels=3,
embeddings_size=10,
......@@ -128,6 +134,9 @@ class DetaModelTester:
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
)
two_stage = model_class_name == "DetaForObjectDetection"
assign_first_stage = model_class_name == "DetaForObjectDetection"
assign_second_stage = model_class_name == "DetaForObjectDetection"
return DetaConfig(
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
......@@ -139,16 +148,19 @@ class DetaModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
num_queries=self.num_queries,
two_stage_num_proposals=self.two_stage_num_proposals,
num_labels=self.num_labels,
num_feature_levels=self.num_feature_levels,
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
two_stage=self.two_stage,
two_stage=two_stage,
assign_first_stage=assign_first_stage,
assign_second_stage=assign_second_stage,
backbone_config=resnet_config,
)
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs(model_class_name)
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
return config, inputs_dict
......@@ -190,14 +202,14 @@ class DetaModelTester:
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4))
@require_torchvision
......@@ -267,19 +279,19 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.config_tester.check_config_can_be_init_without_params()
def test_deta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_model(*config_and_inputs)
def test_deta_freeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs)
def test_deta_unfreeze_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel")
self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs)
def test_deta_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaForObjectDetection")
self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="DETA does not use inputs_embeds")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment