Unverified Commit 3c17c529 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

SuperPointModel -> SuperPointForKeypointDetection (#29757)

parent 1248f092
......@@ -250,6 +250,10 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForVideoClassification
### AutoModelForKeypointDetection
[[autodoc]] AutoModelForKeypointDetection
### AutoModelForMaskedImageModeling
[[autodoc]] AutoModelForMaskedImageModeling
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the MIT License; you may not use this file except in compliance with
the License.
the License.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
......@@ -113,10 +113,8 @@ The original code can be found [here](https://github.com/magicleap/SuperPointPre
- preprocess
## SuperPointModel
## SuperPointForKeypointDetection
[[autodoc]] SuperPointModel
[[autodoc]] SuperPointForKeypointDetection
- forward
......@@ -1487,6 +1487,7 @@ else:
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
......@@ -1527,6 +1528,7 @@ else:
"AutoModelForImageSegmentation",
"AutoModelForImageToImage",
"AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskedImageModeling",
"AutoModelForMaskedLM",
"AutoModelForMaskGeneration",
......@@ -3341,7 +3343,7 @@ else:
_import_structure["models.superpoint"].extend(
[
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel",
"SuperPointForKeypointDetection",
"SuperPointPreTrainedModel",
]
)
......@@ -6319,6 +6321,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
......@@ -6359,6 +6362,7 @@ if TYPE_CHECKING:
AutoModelForImageSegmentation,
AutoModelForImageToImage,
AutoModelForInstanceSegmentation,
AutoModelForKeypointDetection,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMaskGeneration,
......@@ -7852,7 +7856,7 @@ if TYPE_CHECKING:
)
from .models.superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
SuperPointPreTrainedModel,
)
from .models.swiftformer import (
......
......@@ -52,6 +52,7 @@ else:
"MODEL_FOR_IMAGE_MAPPING",
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
......@@ -92,6 +93,7 @@ else:
"AutoModelForImageSegmentation",
"AutoModelForImageToImage",
"AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskGeneration",
"AutoModelForTextEncoding",
"AutoModelForMaskedImageModeling",
......@@ -117,7 +119,6 @@ else:
"AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection",
"AutoModelForKeypointDetection",
]
try:
......@@ -239,6 +240,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
......
......@@ -207,7 +207,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
("superpoint", "SuperPointModel"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"),
......@@ -1225,6 +1224,14 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("superpoint", "SuperPointForKeypointDetection"),
]
)
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertModel"),
......@@ -1360,6 +1367,10 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
)
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
......@@ -1377,6 +1388,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModelForKeypointDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
class AutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
......
......@@ -40,7 +40,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_superpoint"] = [
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel",
"SuperPointForKeypointDetection",
"SuperPointPreTrainedModel",
]
......@@ -67,7 +67,7 @@ if TYPE_CHECKING:
else:
from .modeling_superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
SuperPointPreTrainedModel,
)
......
......@@ -26,7 +26,7 @@ SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class SuperPointConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SuperPointModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SuperPoint
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
......@@ -53,12 +53,12 @@ class SuperPointConfig(PretrainedConfig):
Example:
```python
>>> from transformers import SuperPointConfig, SuperPointModel
>>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
>>> # Initializing a SuperPoint superpoint style configuration
>>> configuration = SuperPointConfig()
>>> # Initializing a model from the superpoint style configuration
>>> model = SuperPointModel(configuration)
>>> model = SuperPointForKeypointDetection(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
......
......@@ -18,7 +18,7 @@ import requests
import torch
from PIL import Image
from transformers import SuperPointConfig, SuperPointImageProcessor, SuperPointModel
from transformers import SuperPointConfig, SuperPointForKeypointDetection, SuperPointImageProcessor
def get_superpoint_config():
......@@ -106,7 +106,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save
rename_key(new_state_dict, src, dest)
# Load HuggingFace model
model = SuperPointModel(config)
model = SuperPointForKeypointDetection(config)
model.load_state_dict(new_state_dict)
model.eval()
print("Successfully loaded weights in the model")
......
......@@ -390,7 +390,7 @@ Args:
"SuperPoint model outputting keypoints and descriptors.",
SUPERPOINT_START_DOCSTRING,
)
class SuperPointModel(SuperPointPreTrainedModel):
class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
"""
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
......
......@@ -606,6 +606,9 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = None
MODEL_FOR_MASK_GENERATION_MAPPING = None
......@@ -778,6 +781,13 @@ class AutoModelForInstanceSegmentation(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
......@@ -8029,7 +8039,7 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SuperPointModel(metaclass=DummyObject):
class SuperPointForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
......
......@@ -28,7 +28,7 @@ if is_torch_available():
from transformers import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
)
if is_vision_available():
......@@ -86,7 +86,7 @@ class SuperPointModelTester:
)
def create_and_check_model(self, config, pixel_values):
model = SuperPointModel(config=config)
model = SuperPointForKeypointDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
......@@ -109,7 +109,7 @@ class SuperPointModelTester:
@require_torch
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False
......@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
@unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
@unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
@unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
......@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SuperPointModel.from_pretrained(model_name)
model = SuperPointForKeypointDetection.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self):
......@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment