Unverified Commit 7c5eaf9e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add `dpt-hybrid` support (#20645)



* add `dpt-hybrid` support

* refactor

* final changes, all tests pass

* final cleanups

* final changes

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix docstring

* fix typo

* change `vit_hybrid` to `hybrid`

* replace dataclass

* add docstring

* move dataclasses

* fix test

* add `PretrainedConfig` support for `backbone_config`

* fix docstring

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove `embedding_type` and replace it by `is_hybrid`
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3ac040bc
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
# limitations under the License. # limitations under the License.
""" DPT model configuration""" """ DPT model configuration"""
import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..bit import BitConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -76,6 +79,8 @@ class DPTConfig(PretrainedConfig): ...@@ -76,6 +79,8 @@ class DPTConfig(PretrainedConfig):
- "project" passes information to the other tokens by concatenating the readout to all other tokens before - "project" passes information to the other tokens by concatenating the readout to all other tokens before
projecting the projecting the
representation to the original feature dimension D using a linear layer followed by a GELU non-linearity. representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
is_hybrid (`bool`, *optional*, defaults to `False`):
Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
The up/downsampling factors of the reassemble layers. The up/downsampling factors of the reassemble layers.
neck_hidden_sizes (`List[str]`, *optional*, defaults to [96, 192, 384, 768]): neck_hidden_sizes (`List[str]`, *optional*, defaults to [96, 192, 384, 768]):
...@@ -94,6 +99,12 @@ class DPTConfig(PretrainedConfig): ...@@ -94,6 +99,12 @@ class DPTConfig(PretrainedConfig):
The index that is ignored by the loss function of the semantic segmentation model. The index that is ignored by the loss function of the semantic segmentation model.
semantic_classifier_dropout (`float`, *optional*, defaults to 0.1): semantic_classifier_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the semantic classification head. The dropout ratio for the semantic classification head.
backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`):
Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
neck_ignore_stages (`List[int]`, *optional*, defaults to `[0, 1]`):
Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
Used only for the `hybrid` embedding type. The configuration of the backbone in a dictionary.
Example: Example:
...@@ -125,6 +136,7 @@ class DPTConfig(PretrainedConfig): ...@@ -125,6 +136,7 @@ class DPTConfig(PretrainedConfig):
image_size=384, image_size=384,
patch_size=16, patch_size=16,
num_channels=3, num_channels=3,
is_hybrid=False,
qkv_bias=True, qkv_bias=True,
backbone_out_indices=[2, 5, 8, 11], backbone_out_indices=[2, 5, 8, 11],
readout_type="project", readout_type="project",
...@@ -137,11 +149,47 @@ class DPTConfig(PretrainedConfig): ...@@ -137,11 +149,47 @@ class DPTConfig(PretrainedConfig):
auxiliary_loss_weight=0.4, auxiliary_loss_weight=0.4,
semantic_loss_ignore_index=255, semantic_loss_ignore_index=255,
semantic_classifier_dropout=0.1, semantic_classifier_dropout=0.1,
backbone_featmap_shape=[1, 1024, 24, 24],
neck_ignore_stages=[0, 1],
backbone_config=None,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_hybrid = is_hybrid
if self.is_hybrid:
if backbone_config is None:
logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = {
"global_padding": "same",
"layer_type": "bottleneck",
"depths": [3, 4, 9],
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
}
self.backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.")
self.backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig):
self.backbone_config = backbone_config
else:
raise ValueError(
f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
)
self.backbone_featmap_shape = backbone_featmap_shape
self.neck_ignore_stages = neck_ignore_stages
if readout_type != "project":
raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")
else:
self.backbone_config = None
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -168,3 +216,16 @@ class DPTConfig(PretrainedConfig): ...@@ -168,3 +216,16 @@ class DPTConfig(PretrainedConfig):
self.auxiliary_loss_weight = auxiliary_loss_weight self.auxiliary_loss_weight = auxiliary_loss_weight
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.semantic_classifier_dropout = semantic_classifier_dropout self.semantic_classifier_dropout = semantic_classifier_dropout
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DPT checkpoints from the original repository. URL: https://github.com/isl-org/DPT"""
import argparse
import json
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import DPTConfig, DPTFeatureExtractor, DPTForDepthEstimation, DPTForSemanticSegmentation
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_dpt_config(checkpoint_url):
config = DPTConfig(embedding_type="hybrid")
if "large" in checkpoint_url:
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
config.backbone_out_indices = [5, 11, 17, 23]
config.neck_hidden_sizes = [256, 512, 1024, 1024]
expected_shape = (1, 384, 384)
if "nyu" or "midas" in checkpoint_url:
config.hidden_size = 768
config.reassemble_factors = [1, 1, 1, 0.5]
config.neck_hidden_sizes = [256, 512, 768, 768]
config.num_labels = 150
config.patch_size = 16
expected_shape = (1, 384, 384)
config.use_batch_norm_in_fusion_residual = False
config.readout_type = "project"
if "ade" in checkpoint_url:
config.use_batch_norm_in_fusion_residual = True
config.hidden_size = 768
config.reassemble_stage = [1, 1, 1, 0.5]
config.num_labels = 150
config.patch_size = 16
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
expected_shape = [1, 150, 480, 480]
return config, expected_shape
def remove_ignore_keys_(state_dict):
ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(name):
if (
"pretrained.model" in name
and "cls_token" not in name
and "pos_embed" not in name
and "patch_embed" not in name
):
name = name.replace("pretrained.model", "dpt.encoder")
if "pretrained.model" in name:
name = name.replace("pretrained.model", "dpt.embeddings")
if "patch_embed" in name:
name = name.replace("patch_embed", "")
if "pos_embed" in name:
name = name.replace("pos_embed", "position_embeddings")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "proj" in name and "project" not in name:
name = name.replace("proj", "projection")
if "blocks" in name:
name = name.replace("blocks", "layer")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "norm1" in name and "backbone" not in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name and "backbone" not in name:
name = name.replace("norm2", "layernorm_after")
if "scratch.output_conv" in name:
name = name.replace("scratch.output_conv", "head")
if "scratch" in name:
name = name.replace("scratch", "neck")
if "layer1_rn" in name:
name = name.replace("layer1_rn", "convs.0")
if "layer2_rn" in name:
name = name.replace("layer2_rn", "convs.1")
if "layer3_rn" in name:
name = name.replace("layer3_rn", "convs.2")
if "layer4_rn" in name:
name = name.replace("layer4_rn", "convs.3")
if "refinenet" in name:
layer_idx = int(name[len("neck.refinenet") : len("neck.refinenet") + 1])
# tricky here: we need to map 4 to 0, 3 to 1, 2 to 2 and 1 to 3
name = name.replace(f"refinenet{layer_idx}", f"fusion_stage.layers.{abs(layer_idx-4)}")
if "out_conv" in name:
name = name.replace("out_conv", "projection")
if "resConfUnit1" in name:
name = name.replace("resConfUnit1", "residual_layer1")
if "resConfUnit2" in name:
name = name.replace("resConfUnit2", "residual_layer2")
if "conv1" in name:
name = name.replace("conv1", "convolution1")
if "conv2" in name:
name = name.replace("conv2", "convolution2")
# readout blocks
if "pretrained.act_postprocess1.0.project.0" in name:
name = name.replace("pretrained.act_postprocess1.0.project.0", "neck.reassemble_stage.readout_projects.0.0")
if "pretrained.act_postprocess2.0.project.0" in name:
name = name.replace("pretrained.act_postprocess2.0.project.0", "neck.reassemble_stage.readout_projects.1.0")
if "pretrained.act_postprocess3.0.project.0" in name:
name = name.replace("pretrained.act_postprocess3.0.project.0", "neck.reassemble_stage.readout_projects.2.0")
if "pretrained.act_postprocess4.0.project.0" in name:
name = name.replace("pretrained.act_postprocess4.0.project.0", "neck.reassemble_stage.readout_projects.3.0")
# resize blocks
if "pretrained.act_postprocess1.3" in name:
name = name.replace("pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
if "pretrained.act_postprocess1.4" in name:
name = name.replace("pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
if "pretrained.act_postprocess2.3" in name:
name = name.replace("pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
if "pretrained.act_postprocess2.4" in name:
name = name.replace("pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
if "pretrained.act_postprocess3.3" in name:
name = name.replace("pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
if "pretrained.act_postprocess4.3" in name:
name = name.replace("pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
if "pretrained.act_postprocess4.4" in name:
name = name.replace("pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
if "pretrained" in name:
name = name.replace("pretrained", "dpt")
if "bn" in name:
name = name.replace("bn", "batch_norm")
if "head" in name:
name = name.replace("head", "head.head")
if "encoder.norm" in name:
name = name.replace("encoder.norm", "layernorm")
if "auxlayer" in name:
name = name.replace("auxlayer", "auxiliary_head.head")
if "backbone" in name:
name = name.replace("backbone", "backbone.bit.encoder")
if ".." in name:
name = name.replace("..", ".")
if "stem.conv" in name:
name = name.replace("stem.conv", "bit.embedder.convolution")
if "blocks" in name:
name = name.replace("blocks", "layers")
if "convolution" in name and "backbone" in name:
name = name.replace("convolution", "conv")
if "layer" in name and "backbone" in name:
name = name.replace("layer", "layers")
if "backbone.bit.encoder.bit" in name:
name = name.replace("backbone.bit.encoder.bit", "backbone.bit")
if "embedder.conv" in name:
name = name.replace("embedder.conv", "embedder.convolution")
if "backbone.bit.encoder.stem.norm" in name:
name = name.replace("backbone.bit.encoder.stem.norm", "backbone.bit.embedder.norm")
return name
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
for i in range(config.num_hidden_layers):
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub, model_name, show_prediction):
"""
Copy/paste/tweak model's weights to our DPT structure.
"""
# define DPT configuration based on URL
config, expected_shape = get_dpt_config(checkpoint_url)
# load original state_dict from URL
# state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
state_dict = torch.load(checkpoint_url, map_location="cpu")
# remove certain keys
remove_ignore_keys_(state_dict)
# rename keys
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val
# read in qkv matrices
read_in_q_k_v(state_dict, config)
# load HuggingFace model
model = DPTForSemanticSegmentation(config) if "ade" in checkpoint_url else DPTForDepthEstimation(config)
model.load_state_dict(state_dict)
model.eval()
# Check outputs on an image
size = 480 if "ade" in checkpoint_url else 384
feature_extractor = DPTFeatureExtractor(size=size)
image = prepare_img()
encoding = feature_extractor(image, return_tensors="pt")
# forward pass
outputs = model(**encoding).logits if "ade" in checkpoint_url else model(**encoding).predicted_depth
if show_prediction:
prediction = (
torch.nn.functional.interpolate(
outputs.unsqueeze(1),
size=(image.size[1], image.size[0]),
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
Image.fromarray((prediction / prediction.max()) * 255).show()
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub("ybelkada/dpt-hybrid-midas")
feature_extractor.push_to_hub("ybelkada/dpt-hybrid-midas")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
type=str,
help="URL of the original DPT checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=False,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
)
parser.add_argument(
"--model_name",
default="dpt-large",
type=str,
help="Name of the model, in case you're pushing to the hub.",
)
parser.add_argument(
"--show_prediction",
action="store_true",
)
args = parser.parse_args()
convert_dpt_checkpoint(
args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub, args.model_name, args.show_prediction
)
...@@ -22,6 +22,7 @@ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_hea ...@@ -22,6 +22,7 @@ https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_hea
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass
from typing import List, Optional, Set, Tuple, Union from typing import List, Optional, Set, Tuple, Union
import torch import torch
...@@ -36,15 +37,11 @@ from ...file_utils import ( ...@@ -36,15 +37,11 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
BaseModelOutput,
BaseModelOutputWithPooling,
DepthEstimatorOutput,
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import ModelOutput, logging
from ..auto import AutoBackbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
...@@ -61,10 +58,165 @@ _EXPECTED_OUTPUT_SHAPE = [1, 577, 1024] ...@@ -61,10 +58,165 @@ _EXPECTED_OUTPUT_SHAPE = [1, 577, 1024]
DPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ DPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Intel/dpt-large", "Intel/dpt-large",
"Intel/dpt-hybrid-midas",
# See all DPT models at https://huggingface.co/models?filter=dpt # See all DPT models at https://huggingface.co/models?filter=dpt
] ]
@dataclass
class BaseModelOutputWithIntermediateActivations(ModelOutput):
"""
Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
in the context of Vision models.:
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
Intermediate activations that can be used to compute hidden states of the model at various layers.
"""
last_hidden_states: torch.FloatTensor = None
intermediate_activations: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
activations that can be used by the model at later stages.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
Intermediate activations that can be used to compute hidden states of the model at various layers.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
intermediate_activations: Optional[Tuple[torch.FloatTensor]] = None
class DPTViTHybridEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, feature_size=None):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.backbone = AutoBackbone.from_config(config.backbone_config)
feature_dim = self.backbone.channels[-1]
if len(config.backbone_config.out_features) != 3:
raise ValueError(
f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}"
)
self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
if feature_size is None:
feat_map_shape = config.backbone_featmap_shape
feature_size = feat_map_shape[-2:]
feature_dim = feat_map_shape[1]
else:
feature_size = (
feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
)
feature_dim = self.backbone.channels[-1]
self.image_size = image_size
self.patch_size = patch_size[0]
self.num_channels = num_channels
self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
posemb_tok = posemb[:, :start_index]
posemb_grid = posemb[0, start_index:]
old_grid_size = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward(
self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, return_dict: bool = False
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
position_embeddings = self._resize_pos_embed(
self.position_embeddings, height // self.patch_size, width // self.patch_size
)
backbone_output = self.backbone(pixel_values)
features = backbone_output.feature_maps[-1]
# Retrieve also the intermediate activations to use them at later stages
output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
embeddings = self.projection(features).flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
embeddings = embeddings + position_embeddings
if not return_dict:
return (embeddings, output_hidden_states)
# Return hidden states and intermediate activations
return BaseModelOutputWithIntermediateActivations(
last_hidden_states=embeddings,
intermediate_activations=output_hidden_states,
)
class DPTViTEmbeddings(nn.Module): class DPTViTEmbeddings(nn.Module):
""" """
Construct the CLS token, position and patch embeddings. Construct the CLS token, position and patch embeddings.
...@@ -95,7 +247,7 @@ class DPTViTEmbeddings(nn.Module): ...@@ -95,7 +247,7 @@ class DPTViTEmbeddings(nn.Module):
return posemb return posemb
def forward(self, pixel_values): def forward(self, pixel_values, return_dict=False):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# possibly interpolate position encodings to handle varying image sizes # possibly interpolate position encodings to handle varying image sizes
...@@ -117,7 +269,10 @@ class DPTViTEmbeddings(nn.Module): ...@@ -117,7 +269,10 @@ class DPTViTEmbeddings(nn.Module):
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings if not return_dict:
return (embeddings,)
return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
class DPTViTPatchEmbeddings(nn.Module): class DPTViTPatchEmbeddings(nn.Module):
...@@ -429,6 +584,39 @@ class DPTReassembleStage(nn.Module): ...@@ -429,6 +584,39 @@ class DPTReassembleStage(nn.Module):
self.config = config self.config = config
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if config.is_hybrid:
self._init_reassemble_dpt_hybrid(config)
else:
self._init_reassemble_dpt(config)
self.neck_ignore_stages = config.neck_ignore_stages
def _init_reassemble_dpt_hybrid(self, config):
r""" "
For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
for more details.
"""
for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
if i <= 1:
self.layers.append(nn.Identity())
elif i > 1:
self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
if config.readout_type != "project":
raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
# When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
self.readout_projects = nn.ModuleList()
for i in range(len(config.neck_hidden_sizes)):
if i <= 1:
self.readout_projects.append(nn.Sequential(nn.Identity()))
elif i > 1:
self.readout_projects.append(
nn.Sequential(nn.Linear(2 * config.hidden_size, config.hidden_size), ACT2FN[config.hidden_act])
)
def _init_reassemble_dpt(self, config):
for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors): for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor)) self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
...@@ -448,26 +636,27 @@ class DPTReassembleStage(nn.Module): ...@@ -448,26 +636,27 @@ class DPTReassembleStage(nn.Module):
out = [] out = []
for i, hidden_state in enumerate(hidden_states): for i, hidden_state in enumerate(hidden_states):
# reshape to (B, C, H, W) if i not in self.neck_ignore_stages:
hidden_state, cls_token = hidden_state[:, 1:], hidden_state[:, 0] # reshape to (B, C, H, W)
batch_size, sequence_length, num_channels = hidden_state.shape hidden_state, cls_token = hidden_state[:, 1:], hidden_state[:, 0]
size = int(math.sqrt(sequence_length)) batch_size, sequence_length, num_channels = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size, size, size, num_channels) size = int(math.sqrt(sequence_length))
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_shape = hidden_state.shape
if self.config.readout_type == "project": feature_shape = hidden_state.shape
# reshape to (B, H*W, C) if self.config.readout_type == "project":
hidden_state = hidden_state.flatten(2).permute((0, 2, 1)) # reshape to (B, H*W, C)
readout = cls_token.unsqueeze(1).expand_as(hidden_state) hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
# concatenate the readout token to the hidden states and project readout = cls_token.unsqueeze(1).expand_as(hidden_state)
hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1)) # concatenate the readout token to the hidden states and project
# reshape back to (B, C, H, W) hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape) # reshape back to (B, C, H, W)
elif self.config.readout_type == "add": hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1) elif self.config.readout_type == "add":
hidden_state = hidden_state.reshape(feature_shape) hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
hidden_state = self.layers[i](hidden_state) hidden_state = hidden_state.reshape(feature_shape)
hidden_state = self.layers[i](hidden_state)
out.append(hidden_state) out.append(hidden_state)
return out return out
...@@ -681,7 +870,10 @@ class DPTModel(DPTPreTrainedModel): ...@@ -681,7 +870,10 @@ class DPTModel(DPTPreTrainedModel):
self.config = config self.config = config
# vit encoder # vit encoder
self.embeddings = DPTViTEmbeddings(config) if config.is_hybrid:
self.embeddings = DPTViTHybridEmbeddings(config)
else:
self.embeddings = DPTViTEmbeddings(config)
self.encoder = DPTViTEncoder(config) self.encoder = DPTViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -691,7 +883,10 @@ class DPTModel(DPTPreTrainedModel): ...@@ -691,7 +883,10 @@ class DPTModel(DPTPreTrainedModel):
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings if self.config.is_hybrid:
return self.embeddings
else:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" """
...@@ -705,7 +900,7 @@ class DPTModel(DPTPreTrainedModel): ...@@ -705,7 +900,7 @@ class DPTModel(DPTPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling, output_type=BaseModelOutputWithPoolingAndIntermediateActivations,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
...@@ -717,7 +912,7 @@ class DPTModel(DPTPreTrainedModel): ...@@ -717,7 +912,7 @@ class DPTModel(DPTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]: ) -> Union[Tuple, BaseModelOutputWithPoolingAndIntermediateActivations]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -731,10 +926,12 @@ class DPTModel(DPTPreTrainedModel): ...@@ -731,10 +926,12 @@ class DPTModel(DPTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values) embedding_output = self.embeddings(pixel_values, return_dict=return_dict)
embedding_last_hidden_states = embedding_output[0] if not return_dict else embedding_output.last_hidden_states
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_last_hidden_states,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -747,13 +944,14 @@ class DPTModel(DPTPreTrainedModel): ...@@ -747,13 +944,14 @@ class DPTModel(DPTPreTrainedModel):
if not return_dict: if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:] return head_outputs + encoder_outputs[1:] + embedding_output[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPoolingAndIntermediateActivations(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
intermediate_activations=embedding_output.intermediate_activations,
) )
...@@ -787,7 +985,6 @@ class DPTNeck(nn.Module): ...@@ -787,7 +985,6 @@ class DPTNeck(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
# postprocessing # postprocessing
...@@ -939,9 +1136,17 @@ class DPTForDepthEstimation(DPTPreTrainedModel): ...@@ -939,9 +1136,17 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
# only keep certain features based on config.backbone_out_indices # only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings # note that the hidden_states also include the initial embeddings
hidden_states = [ if not self.config.is_hybrid:
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices hidden_states = [
] feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
]
else:
backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
backbone_hidden_states.extend(
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
)
hidden_states = backbone_hidden_states
hidden_states = self.neck(hidden_states) hidden_states = self.neck(hidden_states)
...@@ -1084,9 +1289,17 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel): ...@@ -1084,9 +1289,17 @@ class DPTForSemanticSegmentation(DPTPreTrainedModel):
# only keep certain features based on config.backbone_out_indices # only keep certain features based on config.backbone_out_indices
# note that the hidden_states also include the initial embeddings # note that the hidden_states also include the initial embeddings
hidden_states = [ if not self.config.is_hybrid:
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices hidden_states = [
] feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
]
else:
backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
backbone_hidden_states.extend(
feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
)
hidden_states = backbone_hidden_states
hidden_states = self.neck(hidden_states) hidden_states = self.neck(hidden_states)
......
...@@ -61,6 +61,7 @@ class DPTModelTester: ...@@ -61,6 +61,7 @@ class DPTModelTester:
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
is_hybrid=False,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
...@@ -81,6 +82,7 @@ class DPTModelTester: ...@@ -81,6 +82,7 @@ class DPTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.scope = scope self.scope = scope
self.is_hybrid = is_hybrid
# sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token) # sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1 self.seq_length = num_patches + 1
...@@ -111,6 +113,7 @@ class DPTModelTester: ...@@ -111,6 +113,7 @@ class DPTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
is_hybrid=self.is_hybrid,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch DPT model. """
import inspect
import unittest
from transformers import DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import MODEL_MAPPING, DPTForDepthEstimation, DPTForSemanticSegmentation, DPTModel
from transformers.models.dpt.modeling_dpt import DPT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import DPTFeatureExtractor
class DPTModelTester:
def __init__(
self,
parent,
batch_size=2,
image_size=32,
patch_size=16,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=4,
backbone_out_indices=[0, 1, 2, 3],
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
num_labels=3,
backbone_featmap_shape=[1, 384, 24, 24],
is_hybrid=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.backbone_out_indices = backbone_out_indices
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.num_labels = num_labels
self.backbone_featmap_shape = backbone_featmap_shape
self.scope = scope
self.is_hybrid = is_hybrid
# sequence length of DPT = num_patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
backbone_config = {
"global_padding": "same",
"layer_type": "bottleneck",
"depths": [3, 4, 9],
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
"hidden_sizes": [96, 192, 384, 768],
"num_groups": 2,
}
return DPTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
backbone_out_indices=self.backbone_out_indices,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
is_hybrid=self.is_hybrid,
backbone_config=backbone_config,
backbone_featmap_shape=self.backbone_featmap_shape,
)
def create_and_check_model(self, config, pixel_values, labels):
model = DPTModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = DPTForDepthEstimation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = DPTForSemanticSegmentation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class DPTModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as DPT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (DPTModel, DPTForDepthEstimation, DPTForSemanticSegmentation) if is_torch_available() else ()
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = DPTModelTester(self)
self.config_tester = ConfigTester(self, config_class=DPTConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="DPT does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_depth_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
def test_for_semantic_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
def test_training(self):
for model_class in self.all_model_classes:
if model_class.__name__ == "DPTForDepthEstimation":
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
for model_class in self.all_model_classes:
if model_class.__name__ == "DPTForDepthEstimation":
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
@slow
def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:
model = DPTModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_raise_readout_type(self):
# We do this test only for DPTForDepthEstimation since it is the only model that uses readout_type
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.readout_type = "add"
with self.assertRaises(ValueError):
_ = DPTForDepthEstimation(config)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
@slow
class DPTModelIntegrationTest(unittest.TestCase):
def test_inference_depth_estimation(self):
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(torch_device)
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 384, 384))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[[5.6437, 5.6146, 5.6511], [5.4371, 5.5649, 5.5958], [5.5215, 5.5184, 5.5293]]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[:3, :3, :3] / 100, expected_slice, atol=1e-4))
...@@ -376,6 +376,10 @@ SPECIAL_MODULE_TO_TEST_MAP = { ...@@ -376,6 +376,10 @@ SPECIAL_MODULE_TO_TEST_MAP = {
"models/gpt2/test_modeling_gpt2.py", "models/gpt2/test_modeling_gpt2.py",
"models/megatron_gpt2/test_modeling_megatron_gpt2.py", "models/megatron_gpt2/test_modeling_megatron_gpt2.py",
], ],
"models/dpt/modeling_dpt.py": [
"models/dpt/test_modeling_dpt.py",
"models/dpt/test_modeling_dpt_hybrid.py",
],
"optimization.py": "optimization/test_optimization.py", "optimization.py": "optimization/test_optimization.py",
"optimization_tf.py": "optimization/test_optimization_tf.py", "optimization_tf.py": "optimization/test_optimization_tf.py",
"pipelines/__init__.py": "pipelines/test_pipelines_*.py", "pipelines/__init__.py": "pipelines/test_pipelines_*.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment