# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
fromtypingimportOptional,Tuple,Type
from.commonimportLayerNorm2d,MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
classImageEncoderViT(nn.Module):
def__init__(
self,
img_size:int=1024,
patch_size:int=16,
in_chans:int=3,
embed_dim:int=768,
depth:int=12,
num_heads:int=12,
mlp_ratio:float=4.0,
out_chans:int=256,
qkv_bias:bool=True,
norm_layer:Type[nn.Module]=nn.LayerNorm,
act_layer:Type[nn.Module]=nn.GELU,
use_abs_pos:bool=True,
use_rel_pos:bool=False,
rel_pos_zero_init:bool=True,
window_size:int=0,
global_attn_indexes:Tuple[int,...]=(),
)->None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size=img_size
self.patch_embed=PatchEmbed(
kernel_size=(patch_size,patch_size),
stride=(patch_size,patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed:Optional[nn.Parameter]=None
ifuse_abs_pos:
# Initialize absolute positional embedding with pretrain image size.