Unverified Commit 3c17c529 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

SuperPointModel -> SuperPointForKeypointDetection (#29757)

parent 1248f092
...@@ -250,6 +250,10 @@ The following auto classes are available for the following computer vision tasks ...@@ -250,6 +250,10 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForVideoClassification [[autodoc]] AutoModelForVideoClassification
### AutoModelForKeypointDetection
[[autodoc]] AutoModelForKeypointDetection
### AutoModelForMaskedImageModeling ### AutoModelForMaskedImageModeling
[[autodoc]] AutoModelForMaskedImageModeling [[autodoc]] AutoModelForMaskedImageModeling
......
...@@ -113,10 +113,8 @@ The original code can be found [here](https://github.com/magicleap/SuperPointPre ...@@ -113,10 +113,8 @@ The original code can be found [here](https://github.com/magicleap/SuperPointPre
- preprocess - preprocess
## SuperPointModel ## SuperPointForKeypointDetection
[[autodoc]] SuperPointModel [[autodoc]] SuperPointForKeypointDetection
- forward - forward
...@@ -1487,6 +1487,7 @@ else: ...@@ -1487,6 +1487,7 @@ else:
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING", "MODEL_FOR_MASK_GENERATION_MAPPING",
...@@ -1527,6 +1528,7 @@ else: ...@@ -1527,6 +1528,7 @@ else:
"AutoModelForImageSegmentation", "AutoModelForImageSegmentation",
"AutoModelForImageToImage", "AutoModelForImageToImage",
"AutoModelForInstanceSegmentation", "AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskedImageModeling", "AutoModelForMaskedImageModeling",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMaskGeneration", "AutoModelForMaskGeneration",
...@@ -3341,7 +3343,7 @@ else: ...@@ -3341,7 +3343,7 @@ else:
_import_structure["models.superpoint"].extend( _import_structure["models.superpoint"].extend(
[ [
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST", "SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel", "SuperPointForKeypointDetection",
"SuperPointPreTrainedModel", "SuperPointPreTrainedModel",
] ]
) )
...@@ -6319,6 +6321,7 @@ if TYPE_CHECKING: ...@@ -6319,6 +6321,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING, MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
...@@ -6359,6 +6362,7 @@ if TYPE_CHECKING: ...@@ -6359,6 +6362,7 @@ if TYPE_CHECKING:
AutoModelForImageSegmentation, AutoModelForImageSegmentation,
AutoModelForImageToImage, AutoModelForImageToImage,
AutoModelForInstanceSegmentation, AutoModelForInstanceSegmentation,
AutoModelForKeypointDetection,
AutoModelForMaskedImageModeling, AutoModelForMaskedImageModeling,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMaskGeneration, AutoModelForMaskGeneration,
...@@ -7852,7 +7856,7 @@ if TYPE_CHECKING: ...@@ -7852,7 +7856,7 @@ if TYPE_CHECKING:
) )
from .models.superpoint import ( from .models.superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST, SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel, SuperPointForKeypointDetection,
SuperPointPreTrainedModel, SuperPointPreTrainedModel,
) )
from .models.swiftformer import ( from .models.swiftformer import (
......
...@@ -52,6 +52,7 @@ else: ...@@ -52,6 +52,7 @@ else:
"MODEL_FOR_IMAGE_MAPPING", "MODEL_FOR_IMAGE_MAPPING",
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
...@@ -92,6 +93,7 @@ else: ...@@ -92,6 +93,7 @@ else:
"AutoModelForImageSegmentation", "AutoModelForImageSegmentation",
"AutoModelForImageToImage", "AutoModelForImageToImage",
"AutoModelForInstanceSegmentation", "AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskGeneration", "AutoModelForMaskGeneration",
"AutoModelForTextEncoding", "AutoModelForTextEncoding",
"AutoModelForMaskedImageModeling", "AutoModelForMaskedImageModeling",
...@@ -117,7 +119,6 @@ else: ...@@ -117,7 +119,6 @@ else:
"AutoModelWithLMHead", "AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification", "AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection", "AutoModelForZeroShotObjectDetection",
"AutoModelForKeypointDetection",
] ]
try: try:
...@@ -239,6 +240,7 @@ if TYPE_CHECKING: ...@@ -239,6 +240,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING, MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
......
...@@ -207,7 +207,6 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -207,7 +207,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"), ("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"), ("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"), ("starcoder2", "Starcoder2Model"),
("superpoint", "SuperPointModel"),
("swiftformer", "SwiftFormerModel"), ("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"), ("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"), ("swin2sr", "Swin2SRModel"),
...@@ -1225,6 +1224,14 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( ...@@ -1225,6 +1224,14 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("superpoint", "SuperPointForKeypointDetection"),
]
)
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[ [
("albert", "AlbertModel"), ("albert", "AlbertModel"),
...@@ -1360,6 +1367,10 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA ...@@ -1360,6 +1367,10 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
)
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES) MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping( MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
...@@ -1377,6 +1388,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass): ...@@ -1377,6 +1388,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModelForKeypointDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
class AutoModelForTextEncoding(_BaseAutoModelClass): class AutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
......
...@@ -40,7 +40,7 @@ except OptionalDependencyNotAvailable: ...@@ -40,7 +40,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["modeling_superpoint"] = [ _import_structure["modeling_superpoint"] = [
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST", "SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel", "SuperPointForKeypointDetection",
"SuperPointPreTrainedModel", "SuperPointPreTrainedModel",
] ]
...@@ -67,7 +67,7 @@ if TYPE_CHECKING: ...@@ -67,7 +67,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_superpoint import ( from .modeling_superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST, SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel, SuperPointForKeypointDetection,
SuperPointPreTrainedModel, SuperPointPreTrainedModel,
) )
......
...@@ -26,7 +26,7 @@ SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -26,7 +26,7 @@ SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class SuperPointConfig(PretrainedConfig): class SuperPointConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`SuperPointModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SuperPoint configuration with the defaults will yield a similar configuration to that of the SuperPoint
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture. [magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
...@@ -53,12 +53,12 @@ class SuperPointConfig(PretrainedConfig): ...@@ -53,12 +53,12 @@ class SuperPointConfig(PretrainedConfig):
Example: Example:
```python ```python
>>> from transformers import SuperPointConfig, SuperPointModel >>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
>>> # Initializing a SuperPoint superpoint style configuration >>> # Initializing a SuperPoint superpoint style configuration
>>> configuration = SuperPointConfig() >>> configuration = SuperPointConfig()
>>> # Initializing a model from the superpoint style configuration >>> # Initializing a model from the superpoint style configuration
>>> model = SuperPointModel(configuration) >>> model = SuperPointForKeypointDetection(configuration)
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
......
...@@ -18,7 +18,7 @@ import requests ...@@ -18,7 +18,7 @@ import requests
import torch import torch
from PIL import Image from PIL import Image
from transformers import SuperPointConfig, SuperPointImageProcessor, SuperPointModel from transformers import SuperPointConfig, SuperPointForKeypointDetection, SuperPointImageProcessor
def get_superpoint_config(): def get_superpoint_config():
...@@ -106,7 +106,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save ...@@ -106,7 +106,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save
rename_key(new_state_dict, src, dest) rename_key(new_state_dict, src, dest)
# Load HuggingFace model # Load HuggingFace model
model = SuperPointModel(config) model = SuperPointForKeypointDetection(config)
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
model.eval() model.eval()
print("Successfully loaded weights in the model") print("Successfully loaded weights in the model")
......
...@@ -390,7 +390,7 @@ Args: ...@@ -390,7 +390,7 @@ Args:
"SuperPoint model outputting keypoints and descriptors.", "SuperPoint model outputting keypoints and descriptors.",
SUPERPOINT_START_DOCSTRING, SUPERPOINT_START_DOCSTRING,
) )
class SuperPointModel(SuperPointPreTrainedModel): class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
""" """
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
......
...@@ -606,6 +606,9 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None ...@@ -606,6 +606,9 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = None
MODEL_FOR_MASK_GENERATION_MAPPING = None MODEL_FOR_MASK_GENERATION_MAPPING = None
...@@ -778,6 +781,13 @@ class AutoModelForInstanceSegmentation(metaclass=DummyObject): ...@@ -778,6 +781,13 @@ class AutoModelForInstanceSegmentation(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedImageModeling(metaclass=DummyObject): class AutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -8029,7 +8039,7 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject): ...@@ -8029,7 +8039,7 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SuperPointModel(metaclass=DummyObject): class SuperPointForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -28,7 +28,7 @@ if is_torch_available(): ...@@ -28,7 +28,7 @@ if is_torch_available():
from transformers import ( from transformers import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST, SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel, SuperPointForKeypointDetection,
) )
if is_vision_available(): if is_vision_available():
...@@ -86,7 +86,7 @@ class SuperPointModelTester: ...@@ -86,7 +86,7 @@ class SuperPointModelTester:
) )
def create_and_check_model(self, config, pixel_values): def create_and_check_model(self, config, pixel_values):
model = SuperPointModel(config=config) model = SuperPointForKeypointDetection(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
...@@ -109,7 +109,7 @@ class SuperPointModelTester: ...@@ -109,7 +109,7 @@ class SuperPointModelTester:
@require_torch @require_torch
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperPointModel,) if is_torch_available() else () all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else () all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False fx_compatible = False
...@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self): def create_and_test_config_common_properties(self):
return return
@unittest.skip(reason="SuperPointModel does not use inputs_embeds") @unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="SuperPointModel does not support input and output embeddings") @unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@unittest.skip(reason="SuperPointModel does not use feedforward chunking") @unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(reason="SuperPointModel is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training(self): def test_training(self):
pass pass
@unittest.skip(reason="SuperPointModel is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(reason="SuperPointModel is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant(self):
pass pass
@unittest.skip(reason="SuperPointModel is not trainable") @unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
...@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SuperPointModel.from_pretrained(model_name) model = SuperPointForKeypointDetection.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self): def test_forward_labels_should_be_none(self):
...@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase): ...@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference(self): def test_inference(self):
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device) model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
preprocessor = self.default_image_processor preprocessor = self.default_image_processor
images = prepare_imgs() images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment