Unverified Commit 5569552c authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Update output of SuperPointForKeypointDetection (#29809)

* Remove auto class

* Update ImagePointDescriptionOutput

* Update model outputs

* Rename output class

* Revert "Remove auto class"

This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519.

* Address comments
parent 386ef34e
...@@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor: ...@@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
@dataclass @dataclass
class ImagePointDescriptionOutput(ModelOutput): class SuperPointKeypointDescriptionOutput(ModelOutput):
""" """
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images, keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
...@@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput): ...@@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput):
and which are padding. and which are padding.
Args: Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Sequence of hidden-states at the output of the last layer of the decoder of the model. Loss computed during training.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Relative (x, y) coordinates of predicted keypoints in a given image. Relative (x, y) coordinates of predicted keypoints in a given image.
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`): scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
...@@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput): ...@@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput):
(also called feature maps) of the model at the output of each stage. (also called feature maps) of the model at the output of each stage.
""" """
last_hidden_state: torch.FloatTensor = None loss: Optional[torch.FloatTensor] = None
keypoints: Optional[torch.IntTensor] = None keypoints: Optional[torch.IntTensor] = None
scores: Optional[torch.FloatTensor] = None scores: Optional[torch.FloatTensor] = None
descriptors: Optional[torch.FloatTensor] = None descriptors: Optional[torch.FloatTensor] = None
...@@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): ...@@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, ImagePointDescriptionOutput]: ) -> Union[Tuple, SuperPointKeypointDescriptionOutput]:
""" """
Examples: Examples:
...@@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): ...@@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
>>> inputs = processor(image, return_tensors="pt") >>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
```""" ```"""
loss = None
if labels is not None: if labels is not None:
raise ValueError( raise ValueError("SuperPoint does not support training for now.")
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
pixel_values = self.extract_one_channel_pixel_values(pixel_values) pixel_values = self.extract_one_channel_pixel_values(pixel_values)
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
...@@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel): ...@@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
hidden_states = encoder_outputs[1] if output_hidden_states else None hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict: if not return_dict:
return tuple( return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
)
return ImagePointDescriptionOutput( return SuperPointKeypointDescriptionOutput(
last_hidden_state=last_hidden_state, loss=loss,
keypoints=keypoints, keypoints=keypoints,
scores=scores, scores=scores,
descriptors=descriptors, descriptors=descriptors,
......
...@@ -85,13 +85,17 @@ class SuperPointModelTester: ...@@ -85,13 +85,17 @@ class SuperPointModelTester:
border_removal_distance=self.border_removal_distance, border_removal_distance=self.border_removal_distance,
) )
def create_and_check_model(self, config, pixel_values): def create_and_check_keypoint_detection(self, config, pixel_values):
model = SuperPointForKeypointDetection(config=config) model = SuperPointForKeypointDetection(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
self.parent.assertEqual(result.keypoints.shape[-1], 2)
result = model(pixel_values, output_hidden_states=True)
self.parent.assertEqual( self.parent.assertEqual(
result.last_hidden_state.shape, result.hidden_states[-1].shape,
( (
self.batch_size, self.batch_size,
self.encoder_hidden_sizes[-1], self.encoder_hidden_sizes[-1],
...@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training(self): def test_training(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant(self):
pass pass
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
...@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
pass pass
def test_model(self): def test_keypoint_detection(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs() config, _ = self.model_tester.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment