Unverified Commit 1b14829c authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

[ViT] Support fine-tuning with different image resolution (#5025)



* add from_checkpoint method for vit

* remove useless change

* Making interpolate_embeddings a utility function

* remove logging

* fix type hint

* fix return type check

* ad  retuurns in docsting & unify type hint

* remove useless import

* fix issue: 'type' object is not subscriptable

* Fixing typing issues

* Making interpolation mode configurable

* formatting
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 8aa31740
...@@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru ...@@ -379,3 +379,78 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def interpolate_embeddings(
image_size: int,
patch_size: int,
model_state: "OrderedDict[str, torch.Tensor]",
interpolation_mode: str = "bicubic",
reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
"""This function helps interpolating positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
image_size (int): Image size of the new model.
patch_size (int): Patch size of the new model.
model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
reset_heads (bool): If true, not copying the state of heads. Default: False.
Returns:
OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
"""
# Shape of pos_embedding is (1, seq_length, hidden_dim)
pos_embedding = model_state["encoder.pos_embedding"]
n, seq_length, hidden_dim = pos_embedding.shape
if n != 1:
raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")
new_seq_length = (image_size // patch_size) ** 2 + 1
# Need to interpolate the weights for the position embedding.
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated so we split it up.
seq_length -= 1
new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :]
pos_embedding_img = pos_embedding[:, 1:, :]
# (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
seq_length_1d = int(math.sqrt(seq_length))
torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!")
# (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
new_seq_length_1d = image_size // patch_size
# Perform interpolation.
# (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
new_pos_embedding_img = nn.functional.interpolate(
pos_embedding_img,
size=new_seq_length_1d,
mode=interpolation_mode,
align_corners=True,
)
# (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)
# (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)
model_state["encoder.pos_embedding"] = new_pos_embedding
if reset_heads:
model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
for k, v in model_state.items():
if not k.startswith("heads"):
model_state_copy[k] = v
model_state = model_state_copy
return model_state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment