Commit 9e4126a0 authored by chenych's avatar chenych
Browse files

add .float()

parent 88df767f
...@@ -571,7 +571,7 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix): ...@@ -571,7 +571,7 @@ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated # only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2).float()
pos_tokens = torch.nn.functional.interpolate( pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment