Unverified Commit 56baa033 authored by StevenBucaille's avatar StevenBucaille Committed by GitHub
Browse files

Implementation of SuperPoint and AutoModelForKeypointDetection (#28966)



* Added SuperPoint docs

* Added tests

* Removed commented part

* Commit to create and fix add_superpoint branch with a new branch

* Fixed dummy_pt_objects

* Committed missing files

* Fixed README.md

* Apply suggestions from code review

Fixed small changes
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Moved ImagePointDescriptionOutput from modeling_outputs.py to modeling_superpoint.py

* Removed AutoModelForKeypointDetection and related stuff

* Fixed inconsistencies in image_processing_superpoint.py

* Moved infer_on_model logic simply in test_inference

* Fixed bugs, added labels to forward method with checks whether it is properly a None value, also added tests about this logic in test_modeling_superpoint.py

* Added tests to SuperPointImageProcessor to ensure that images are properly converted to grayscale

* Removed remaining mentions of MODEL_FOR_KEYPOINT_DETECTION_MAPPING

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fixed from (w, h) to (h, w) as input for tests

* Removed unnecessary condition

* Moved last_hidden_state to be the first returned

* Moved last_hidden_state to be the first returned (bis)

* Moved last_hidden_state to be the first returned (ter)

* Switched image_width and image_height in tests to match recent changes

* Added config as first SuperPointConvBlock init argument

* Reordered README's after merge

* Added missing first config argument to SuperPointConvBlock instantiations

* Removed formatting error

* Added SuperPoint to README's de, pt-br, ru, te and vi

* Checked out README_fr.md

* Fixed README_fr.md

* Test fix README_fr.md

* Test fix README_fr.md

* Last make fix-copies !

* Updated checkpoint path

* Removed unused SuperPoint doc

* Added missing image

* Update src/transformers/models/superpoint/modeling_superpoint.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Removed unnecessary import

* Update src/transformers/models/superpoint/modeling_superpoint.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Added SuperPoint to _toctree.yml

---------
Co-authored-by: default avatarsteven <steven.bucaillle@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarSteven Bucaille <steven.bucaille@buawei.com>
parent 2f9a3edb
...@@ -207,6 +207,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -207,6 +207,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"), ("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"), ("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"), ("starcoder2", "Starcoder2Model"),
("superpoint", "SuperPointModel"),
("swiftformer", "SwiftFormerModel"), ("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"), ("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"), ("swin2sr", "Swin2SRModel"),
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_superpoint": [
"SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SuperPointConfig",
]
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_superpoint"] = ["SuperPointImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_superpoint"] = [
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel",
"SuperPointPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_superpoint import (
SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP,
SuperPointConfig,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_superpoint import SuperPointImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"magic-leap-community/superpoint": "https://huggingface.co/magic-leap-community/superpoint/blob/main/config.json"
}
class SuperPointConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SuperPointModel`]. It is used to instantiate a
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SuperPoint
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
encoder_hidden_sizes (`List`, *optional*, defaults to `[64, 64, 128, 128]`):
The number of channels in each convolutional layer in the encoder.
decoder_hidden_size (`int`, *optional*, defaults to 256): The hidden size of the decoder.
keypoint_decoder_dim (`int`, *optional*, defaults to 65): The output dimension of the keypoint decoder.
descriptor_decoder_dim (`int`, *optional*, defaults to 256): The output dimension of the descriptor decoder.
keypoint_threshold (`float`, *optional*, defaults to 0.005):
The threshold to use for extracting keypoints.
max_keypoints (`int`, *optional*, defaults to -1):
The maximum number of keypoints to extract. If `-1`, will extract all keypoints.
nms_radius (`int`, *optional*, defaults to 4):
The radius for non-maximum suppression.
border_removal_distance (`int`, *optional*, defaults to 4):
The distance from the border to remove keypoints.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import SuperPointConfig, SuperPointModel
>>> # Initializing a SuperPoint superpoint style configuration
>>> configuration = SuperPointConfig()
>>> # Initializing a model from the superpoint style configuration
>>> model = SuperPointModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "superpoint"
def __init__(
self,
encoder_hidden_sizes: List[int] = [64, 64, 128, 128],
decoder_hidden_size: int = 256,
keypoint_decoder_dim: int = 65,
descriptor_decoder_dim: int = 256,
keypoint_threshold: float = 0.005,
max_keypoints: int = -1,
nms_radius: int = 4,
border_removal_distance: int = 4,
initializer_range=0.02,
**kwargs,
):
self.encoder_hidden_sizes = encoder_hidden_sizes
self.decoder_hidden_size = decoder_hidden_size
self.keypoint_decoder_dim = keypoint_decoder_dim
self.descriptor_decoder_dim = descriptor_decoder_dim
self.keypoint_threshold = keypoint_threshold
self.max_keypoints = max_keypoints
self.nms_radius = nms_radius
self.border_removal_distance = border_removal_distance
self.initializer_range = initializer_range
super().__init__(**kwargs)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import requests
import torch
from PIL import Image
from transformers import SuperPointConfig, SuperPointImageProcessor, SuperPointModel
def get_superpoint_config():
config = SuperPointConfig(
encoder_hidden_sizes=[64, 64, 128, 128],
decoder_hidden_size=256,
keypoint_decoder_dim=65,
descriptor_decoder_dim=256,
keypoint_threshold=0.005,
max_keypoints=-1,
nms_radius=4,
border_removal_distance=4,
initializer_range=0.02,
)
return config
def create_rename_keys(config, state_dict):
rename_keys = []
# Encoder weights
rename_keys.append(("conv1a.weight", "encoder.conv_blocks.0.conv_a.weight"))
rename_keys.append(("conv1b.weight", "encoder.conv_blocks.0.conv_b.weight"))
rename_keys.append(("conv2a.weight", "encoder.conv_blocks.1.conv_a.weight"))
rename_keys.append(("conv2b.weight", "encoder.conv_blocks.1.conv_b.weight"))
rename_keys.append(("conv3a.weight", "encoder.conv_blocks.2.conv_a.weight"))
rename_keys.append(("conv3b.weight", "encoder.conv_blocks.2.conv_b.weight"))
rename_keys.append(("conv4a.weight", "encoder.conv_blocks.3.conv_a.weight"))
rename_keys.append(("conv4b.weight", "encoder.conv_blocks.3.conv_b.weight"))
rename_keys.append(("conv1a.bias", "encoder.conv_blocks.0.conv_a.bias"))
rename_keys.append(("conv1b.bias", "encoder.conv_blocks.0.conv_b.bias"))
rename_keys.append(("conv2a.bias", "encoder.conv_blocks.1.conv_a.bias"))
rename_keys.append(("conv2b.bias", "encoder.conv_blocks.1.conv_b.bias"))
rename_keys.append(("conv3a.bias", "encoder.conv_blocks.2.conv_a.bias"))
rename_keys.append(("conv3b.bias", "encoder.conv_blocks.2.conv_b.bias"))
rename_keys.append(("conv4a.bias", "encoder.conv_blocks.3.conv_a.bias"))
rename_keys.append(("conv4b.bias", "encoder.conv_blocks.3.conv_b.bias"))
# Keypoint Decoder weights
rename_keys.append(("convPa.weight", "keypoint_decoder.conv_score_a.weight"))
rename_keys.append(("convPb.weight", "keypoint_decoder.conv_score_b.weight"))
rename_keys.append(("convPa.bias", "keypoint_decoder.conv_score_a.bias"))
rename_keys.append(("convPb.bias", "keypoint_decoder.conv_score_b.bias"))
# Descriptor Decoder weights
rename_keys.append(("convDa.weight", "descriptor_decoder.conv_descriptor_a.weight"))
rename_keys.append(("convDb.weight", "descriptor_decoder.conv_descriptor_b.weight"))
rename_keys.append(("convDa.bias", "descriptor_decoder.conv_descriptor_a.bias"))
rename_keys.append(("convDb.bias", "descriptor_decoder.conv_descriptor_b.bias"))
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
def prepare_imgs():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im1 = Image.open(requests.get(url, stream=True).raw)
url = "http://images.cocodataset.org/test-stuff2017/000000004016.jpg"
im2 = Image.open(requests.get(url, stream=True).raw)
return [im1, im2]
@torch.no_grad()
def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub, test_mode=False):
"""
Copy/paste/tweak model's weights to our SuperPoint structure.
"""
print("Downloading original model from checkpoint...")
config = get_superpoint_config()
# load original state_dict from URL
original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
print("Converting model parameters...")
# rename keys
rename_keys = create_rename_keys(config, original_state_dict)
new_state_dict = original_state_dict.copy()
for src, dest in rename_keys:
rename_key(new_state_dict, src, dest)
# Load HuggingFace model
model = SuperPointModel(config)
model.load_state_dict(new_state_dict)
model.eval()
print("Successfully loaded weights in the model")
# Check model outputs
preprocessor = SuperPointImageProcessor()
inputs = preprocessor(images=prepare_imgs(), return_tensors="pt")
outputs = model(**inputs)
# If test_mode is True, we check that the model outputs match the original results
if test_mode:
torch.count_nonzero(outputs.mask[0])
expected_keypoints_shape = (2, 830, 2)
expected_scores_shape = (2, 830)
expected_descriptors_shape = (2, 830, 256)
expected_keypoints_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]])
expected_scores_values = torch.tensor([0.0064, 0.0140, 0.0595, 0.0728, 0.5170, 0.0175, 0.1523, 0.2055, 0.0336])
expected_descriptors_value = torch.tensor(-0.1096)
assert outputs.keypoints.shape == expected_keypoints_shape
assert outputs.scores.shape == expected_scores_shape
assert outputs.descriptors.shape == expected_descriptors_shape
assert torch.allclose(outputs.keypoints[0, :3], expected_keypoints_values, atol=1e-3)
assert torch.allclose(outputs.scores[0, :9], expected_scores_values, atol=1e-3)
assert torch.allclose(outputs.descriptors[0, 0, 0], expected_descriptors_value, atol=1e-3)
print("Model outputs match the original results!")
if save_model:
print("Saving model to local...")
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
model.save_pretrained(pytorch_dump_folder_path)
preprocessor.save_pretrained(pytorch_dump_folder_path)
model_name = "superpoint"
if push_to_hub:
print(f"Pushing {model_name} to the hub...")
model.push_to_hub(model_name)
preprocessor.push_to_hub(model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://github.com/magicleap/SuperPointPretrainedNetwork/raw/master/superpoint_v1.pth",
type=str,
help="URL of the original SuperPoint checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
args = parser.parse_args()
convert_superpoint_checkpoint(
args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for SuperPoint."""
from typing import Dict, Optional, Union
import numpy as np
from ... import is_vision_available, requires_backends
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from ...utils import TensorType, logging
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
def is_grayscale(
image: ImageInput,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if input_data_format == ChannelDimension.FIRST:
return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
elif input_data_format == ChannelDimension.LAST:
return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
def convert_to_grayscale(
image: ImageInput,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> ImageInput:
"""
Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch
and tensorflow grayscale conversion
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
channel, because of an issue that is discussed in :
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
Args:
image (Image):
The image to convert.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image.
"""
requires_backends(convert_to_grayscale, ["vision"])
if isinstance(image, np.ndarray):
if input_data_format == ChannelDimension.FIRST:
gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
gray_image = np.stack([gray_image] * 3, axis=0)
elif input_data_format == ChannelDimension.LAST:
gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140
gray_image = np.stack([gray_image] * 3, axis=-1)
return gray_image
if not isinstance(image, PIL.Image.Image):
return image
image = image.convert("L")
return image
class SuperPointImageProcessor(BaseImageProcessor):
r"""
Constructs a SuperPoint image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
by `do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
`True`. Can be overriden by `size` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 480, "width": 640}
size = get_size_dict(size, default_to_square=False)
self.do_resize = do_resize
self.size = size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Resize an image.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the output image. If not provided, it will be inferred from the input
image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
size = get_size_dict(size, default_to_square=False)
return resize(
image,
size=(size["height"], size["width"]),
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images,
do_resize: bool = None,
size: Dict[str, int] = None,
do_rescale: bool = None,
rescale_factor: float = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [self.resize(image=image, size=size, input_data_format=input_data_format) for image in images]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# Checking if image is RGB or grayscale
for i in range(len(images)):
if not is_grayscale(images[i], input_data_format):
images[i] = convert_to_grayscale(images[i], input_data_format=input_data_format)
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch SuperPoint model."""
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithNoAttention,
)
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SuperPointConfig"
_CHECKPOINT_FOR_DOC = "magic-leap-community/superpoint"
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = ["magic-leap-community/superpoint"]
def remove_keypoints_from_borders(
keypoints: torch.Tensor, scores: torch.Tensor, border: int, height: int, width: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Removes keypoints (and their associated scores) that are too close to the border"""
mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
mask = mask_h & mask_w
return keypoints[mask], scores[mask]
def top_k_keypoints(keypoints: torch.Tensor, scores: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Keeps the k keypoints with highest score"""
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0)
return keypoints[indices], scores
def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
"""Applies non-maximum suppression on scores"""
if nms_radius < 0:
raise ValueError("Expected positive values for nms_radius")
def max_pool(x):
return nn.functional.max_pool2d(x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
@dataclass
class ImagePointDescriptionOutput(ModelOutput):
"""
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
the maximum number of keypoints is set as the dimension of the keypoints, scores and descriptors tensors. The mask
tensor is used to indicate which values in the keypoints, scores and descriptors tensors are keypoint information
and which are padding.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
Relative (x, y) coordinates of predicted keypoints in a given image.
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
Scores of predicted keypoints.
descriptors (`torch.FloatTensor` of shape `(batch_size, num_keypoints, descriptor_size)`):
Descriptors of predicted keypoints.
mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
Mask indicating which values in keypoints, scores and descriptors are keypoint information.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
keypoints: Optional[torch.IntTensor] = None
scores: Optional[torch.FloatTensor] = None
descriptors: Optional[torch.FloatTensor] = None
mask: Optional[torch.BoolTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class SuperPointConvBlock(nn.Module):
def __init__(
self, config: SuperPointConfig, in_channels: int, out_channels: int, add_pooling: bool = False
) -> None:
super().__init__()
self.conv_a = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.conv_b = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if add_pooling else None
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.relu(self.conv_a(hidden_states))
hidden_states = self.relu(self.conv_b(hidden_states))
if self.pool is not None:
hidden_states = self.pool(hidden_states)
return hidden_states
class SuperPointEncoder(nn.Module):
"""
SuperPoint encoder module. It is made of 4 convolutional layers with ReLU activation and max pooling, reducing the
dimensionality of the image.
"""
def __init__(self, config: SuperPointConfig) -> None:
super().__init__()
# SuperPoint uses 1 channel images
self.input_dim = 1
conv_blocks = []
conv_blocks.append(
SuperPointConvBlock(config, self.input_dim, config.encoder_hidden_sizes[0], add_pooling=True)
)
for i in range(1, len(config.encoder_hidden_sizes) - 1):
conv_blocks.append(
SuperPointConvBlock(
config, config.encoder_hidden_sizes[i - 1], config.encoder_hidden_sizes[i], add_pooling=True
)
)
conv_blocks.append(
SuperPointConvBlock(
config, config.encoder_hidden_sizes[-2], config.encoder_hidden_sizes[-1], add_pooling=False
)
)
self.conv_blocks = nn.ModuleList(conv_blocks)
def forward(
self,
input,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
all_hidden_states = () if output_hidden_states else None
for conv_block in self.conv_blocks:
input = conv_block(input)
if output_hidden_states:
all_hidden_states = all_hidden_states + (input,)
output = input
if not return_dict:
return tuple(v for v in [output, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=output,
hidden_states=all_hidden_states,
)
class SuperPointInterestPointDecoder(nn.Module):
"""
The SuperPointInterestPointDecoder uses the output of the SuperPointEncoder to compute the keypoint with scores.
The scores are first computed by a convolutional layer, then a softmax is applied to get a probability distribution
over the 65 possible keypoint classes. The keypoints are then extracted from the scores by thresholding and
non-maximum suppression. Post-processing is then applied to remove keypoints too close to the image borders as well
as to keep only the k keypoints with highest score.
"""
def __init__(self, config: SuperPointConfig) -> None:
super().__init__()
self.keypoint_threshold = config.keypoint_threshold
self.max_keypoints = config.max_keypoints
self.nms_radius = config.nms_radius
self.border_removal_distance = config.border_removal_distance
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_score_a = nn.Conv2d(
config.encoder_hidden_sizes[-1],
config.decoder_hidden_size,
kernel_size=3,
stride=1,
padding=1,
)
self.conv_score_b = nn.Conv2d(
config.decoder_hidden_size, config.keypoint_decoder_dim, kernel_size=1, stride=1, padding=0
)
def forward(self, encoded: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scores = self._get_pixel_scores(encoded)
keypoints, scores = self._extract_keypoints(scores)
return keypoints, scores
def _get_pixel_scores(self, encoded: torch.Tensor) -> torch.Tensor:
"""Based on the encoder output, compute the scores for each pixel of the image"""
scores = self.relu(self.conv_score_a(encoded))
scores = self.conv_score_b(scores)
scores = nn.functional.softmax(scores, 1)[:, :-1]
batch_size, _, height, width = scores.shape
scores = scores.permute(0, 2, 3, 1).reshape(batch_size, height, width, 8, 8)
scores = scores.permute(0, 1, 3, 2, 4).reshape(batch_size, height * 8, width * 8)
scores = simple_nms(scores, self.nms_radius)
return scores
def _extract_keypoints(self, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation"""
_, height, width = scores.shape
# Threshold keypoints by score value
keypoints = torch.nonzero(scores[0] > self.keypoint_threshold)
scores = scores[0][tuple(keypoints.t())]
# Discard keypoints near the image borders
keypoints, scores = remove_keypoints_from_borders(
keypoints, scores, self.border_removal_distance, height * 8, width * 8
)
# Keep the k keypoints with highest score
if self.max_keypoints >= 0:
keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints)
# Convert (y, x) to (x, y)
keypoints = torch.flip(keypoints, [1]).float()
return keypoints, scores
class SuperPointDescriptorDecoder(nn.Module):
"""
The SuperPointDescriptorDecoder uses the outputs of both the SuperPointEncoder and the
SuperPointInterestPointDecoder to compute the descriptors at the keypoints locations.
The descriptors are first computed by a convolutional layer, then normalized to have a norm of 1. The descriptors
are then interpolated at the keypoints locations.
"""
def __init__(self, config: SuperPointConfig) -> None:
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv_descriptor_a = nn.Conv2d(
config.encoder_hidden_sizes[-1],
config.decoder_hidden_size,
kernel_size=3,
stride=1,
padding=1,
)
self.conv_descriptor_b = nn.Conv2d(
config.decoder_hidden_size,
config.descriptor_decoder_dim,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, encoded: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor:
"""Based on the encoder output and the keypoints, compute the descriptors for each keypoint"""
descriptors = self.conv_descriptor_b(self.relu(self.conv_descriptor_a(encoded)))
descriptors = nn.functional.normalize(descriptors, p=2, dim=1)
descriptors = self._sample_descriptors(keypoints[None], descriptors[0][None], 8)[0]
# [descriptor_dim, num_keypoints] -> [num_keypoints, descriptor_dim]
descriptors = torch.transpose(descriptors, 0, 1)
return descriptors
@staticmethod
def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor:
"""Interpolate descriptors at keypoint locations"""
batch_size, num_channels, height, width = descriptors.shape
keypoints = keypoints - scale / 2 + 0.5
divisor = torch.tensor([[(width * scale - scale / 2 - 0.5), (height * scale - scale / 2 - 0.5)]])
divisor = divisor.to(keypoints)
keypoints /= divisor
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {}
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
keypoints = keypoints.view(batch_size, 1, -1, 2)
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
# [batch_size, descriptor_decoder_dim, num_channels, num_keypoints] -> [batch_size, descriptor_decoder_dim, num_keypoints]
descriptors = descriptors.reshape(batch_size, num_channels, -1)
descriptors = nn.functional.normalize(descriptors, p=2, dim=1)
return descriptors
class SuperPointPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SuperPointConfig
base_model_prefix = "superpoint"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
"""
Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for SuperPoint. This is
a workaround for the issue discussed in :
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
Args:
pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
Returns:
pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
"""
return pixel_values[:, 0, :, :][:, None, :, :]
SUPERPOINT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`SuperPointConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SUPERPOINT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`SuperPointImageProcessor`]. See
[`SuperPointImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more
detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"SuperPoint model outputting keypoints and descriptors.",
SUPERPOINT_START_DOCSTRING,
)
class SuperPointModel(SuperPointPreTrainedModel):
"""
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
Description <https://arxiv.org/abs/1712.07629>`__ by Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. It
is a fully convolutional neural network that extracts keypoints and descriptors from an image. It is trained in a
self-supervised manner, using a combination of a photometric loss and a loss based on the homographic adaptation of
keypoints. It is made of a convolutional encoder and two decoders: one for keypoints and one for descriptors.
"""
def __init__(self, config: SuperPointConfig) -> None:
super().__init__(config)
self.config = config
self.encoder = SuperPointEncoder(config)
self.keypoint_decoder = SuperPointInterestPointDecoder(config)
self.descriptor_decoder = SuperPointDescriptorDecoder(config)
self.post_init()
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImagePointDescriptionOutput]:
"""
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoModel
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
>>> model = AutoModel.from_pretrained("magic-leap-community/superpoint")
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
if labels is not None:
raise ValueError(
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
batch_size = pixel_values.shape[0]
encoder_outputs = self.encoder(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
list_keypoints_scores = [
self.keypoint_decoder(last_hidden_state[None, ...]) for last_hidden_state in last_hidden_state
]
list_keypoints = [keypoints_scores[0] for keypoints_scores in list_keypoints_scores]
list_scores = [keypoints_scores[1] for keypoints_scores in list_keypoints_scores]
list_descriptors = [
self.descriptor_decoder(last_hidden_state[None, ...], keypoints[None, ...])
for last_hidden_state, keypoints in zip(last_hidden_state, list_keypoints)
]
maximum_num_keypoints = max(keypoints.shape[0] for keypoints in list_keypoints)
keypoints = torch.zeros((batch_size, maximum_num_keypoints, 2), device=pixel_values.device)
scores = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device)
descriptors = torch.zeros(
(batch_size, maximum_num_keypoints, self.config.descriptor_decoder_dim),
device=pixel_values.device,
)
mask = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device, dtype=torch.int)
for i, (_keypoints, _scores, _descriptors) in enumerate(zip(list_keypoints, list_scores, list_descriptors)):
keypoints[i, : _keypoints.shape[0]] = _keypoints
scores[i, : _scores.shape[0]] = _scores
descriptors[i, : _descriptors.shape[0]] = _descriptors
mask[i, : _scores.shape[0]] = 1
hidden_states = encoder_outputs[1] if output_hidden_states else None
if not return_dict:
return tuple(
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
)
return ImagePointDescriptionOutput(
last_hidden_state=last_hidden_state,
keypoints=keypoints,
scores=scores,
descriptors=descriptors,
mask=mask,
hidden_states=hidden_states,
)
...@@ -8026,6 +8026,23 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject): ...@@ -8026,6 +8026,23 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SuperPointModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SuperPointPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -485,6 +485,13 @@ class SiglipImageProcessor(metaclass=DummyObject): ...@@ -485,6 +485,13 @@ class SiglipImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
class SuperPointImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class Swin2SRImageProcessor(metaclass=DummyObject): class Swin2SRImageProcessor(metaclass=DummyObject):
_backends = ["vision"] _backends = ["vision"]
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
prepare_image_inputs,
)
if is_vision_available():
from transformers import SuperPointImageProcessor
class SuperPointImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
):
size = size if size is not None else {"height": 480, "width": 640}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SuperPointImageProcessor if is_vision_available() else None
def setUp(self) -> None:
self.image_processor_tester = SuperPointImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processing(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 42, "width": 42}
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
def test_call_numpy_4_channels(self):
pass
def test_input_image_properly_converted_to_grayscale(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from typing import List
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
from transformers import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
)
if is_vision_available():
from PIL import Image
from transformers import AutoImageProcessor
class SuperPointModelTester:
def __init__(
self,
parent,
batch_size=3,
image_width=80,
image_height=60,
encoder_hidden_sizes: List[int] = [32, 32, 64, 64],
decoder_hidden_size: int = 128,
keypoint_decoder_dim: int = 65,
descriptor_decoder_dim: int = 128,
keypoint_threshold: float = 0.005,
max_keypoints: int = -1,
nms_radius: int = 4,
border_removal_distance: int = 4,
):
self.parent = parent
self.batch_size = batch_size
self.image_width = image_width
self.image_height = image_height
self.encoder_hidden_sizes = encoder_hidden_sizes
self.decoder_hidden_size = decoder_hidden_size
self.keypoint_decoder_dim = keypoint_decoder_dim
self.descriptor_decoder_dim = descriptor_decoder_dim
self.keypoint_threshold = keypoint_threshold
self.max_keypoints = max_keypoints
self.nms_radius = nms_radius
self.border_removal_distance = border_removal_distance
def prepare_config_and_inputs(self):
# SuperPoint expects a grayscale image as input
pixel_values = floats_tensor([self.batch_size, 3, self.image_height, self.image_width])
config = self.get_config()
return config, pixel_values
def get_config(self):
return SuperPointConfig(
encoder_hidden_sizes=self.encoder_hidden_sizes,
decoder_hidden_size=self.decoder_hidden_size,
keypoint_decoder_dim=self.keypoint_decoder_dim,
descriptor_decoder_dim=self.descriptor_decoder_dim,
keypoint_threshold=self.keypoint_threshold,
max_keypoints=self.max_keypoints,
nms_radius=self.nms_radius,
border_removal_distance=self.border_removal_distance,
)
def create_and_check_model(self, config, pixel_values):
model = SuperPointModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.last_hidden_state.shape,
(
self.batch_size,
self.encoder_hidden_sizes[-1],
self.image_height // 8,
self.image_width // 8,
),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = SuperPointModelTester(self)
self.config_tester = ConfigTester(self, config_class=SuperPointConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
def test_training(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="SuperPoint does not output any loss term in the forward pass")
def test_retain_grad_hidden_states_attentions(self):
pass
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
# SuperPoint's feature maps are of shape (batch_size, num_channels, width, height)
for i, conv_layer_size in enumerate(self.model_tester.encoder_hidden_sizes[:-1]):
self.assertListEqual(
list(hidden_states[i].shape[-3:]),
[
conv_layer_size,
self.model_tester.image_height // (2 ** (i + 1)),
self.model_tester.image_width // (2 ** (i + 1)),
],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
@slow
def test_model_from_pretrained(self):
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SuperPointModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
model_inputs = self._prepare_for_class(inputs_dict, model_class)
# Provide an arbitrary sized Tensor as labels to model inputs
model_inputs["labels"] = torch.rand((128, 128))
with self.assertRaises(ValueError) as cm:
model(**model_inputs)
self.assertEqual(ValueError, cm.exception.__class__)
def prepare_imgs():
image1 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
image2 = Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")
return [image1, image2]
@require_torch
@require_vision
class SuperPointModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return AutoImageProcessor.from_pretrained("magic-leap-community/superpoint") if is_vision_available() else None
@slow
def test_inference(self):
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
expected_number_keypoints_image0 = 567
expected_number_keypoints_image1 = 830
expected_max_number_keypoints = max(expected_number_keypoints_image0, expected_number_keypoints_image1)
expected_keypoints_shape = torch.Size((len(images), expected_max_number_keypoints, 2))
expected_scores_shape = torch.Size(
(
len(images),
expected_max_number_keypoints,
)
)
expected_descriptors_shape = torch.Size((len(images), expected_max_number_keypoints, 256))
# Check output shapes
self.assertEqual(outputs.keypoints.shape, expected_keypoints_shape)
self.assertEqual(outputs.scores.shape, expected_scores_shape)
self.assertEqual(outputs.descriptors.shape, expected_descriptors_shape)
expected_keypoints_image0_values = torch.tensor([[480.0, 9.0], [494.0, 9.0], [489.0, 16.0]]).to(torch_device)
expected_scores_image0_values = torch.tensor(
[0.0064, 0.0137, 0.0589, 0.0723, 0.5166, 0.0174, 0.1515, 0.2054, 0.0334]
).to(torch_device)
expected_descriptors_image0_value = torch.tensor(-0.1096).to(torch_device)
predicted_keypoints_image0_values = outputs.keypoints[0, :3]
predicted_scores_image0_values = outputs.scores[0, :9]
predicted_descriptors_image0_value = outputs.descriptors[0, 0, 0]
# Check output values
self.assertTrue(
torch.allclose(
predicted_keypoints_image0_values,
expected_keypoints_image0_values,
atol=1e-4,
)
)
self.assertTrue(torch.allclose(predicted_scores_image0_values, expected_scores_image0_values, atol=1e-4))
self.assertTrue(
torch.allclose(
predicted_descriptors_image0_value,
expected_descriptors_image0_value,
atol=1e-4,
)
)
# Check mask values
self.assertTrue(outputs.mask[0, expected_number_keypoints_image0 - 1].item() == 1)
self.assertTrue(outputs.mask[0, expected_number_keypoints_image0].item() == 0)
self.assertTrue(torch.all(outputs.mask[0, : expected_number_keypoints_image0 - 1]))
self.assertTrue(torch.all(torch.logical_not(outputs.mask[0, expected_number_keypoints_image0:])))
self.assertTrue(torch.all(outputs.mask[1]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment