Unverified Commit 06fd7972 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add ZoeDepth (#30136)



* First draft

* Add docs

* Clean up code

* Convert model

* Add image processor

* Convert Zoe_K

* More improvements

* Improve variable names and docstrings

* Improve variable names

* Improve variable names

* Replace nn.sequential

* More improvements

* Convert ZoeD_NK

* Fix most tests

* Verify pixel values

* Verify pixel values

* Add squeeze

* Update beit to support arbitrary window sizes

* Improve image processor

* Improve docstring

* Improve beit

* Improve model outputs

* Add figure

* Fix beit

* Update checkpoint

* Fix repo id

* Add _keys_to_ignore_on_load_unexpected

* More improvements

* Address comments

* Address comments

* Address comments

* Address comments

* Rename variable name

* Add backbone_hidden_size

* Vectorize

* Vectorize more

* Address comments

* Clarify docstring

* Remove backbone_hidden_size

* Fix image processor

* Remove print statements

* Remove print statement

* Add integration test

* Address comments

* Address comments

* Address comments

* Address comments

* Add requires_backends

* Clean up

* Simplify conversion script

* Simplify more

* Simplify more

* Simplify more

* Clean up

* Make sure beit is loaded correctly

* Address comment

* Address bin_configurations

* Use bin_configurations

* Convert models, add integration tests

* Fix doc test

* Address comments

* Unify regressor classes

* Clarify arguments

* Improve resize_image

* Add num_relative_features

* Address comment

* [run-slow]beit,data2vec,zoedepth

* [run-slow]beit,data2vec,zoedepth

* Address comments

* Address comment

* Address comment

* Replace nn.TransformerEncoderLayer and nn.TransformerEncoder

* Replace nn.MultiheadAttention

* Add attributes for patch transformer to config

* Add tests for ensure_multiple_of

* Update organization

* Add tests

* [run-slow] beit data2vec

* Update ruff

* [run-slow] beit data2vec

* Add comment

* Improve docstrings, add test

* Fix interpolate_pos_encoding

* Fix slow tests

* Add docstring

* Update src/transformers/models/zoedepth/image_processing_zoedepth.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/zoedepth/image_processing_zoedepth.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Improve tests and docstrings

* Use run_common_tests

* Improve docstrings

* Improve docstrings

* Improve tests

* Improve tests

* Remove print statements

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 1082361a
...@@ -667,6 +667,8 @@ ...@@ -667,6 +667,8 @@
title: ViTMSN title: ViTMSN
- local: model_doc/yolos - local: model_doc/yolos
title: YOLOS title: YOLOS
- local: model_doc/zoedepth
title: ZoeDepth
title: Vision models title: Vision models
- isExpanded: false - isExpanded: false
sections: sections:
......
...@@ -343,5 +343,6 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -343,5 +343,6 @@ Flax), PyTorch, and/or TensorFlow.
| [XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2) | ✅ | ✅ | ✅ | | [XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2) | ✅ | ✅ | ✅ |
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ | | [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ | | [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
<!-- End table--> <!-- End table-->
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ZoeDepth
## Overview
The ZoeDepth model was proposed in [ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth](https://arxiv.org/abs/2302.12288) by Shariq Farooq Bhat, Reiner Birkl, Diana Wofk, Peter Wonka, Matthias Müller. ZoeDepth extends the [DPT](dpt) framework for metric (also called absolute) depth estimation. ZoeDepth is pre-trained on 12 datasets using relative depth and fine-tuned on two domains (NYU and KITTI) using metric depth. A lightweight head is used with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier.
The abstract from the paper is the following:
*This paper tackles the problem of depth estimation from a single image. Existing work either focuses on generalization performance disregarding metric scale, i.e. relative depth estimation, or state-of-the-art results on specific datasets, i.e. metric depth estimation. We propose the first approach that combines both worlds, leading to a model with excellent generalization performance while maintaining metric scale. Our flagship model, ZoeD-M12-NK, is pre-trained on 12 datasets using relative depth and fine-tuned on two datasets using metric depth. We use a lightweight head with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier. Our framework admits multiple configurations depending on the datasets used for relative depth pre-training and metric fine-tuning. Without pre-training, we can already significantly improve the state of the art (SOTA) on the NYU Depth v2 indoor dataset. Pre-training on twelve datasets and fine-tuning on the NYU Depth v2 indoor dataset, we can further improve SOTA for a total of 21% in terms of relative absolute error (REL). Finally, ZoeD-M12-NK is the first model that can jointly train on multiple datasets (NYU Depth v2 and KITTI) without a significant drop in performance and achieve unprecedented zero-shot generalization performance to eight unseen datasets from both indoor and outdoor domains.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/zoedepth_architecture_bis.png"
alt="drawing" width="600"/>
<small> ZoeDepth architecture. Taken from the <a href="https://arxiv.org/abs/2302.12288">original paper.</a> </small>
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/isl-org/ZoeDepth).
## Usage tips
- ZoeDepth is an absolute (also called metric) depth estimation model, unlike DPT which is a relative depth estimation model. This means that ZoeDepth is able to estimate depth in metric units like meters.
The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md):
```python
from transformers import pipeline
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
result = pipe(image)
depth = result["depth"]
```
Alternatively, one can also perform inference using the classes:
```python
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
import torch
import numpy as np
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth.
- A demo notebook regarding inference with ZoeDepth models can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ZoeDepth). 🌎
## ZoeDepthConfig
[[autodoc]] ZoeDepthConfig
## ZoeDepthImageProcessor
[[autodoc]] ZoeDepthImageProcessor
- preprocess
## ZoeDepthForDepthEstimation
[[autodoc]] ZoeDepthForDepthEstimation
- forward
\ No newline at end of file
...@@ -807,6 +807,7 @@ _import_structure = { ...@@ -807,6 +807,7 @@ _import_structure = {
"models.xmod": ["XmodConfig"], "models.xmod": ["XmodConfig"],
"models.yolos": ["YolosConfig"], "models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"], "models.yoso": ["YosoConfig"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [], "onnx": [],
"pipelines": [ "pipelines": [
"AudioClassificationPipeline", "AudioClassificationPipeline",
...@@ -1182,6 +1183,7 @@ else: ...@@ -1182,6 +1183,7 @@ else:
_import_structure["models.vitmatte"].append("VitMatteImageProcessor") _import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor") _import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"]) _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
_import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
try: try:
if not is_torchvision_available(): if not is_torchvision_available():
...@@ -3586,6 +3588,12 @@ else: ...@@ -3586,6 +3588,12 @@ else:
"YosoPreTrainedModel", "YosoPreTrainedModel",
] ]
) )
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
"ZoeDepthPreTrainedModel",
]
)
_import_structure["optimization"] = [ _import_structure["optimization"] = [
"Adafactor", "Adafactor",
"AdamW", "AdamW",
...@@ -5497,6 +5505,7 @@ if TYPE_CHECKING: ...@@ -5497,6 +5505,7 @@ if TYPE_CHECKING:
from .models.xmod import XmodConfig from .models.xmod import XmodConfig
from .models.yolos import YolosConfig from .models.yolos import YolosConfig
from .models.yoso import YosoConfig from .models.yoso import YosoConfig
from .models.zoedepth import ZoeDepthConfig
# Pipelines # Pipelines
from .pipelines import ( from .pipelines import (
...@@ -5872,6 +5881,7 @@ if TYPE_CHECKING: ...@@ -5872,6 +5881,7 @@ if TYPE_CHECKING:
from .models.vitmatte import VitMatteImageProcessor from .models.vitmatte import VitMatteImageProcessor
from .models.vivit import VivitImageProcessor from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
from .models.zoedepth import ZoeDepthImageProcessor
try: try:
if not is_torchvision_available(): if not is_torchvision_available():
...@@ -7798,6 +7808,10 @@ if TYPE_CHECKING: ...@@ -7798,6 +7808,10 @@ if TYPE_CHECKING:
YosoModel, YosoModel,
YosoPreTrainedModel, YosoPreTrainedModel,
) )
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
)
# Optimization # Optimization
from .optimization import ( from .optimization import (
......
...@@ -409,22 +409,22 @@ def validate_preprocess_arguments( ...@@ -409,22 +409,22 @@ def validate_preprocess_arguments(
""" """
if do_rescale and rescale_factor is None: if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.") raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
if do_pad and size_divisibility is None: if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size # Here, size_divisor might be passed as the value of size
raise ValueError( raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True." "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
) )
if do_normalize and (image_mean is None or image_std is None): if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.") raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
if do_center_crop and crop_size is None: if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.") raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
if do_resize and (size is None or resample is None): if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.") raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
# In the future we can add a TF implementation here when we have TF models. # In the future we can add a TF implementation here when we have TF models.
......
...@@ -263,4 +263,5 @@ from . import ( ...@@ -263,4 +263,5 @@ from . import (
xmod, xmod,
yolos, yolos,
yoso, yoso,
zoedepth,
) )
...@@ -291,6 +291,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -291,6 +291,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("xmod", "XmodConfig"), ("xmod", "XmodConfig"),
("yolos", "YolosConfig"), ("yolos", "YolosConfig"),
("yoso", "YosoConfig"), ("yoso", "YosoConfig"),
("zoedepth", "ZoeDepthConfig"),
] ]
) )
...@@ -589,6 +590,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -589,6 +590,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("xmod", "X-MOD"), ("xmod", "X-MOD"),
("yolos", "YOLOS"), ("yolos", "YOLOS"),
("yoso", "YOSO"), ("yoso", "YOSO"),
("zoedepth", "ZoeDepth"),
] ]
) )
......
...@@ -142,6 +142,7 @@ else: ...@@ -142,6 +142,7 @@ else:
("vitmatte", ("VitMatteImageProcessor",)), ("vitmatte", ("VitMatteImageProcessor",)),
("xclip", ("CLIPImageProcessor",)), ("xclip", ("CLIPImageProcessor",)),
("yolos", ("YolosImageProcessor",)), ("yolos", ("YolosImageProcessor",)),
("zoedepth", ("ZoeDepthImageProcessor",)),
] ]
) )
......
...@@ -792,6 +792,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict( ...@@ -792,6 +792,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
("depth_anything", "DepthAnythingForDepthEstimation"), ("depth_anything", "DepthAnythingForDepthEstimation"),
("dpt", "DPTForDepthEstimation"), ("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"), ("glpn", "GLPNForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
] ]
) )
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
......
...@@ -34,7 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput, SemanticSegmenterOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -193,12 +193,6 @@ class BeitEmbeddings(nn.Module): ...@@ -193,12 +193,6 @@ class BeitEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, _, height, width = pixel_values.shape _, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings( embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
) )
...@@ -280,6 +274,7 @@ class BeitPatchEmbeddings(nn.Module): ...@@ -280,6 +274,7 @@ class BeitPatchEmbeddings(nn.Module):
class BeitSelfAttention(nn.Module): class BeitSelfAttention(nn.Module):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__() super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
...@@ -313,6 +308,7 @@ class BeitSelfAttention(nn.Module): ...@@ -313,6 +308,7 @@ class BeitSelfAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None, relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -327,9 +323,11 @@ class BeitSelfAttention(nn.Module): ...@@ -327,9 +323,11 @@ class BeitSelfAttention(nn.Module):
# Add relative position bias if present. # Add relative position bias if present.
if self.relative_position_bias is not None: if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias( attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2] window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
).unsqueeze(0) )
# Add shared relative position bias if provided. # Add shared relative position bias if provided.
if relative_position_bias is not None: if relative_position_bias is not None:
...@@ -407,9 +405,10 @@ class BeitAttention(nn.Module): ...@@ -407,9 +405,10 @@ class BeitAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None, relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention( self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -475,6 +474,7 @@ class BeitLayer(nn.Module): ...@@ -475,6 +474,7 @@ class BeitLayer(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None, relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
...@@ -482,6 +482,7 @@ class BeitLayer(nn.Module): ...@@ -482,6 +482,7 @@ class BeitLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
relative_position_bias=relative_position_bias, relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
...@@ -520,32 +521,71 @@ class BeitRelativePositionBias(nn.Module): ...@@ -520,32 +521,71 @@ class BeitRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH ) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls # cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
"""
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0]) window_area = window_size[0] * window_size[1]
coords_w = torch.arange(window_size[1]) grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww coords = torch.stack(grid) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros( relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1 relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
"""
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
"""
old_height = 2 * self.window_size[0] - 1
old_width = 2 * self.window_size[1] - 1
new_height = 2 * window_size[0] - 1
new_width = 2 * window_size[1] - 1
self.register_buffer("relative_position_index", relative_position_index, persistent=False) old_relative_position_bias_table = self.relative_position_bias_table
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: old_num_relative_distance = self.num_relative_distance
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( new_num_relative_distance = new_height * new_width + 3
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
new_relative_position_bias_table = torch.cat(
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
)
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding: if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate( relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1), relative_position_bias.unsqueeze(1),
...@@ -554,7 +594,7 @@ class BeitRelativePositionBias(nn.Module): ...@@ -554,7 +594,7 @@ class BeitRelativePositionBias(nn.Module):
align_corners=False, align_corners=False,
).squeeze(1) ).squeeze(1)
return relative_position_bias return relative_position_bias.unsqueeze(0)
class BeitEncoder(nn.Module): class BeitEncoder(nn.Module):
...@@ -587,6 +627,7 @@ class BeitEncoder(nn.Module): ...@@ -587,6 +627,7 @@ class BeitEncoder(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]: ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -606,13 +647,22 @@ class BeitEncoder(nn.Module): ...@@ -606,13 +647,22 @@ class BeitEncoder(nn.Module):
output_attentions, output_attentions,
) )
else: else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = ( relative_position_bias = (
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None if self.relative_position_bias is not None
else None else None
) )
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -643,6 +693,7 @@ class BeitPreTrainedModel(PreTrainedModel): ...@@ -643,6 +693,7 @@ class BeitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"] _no_split_modules = ["BeitLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -738,7 +789,7 @@ class BeitModel(BeitPreTrainedModel): ...@@ -738,7 +789,7 @@ class BeitModel(BeitPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -756,9 +807,6 @@ class BeitModel(BeitPreTrainedModel): ...@@ -756,9 +807,6 @@ class BeitModel(BeitPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -766,15 +814,17 @@ class BeitModel(BeitPreTrainedModel): ...@@ -766,15 +814,17 @@ class BeitModel(BeitPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings( embedding_output, _ = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
) )
resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
resolution=resolution,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
) )
...@@ -1477,9 +1527,14 @@ class BeitBackbone(BeitPreTrainedModel, BackboneMixin): ...@@ -1477,9 +1527,14 @@ class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values) embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
resolution = pixel_values.shape[2:]
outputs = self.encoder( outputs = self.encoder(
embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict embedding_output,
output_hidden_states=True,
output_attentions=output_attentions,
resolution=resolution,
return_dict=return_dict,
) )
hidden_states = outputs.hidden_states if return_dict else outputs[1] hidden_states = outputs.hidden_states if return_dict else outputs[1]
......
...@@ -32,7 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput, SemanticSegmenterOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -192,12 +192,6 @@ class Data2VecVisionEmbeddings(nn.Module): ...@@ -192,12 +192,6 @@ class Data2VecVisionEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, _, height, width = pixel_values.shape _, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings( embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
) )
...@@ -281,6 +275,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module): ...@@ -281,6 +275,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
class Data2VecVisionSelfAttention(nn.Module): class Data2VecVisionSelfAttention(nn.Module):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
super().__init__() super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
...@@ -314,6 +309,7 @@ class Data2VecVisionSelfAttention(nn.Module): ...@@ -314,6 +309,7 @@ class Data2VecVisionSelfAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
...@@ -328,9 +324,11 @@ class Data2VecVisionSelfAttention(nn.Module): ...@@ -328,9 +324,11 @@ class Data2VecVisionSelfAttention(nn.Module):
# Add relative position bias if present. # Add relative position bias if present.
if self.relative_position_bias is not None: if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias( attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2] window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
).unsqueeze(0) )
# Add shared relative position bias if provided. # Add shared relative position bias if provided.
if relative_position_bias is not None: if relative_position_bias is not None:
...@@ -410,9 +408,10 @@ class Data2VecVisionAttention(nn.Module): ...@@ -410,9 +408,10 @@ class Data2VecVisionAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention( self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
...@@ -483,6 +482,7 @@ class Data2VecVisionLayer(nn.Module): ...@@ -483,6 +482,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
...@@ -490,6 +490,7 @@ class Data2VecVisionLayer(nn.Module): ...@@ -490,6 +490,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
relative_position_bias=relative_position_bias, relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
...@@ -529,32 +530,71 @@ class Data2VecVisionRelativePositionBias(nn.Module): ...@@ -529,32 +530,71 @@ class Data2VecVisionRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH ) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls # cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
"""
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0]) window_area = window_size[0] * window_size[1]
coords_w = torch.arange(window_size[1]) grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww coords = torch.stack(grid) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros( relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1 relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
"""
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
"""
old_height = 2 * self.window_size[0] - 1
old_width = 2 * self.window_size[1] - 1
new_height = 2 * window_size[0] - 1
new_width = 2 * window_size[1] - 1
self.register_buffer("relative_position_index", relative_position_index, persistent=False) old_relative_position_bias_table = self.relative_position_bias_table
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor: old_num_relative_distance = self.num_relative_distance
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( new_num_relative_distance = new_height * new_width + 3
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
new_relative_position_bias_table = torch.cat(
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
)
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding: if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate( relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1), relative_position_bias.unsqueeze(1),
...@@ -563,7 +603,7 @@ class Data2VecVisionRelativePositionBias(nn.Module): ...@@ -563,7 +603,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
align_corners=False, align_corners=False,
).squeeze(1) ).squeeze(1)
return relative_position_bias return relative_position_bias.unsqueeze(0)
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision # Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
...@@ -597,6 +637,7 @@ class Data2VecVisionEncoder(nn.Module): ...@@ -597,6 +637,7 @@ class Data2VecVisionEncoder(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]: ) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -616,13 +657,22 @@ class Data2VecVisionEncoder(nn.Module): ...@@ -616,13 +657,22 @@ class Data2VecVisionEncoder(nn.Module):
output_attentions, output_attentions,
) )
else: else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = ( relative_position_bias = (
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1]) self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None if self.relative_position_bias is not None
else None else None
) )
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -654,6 +704,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): ...@@ -654,6 +704,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecVisionLayer"] _no_split_modules = ["Data2VecVisionLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -750,7 +801,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): ...@@ -750,7 +801,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None, bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -768,9 +819,6 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): ...@@ -768,9 +819,6 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -778,15 +826,17 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel): ...@@ -778,15 +826,17 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings( embedding_output, _ = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
) )
resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
resolution=resolution,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
) )
......
...@@ -58,7 +58,7 @@ def get_resize_output_image_size( ...@@ -58,7 +58,7 @@ def get_resize_output_image_size(
multiple: int, multiple: int,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None): def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
x = round(val / multiple) * multiple x = round(val / multiple) * multiple
if max_val is not None and x > max_val: if max_val is not None and x > max_val:
...@@ -87,8 +87,8 @@ def get_resize_output_image_size( ...@@ -87,8 +87,8 @@ def get_resize_output_image_size(
# fit height # fit height
scale_width = scale_height scale_width = scale_height
new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple) new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple) new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
return (new_height, new_width) return (new_height, new_width)
......
...@@ -1021,7 +1021,7 @@ class DPTNeck(nn.Module): ...@@ -1021,7 +1021,7 @@ class DPTNeck(nn.Module):
class DPTDepthEstimationHead(nn.Module): class DPTDepthEstimationHead(nn.Module):
""" """
Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
the predictions to the input resolution after the first convolutional layer (details can be found in the paper's the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
supplementary material). supplementary material).
""" """
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable
_import_structure = {"configuration_zoedepth": ["ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP", "ZoeDepthConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_zoedepth"] = [
"ZoeDepthForDepthEstimation",
"ZoeDepthPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_zoedepth"] = ["ZoeDepthImageProcessor"]
if TYPE_CHECKING:
from .configuration_zoedepth import ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP, ZoeDepthConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_zoedepth import ZoeDepthImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ZoeDepth model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Intel/zoedepth-nyu": "https://huggingface.co/Intel/zoedepth-nyu/resolve/main/config.json",
}
class ZoeDepthConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ZoeDepthForDepthEstimation`]. It is used to instantiate an ZoeDepth
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ZoeDepth
[Intel/zoedepth-nyu](https://huggingface.co/Intel/zoedepth-nyu) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `BeitConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the batch normalization layers.
readout_type (`str`, *optional*, defaults to `"project"`):
The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of
the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`].
- "ignore" simply ignores the CLS token.
- "add" passes the information from the CLS token to all other tokens by adding the representations.
- "project" passes information to the other tokens by concatenating the readout to all other tokens before
projecting the
representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
The up/downsampling factors of the reassemble layers.
neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
The hidden sizes to project to for the feature maps of the backbone.
fusion_hidden_size (`int`, *optional*, defaults to 256):
The number of channels before fusion.
head_in_index (`int`, *optional*, defaults to -1):
The index of the features to use in the heads.
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
Whether to use bias in the pre-activate residual units of the fusion blocks.
num_relative_features (`int`, *optional*, defaults to 32):
The number of features to use in the relative depth estimation head.
add_projection (`bool`, *optional*, defaults to `False`):
Whether to add a projection layer before the depth estimation head.
bottleneck_features (`int`, *optional*, defaults to 256):
The number of features in the bottleneck layer.
num_attractors (`List[int], *optional*, defaults to `[16, 8, 4, 1]`):
The number of attractors to use in each stage.
bin_embedding_dim (`int`, *optional*, defaults to 128):
The dimension of the bin embeddings.
attractor_alpha (`int`, *optional*, defaults to 1000):
The alpha value to use in the attractor.
attractor_gamma (`int`, *optional*, defaults to 2):
The gamma value to use in the attractor.
attractor_kind (`str`, *optional*, defaults to `"mean"`):
The kind of attractor to use. Can be one of [`"mean"`, `"sum"`].
min_temp (`float`, *optional*, defaults to 0.0212):
The minimum temperature value to consider.
max_temp (`float`, *optional*, defaults to 50.0):
The maximum temperature value to consider.
bin_centers_type (`str`, *optional*, defaults to `"softplus"`):
Activation type used for bin centers. Can be "normed" or "softplus". For "normed" bin centers, linear normalization trick
is applied. This results in bounded bin centers. For "softplus", softplus activation is used and thus are unbounded.
bin_configurations (`List[dict]`, *optional*, defaults to `[{'n_bins': 64, 'min_depth': 0.001, 'max_depth': 10.0}]`):
Configuration for each of the bin heads.
Each configuration should consist of the following keys:
- name (`str`): The name of the bin head - only required in case of multiple bin configurations.
- `n_bins` (`int`): The number of bins to use.
- `min_depth` (`float`): The minimum depth value to consider.
- `max_depth` (`float`): The maximum depth value to consider.
In case only a single configuration is passed, the model will use a single head with the specified configuration.
In case multiple configurations are passed, the model will use multiple heads with the specified configurations.
num_patch_transformer_layers (`int`, *optional*):
The number of transformer layers to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_hidden_size (`int`, *optional*):
The hidden size to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_intermediate_size (`int`, *optional*):
The intermediate size to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_num_attention_heads (`int`, *optional*):
The number of attention heads to use in the patch transformer. Only used in case of multiple bin configurations.
Example:
```python
>>> from transformers import ZoeDepthConfig, ZoeDepthForDepthEstimation
>>> # Initializing a ZoeDepth zoedepth-large style configuration
>>> configuration = ZoeDepthConfig()
>>> # Initializing a model from the zoedepth-large style configuration
>>> model = ZoeDepthForDepthEstimation(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "zoedepth"
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
backbone_kwargs=None,
hidden_act="gelu",
initializer_range=0.02,
batch_norm_eps=1e-05,
readout_type="project",
reassemble_factors=[4, 2, 1, 0.5],
neck_hidden_sizes=[96, 192, 384, 768],
fusion_hidden_size=256,
head_in_index=-1,
use_batch_norm_in_fusion_residual=False,
use_bias_in_fusion_residual=None,
num_relative_features=32,
add_projection=False,
bottleneck_features=256,
num_attractors=[16, 8, 4, 1],
bin_embedding_dim=128,
attractor_alpha=1000,
attractor_gamma=2,
attractor_kind="mean",
min_temp=0.0212,
max_temp=50.0,
bin_centers_type="softplus",
bin_configurations=[{"n_bins": 64, "min_depth": 0.001, "max_depth": 10.0}],
num_patch_transformer_layers=None,
patch_transformer_hidden_size=None,
patch_transformer_intermediate_size=None,
patch_transformer_num_attention_heads=None,
**kwargs,
):
super().__init__(**kwargs)
if readout_type not in ["ignore", "add", "project"]:
raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
if attractor_kind not in ["mean", "sum"]:
raise ValueError("Attractor_kind must be one of ['mean', 'sum']")
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `BEiT` backbone.")
backbone_config = CONFIG_MAPPING["beit"](
image_size=384,
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
use_relative_position_bias=True,
reshape_hidden_states=False,
out_features=["stage6", "stage12", "stage18", "stage24"],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.hidden_act = hidden_act
self.use_pretrained_backbone = use_pretrained_backbone
self.initializer_range = initializer_range
self.batch_norm_eps = batch_norm_eps
self.readout_type = readout_type
self.reassemble_factors = reassemble_factors
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.head_in_index = head_in_index
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
self.num_relative_features = num_relative_features
self.add_projection = add_projection
self.bottleneck_features = bottleneck_features
self.num_attractors = num_attractors
self.bin_embedding_dim = bin_embedding_dim
self.attractor_alpha = attractor_alpha
self.attractor_gamma = attractor_gamma
self.attractor_kind = attractor_kind
self.min_temp = min_temp
self.max_temp = max_temp
self.bin_centers_type = bin_centers_type
self.bin_configurations = bin_configurations
self.num_patch_transformer_layers = num_patch_transformer_layers
self.patch_transformer_hidden_size = patch_transformer_hidden_size
self.patch_transformer_intermediate_size = patch_transformer_intermediate_size
self.patch_transformer_num_attention_heads = patch_transformer_num_attention_heads
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ZoeDepth checkpoints from the original repository. URL: https://github.com/isl-org/ZoeDepth.
Original logits where obtained by running the following code:
!git clone -b understanding_zoedepth https://github.com/NielsRogge/ZoeDepth
!python inference.py
"""
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import BeitConfig, ZoeDepthConfig, ZoeDepthForDepthEstimation, ZoeDepthImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_zoedepth_config(model_name):
image_size = 384
backbone_config = BeitConfig(
image_size=image_size,
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
use_relative_position_bias=True,
reshape_hidden_states=False,
out_features=["stage6", "stage12", "stage18", "stage24"], # beit-large-512 uses [5, 11, 17, 23],
)
neck_hidden_sizes = [256, 512, 1024, 1024]
bin_centers_type = "softplus" if model_name in ["ZoeD_N", "ZoeD_NK"] else "normed"
if model_name == "ZoeD_NK":
bin_configurations = [
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
{"name": "kitti", "n_bins": 64, "min_depth": 1e-3, "max_depth": 80.0},
]
elif model_name in ["ZoeD_N", "ZoeD_K"]:
bin_configurations = [
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
]
config = ZoeDepthConfig(
backbone_config=backbone_config,
neck_hidden_sizes=neck_hidden_sizes,
bin_centers_type=bin_centers_type,
bin_configurations=bin_configurations,
num_patch_transformer_layers=4 if model_name == "ZoeD_NK" else None,
patch_transformer_hidden_size=128 if model_name == "ZoeD_NK" else None,
patch_transformer_intermediate_size=1024 if model_name == "ZoeD_NK" else None,
patch_transformer_num_attention_heads=4 if model_name == "ZoeD_NK" else None,
)
return config, image_size
def rename_key(name):
# Transformer backbone
if "core.core.pretrained.model.blocks" in name:
name = name.replace("core.core.pretrained.model.blocks", "backbone.encoder.layer")
if "core.core.pretrained.model.patch_embed.proj" in name:
name = name.replace(
"core.core.pretrained.model.patch_embed.proj", "backbone.embeddings.patch_embeddings.projection"
)
if "core.core.pretrained.model.cls_token" in name:
name = name.replace("core.core.pretrained.model.cls_token", "backbone.embeddings.cls_token")
if "norm1" in name and "patch_transformer" not in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name and "patch_transformer" not in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "gamma_1" in name:
name = name.replace("gamma_1", "lambda_1")
if "gamma_2" in name:
name = name.replace("gamma_2", "lambda_2")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn.relative_position_bias_table" in name:
name = name.replace(
"attn.relative_position_bias_table",
"attention.attention.relative_position_bias.relative_position_bias_table",
)
if "attn.relative_position_index" in name:
name = name.replace(
"attn.relative_position_index", "attention.attention.relative_position_bias.relative_position_index"
)
# activation postprocessing (readout projections + resize blocks)
if "core.core.pretrained.act_postprocess1.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess1.0.project", "neck.reassemble_stage.readout_projects.0"
)
if "core.core.pretrained.act_postprocess2.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess2.0.project", "neck.reassemble_stage.readout_projects.1"
)
if "core.core.pretrained.act_postprocess3.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess3.0.project", "neck.reassemble_stage.readout_projects.2"
)
if "core.core.pretrained.act_postprocess4.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess4.0.project", "neck.reassemble_stage.readout_projects.3"
)
if "core.core.pretrained.act_postprocess1.3" in name:
name = name.replace("core.core.pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
if "core.core.pretrained.act_postprocess2.3" in name:
name = name.replace("core.core.pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
if "core.core.pretrained.act_postprocess3.3" in name:
name = name.replace("core.core.pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
if "core.core.pretrained.act_postprocess4.3" in name:
name = name.replace("core.core.pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
if "core.core.pretrained.act_postprocess1.4" in name:
name = name.replace("core.core.pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
if "core.core.pretrained.act_postprocess2.4" in name:
name = name.replace("core.core.pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
if "core.core.pretrained.act_postprocess4.4" in name:
name = name.replace("core.core.pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
# scratch convolutions
if "core.core.scratch.layer1_rn.weight" in name:
name = name.replace("core.core.scratch.layer1_rn.weight", "neck.convs.0.weight")
if "core.core.scratch.layer2_rn.weight" in name:
name = name.replace("core.core.scratch.layer2_rn.weight", "neck.convs.1.weight")
if "core.core.scratch.layer3_rn.weight" in name:
name = name.replace("core.core.scratch.layer3_rn.weight", "neck.convs.2.weight")
if "core.core.scratch.layer4_rn.weight" in name:
name = name.replace("core.core.scratch.layer4_rn.weight", "neck.convs.3.weight")
# fusion layers
# tricky here: mapping = {1:3, 2:2, 3:1, 4:0}
if "core.core.scratch.refinenet1" in name:
name = name.replace("core.core.scratch.refinenet1", "neck.fusion_stage.layers.3")
if "core.core.scratch.refinenet2" in name:
name = name.replace("core.core.scratch.refinenet2", "neck.fusion_stage.layers.2")
if "core.core.scratch.refinenet3" in name:
name = name.replace("core.core.scratch.refinenet3", "neck.fusion_stage.layers.1")
if "core.core.scratch.refinenet4" in name:
name = name.replace("core.core.scratch.refinenet4", "neck.fusion_stage.layers.0")
if "resConfUnit1" in name:
name = name.replace("resConfUnit1", "residual_layer1")
if "resConfUnit2" in name:
name = name.replace("resConfUnit2", "residual_layer2")
if "conv1" in name:
name = name.replace("conv1", "convolution1")
if "conv2" in name and "residual_layer" in name:
name = name.replace("conv2", "convolution2")
if "out_conv" in name:
name = name.replace("out_conv", "projection")
# relative depth estimation head
if "core.core.scratch.output_conv.0" in name:
name = name.replace("core.core.scratch.output_conv.0", "relative_head.conv1")
if "core.core.scratch.output_conv.2" in name:
name = name.replace("core.core.scratch.output_conv.2", "relative_head.conv2")
if "core.core.scratch.output_conv.4" in name:
name = name.replace("core.core.scratch.output_conv.4", "relative_head.conv3")
# patch transformer
if "patch_transformer" in name:
name = name.replace("patch_transformer", "metric_head.patch_transformer")
if "mlp_classifier.0" in name:
name = name.replace("mlp_classifier.0", "metric_head.mlp_classifier.linear1")
if "mlp_classifier.2" in name:
name = name.replace("mlp_classifier.2", "metric_head.mlp_classifier.linear2")
if "projectors" in name:
name = name.replace("projectors", "metric_head.projectors")
if "seed_bin_regressors" in name:
name = name.replace("seed_bin_regressors", "metric_head.seed_bin_regressors")
if "seed_bin_regressor" in name and "seed_bin_regressors" not in name:
name = name.replace("seed_bin_regressor", "metric_head.seed_bin_regressor")
if "seed_projector" in name:
name = name.replace("seed_projector", "metric_head.seed_projector")
if "_net.0" in name:
name = name.replace("_net.0", "conv1")
if "_net.2" in name:
name = name.replace("_net.2", "conv2")
if "attractors" in name:
name = name.replace("attractors", "metric_head.attractors")
if "conditional_log_binomial" in name:
name = name.replace("conditional_log_binomial", "metric_head.conditional_log_binomial")
# metric depth estimation head
if "conv2" in name and "metric_head" not in name and "attractors" not in name and "relative_head" not in name:
name = name.replace("conv2", "metric_head.conv2")
if "transformer_encoder.layers" in name:
name = name.replace("transformer_encoder.layers", "transformer_encoder")
return name
def read_in_q_k_v_metric_head(state_dict):
hidden_size = 128
for i in range(4):
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.weight"] = in_proj_weight[
:hidden_size, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.bias"] = in_proj_bias[:hidden_size]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.weight"] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.bias"] = in_proj_bias[
hidden_size : hidden_size * 2
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.weight"] = in_proj_weight[
-hidden_size:, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.bias"] = in_proj_bias[-hidden_size:]
def convert_state_dict(orig_state_dict):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
# rename key
new_name = rename_key(key)
orig_state_dict[new_name] = val
return orig_state_dict
def remove_ignore_keys(state_dict):
for key, _ in state_dict.copy().items():
if (
"fc_norm" in key
or "relative_position_index" in key
or "k_idx" in key
or "K_minus_1" in key
or "core.core.pretrained.model.head" in key
):
state_dict.pop(key, None)
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
hidden_size = config.backbone_config.hidden_size
for i in range(config.backbone_config.num_hidden_layers):
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.v_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# We will verify our results on an image
def prepare_img():
filepath = hf_hub_download(repo_id="shariqfarooq/ZoeDepth", filename="examples/person_1.jpeg", repo_type="space")
image = Image.open(filepath).convert("RGB")
return image
@torch.no_grad()
def convert_zoedepth_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
"""
Copy/paste/tweak model's weights to our ZoeDepth structure.
"""
# define ZoeDepth configuration based on URL
config, _ = get_zoedepth_config(model_name)
# load original model
original_model = torch.hub.load(
"NielsRogge/ZoeDepth:understanding_zoedepth", model_name, pretrained=True, force_reload=True
)
original_model.eval()
state_dict = original_model.state_dict()
print("Original state dict:")
for name, param in state_dict.items():
print(name, param.shape)
# read in qkv matrices
read_in_q_k_v(state_dict, config)
if model_name == "ZoeD_NK":
read_in_q_k_v_metric_head(state_dict)
# rename keys
state_dict = convert_state_dict(state_dict)
# remove certain keys
remove_ignore_keys(state_dict)
# load HuggingFace model
model = ZoeDepthForDepthEstimation(config)
model.load_state_dict(state_dict)
model.eval()
# verify image processor
image = prepare_img()
image_processor = ZoeDepthImageProcessor()
pixel_values = image_processor(image, return_tensors="pt").pixel_values
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="zoedepth_pixel_values.pt",
repo_type="dataset",
)
original_pixel_values = torch.load(filepath, map_location="cpu")
assert torch.allclose(pixel_values, original_pixel_values)
# verify logits
# this was done on a resized version of the cats image (384x384)
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="zoedepth_pixel_values.pt",
repo_type="dataset",
revision="1865dbb81984f01c89e83eec10f8d07efd10743d",
)
cats_pixel_values = torch.load(filepath, map_location="cpu")
depth = model(cats_pixel_values).predicted_depth
# Verify logits
# These were obtained by inserting the pixel_values at the patch embeddings of BEiT
if model_name == "ZoeD_N":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.0328, 1.0604, 1.0747], [1.0816, 1.1293, 1.1456], [1.1117, 1.1629, 1.1766]])
elif model_name == "ZoeD_K":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.6567, 1.6852, 1.7065], [1.6707, 1.6764, 1.6713], [1.7195, 1.7166, 1.7118]])
elif model_name == "ZoeD_NK":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.1228, 1.1079, 1.1382], [1.1807, 1.1658, 1.1891], [1.2344, 1.2094, 1.2317]])
print("Shape of depth:", depth.shape)
print("First 3x3 slice of depth:", depth[0, :3, :3])
assert depth.shape == torch.Size(expected_shape)
assert torch.allclose(depth[0, :3, :3], expected_slice, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model_name_to_repo_id = {
"ZoeD_N": "zoedepth-nyu",
"ZoeD_K": "zoedepth-kitti",
"ZoeD_NK": "zoedepth-nyu-kitti",
}
print("Pushing model and processor to the hub...")
repo_id = model_name_to_repo_id[model_name]
model.push_to_hub(f"Intel/{repo_id}")
image_processor = ZoeDepthImageProcessor()
image_processor.push_to_hub(f"Intel/{repo_id}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ZoeD_N",
choices=["ZoeD_N", "ZoeD_K", "ZoeD_NK"],
type=str,
help="Name of the original ZoeDepth checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=False,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
)
args = parser.parse_args()
convert_zoedepth_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
This diff is collapsed.
This diff is collapsed.
...@@ -9660,6 +9660,20 @@ class YosoPreTrainedModel(metaclass=DummyObject): ...@@ -9660,6 +9660,20 @@ class YosoPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ZoeDepthPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Adafactor(metaclass=DummyObject): class Adafactor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -651,3 +651,10 @@ class YolosImageProcessor(metaclass=DummyObject): ...@@ -651,3 +651,10 @@ class YolosImageProcessor(metaclass=DummyObject):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])
class ZoeDepthImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment