Unverified Commit 6e4d3f08 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[GIT] Convert more checkpoints (#21245)



* Extend conversion script

* Remove print statement
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 66459ce3
...@@ -246,6 +246,11 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal ...@@ -246,6 +246,11 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
"git-large-msrvtt-qa": ( "git-large-msrvtt-qa": (
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt" "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_MSRVTT_QA/snapshot/model.pt"
), ),
"git-large-r": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R/snapshot/model.pt",
"git-large-r-coco": "https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_COCO/snapshot/model.pt",
"git-large-r-textcaps": (
"https://publicgit.blob.core.windows.net/data/output/GIT_LARGE_R_TEXTCAPS/snapshot/model.pt"
),
} }
model_name_to_path = { model_name_to_path = {
...@@ -258,7 +263,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal ...@@ -258,7 +263,7 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
# define GIT configuration based on model name # define GIT configuration based on model name
config, image_size, is_video = get_git_config(model_name) config, image_size, is_video = get_git_config(model_name)
if "large" in model_name and not is_video: if "large" in model_name and not is_video and "large-r" not in model_name:
# large checkpoints take way too long to download # large checkpoints take way too long to download
checkpoint_path = model_name_to_path[model_name] checkpoint_path = model_name_to_path[model_name]
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
...@@ -349,6 +354,12 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal ...@@ -349,6 +354,12 @@ def convert_git_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=Fal
expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113]) expected_slice_logits = torch.tensor([-1.0113, -1.0114, -1.0113])
elif model_name == "git-large-msrvtt-qa": elif model_name == "git-large-msrvtt-qa":
expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131]) expected_slice_logits = torch.tensor([0.0130, 0.0134, 0.0131])
elif model_name == "git-large-r":
expected_slice_logits = torch.tensor([-1.1283, -1.1285, -1.1286])
elif model_name == "git-large-r-coco":
expected_slice_logits = torch.tensor([-0.9641, -0.9641, -0.9641])
elif model_name == "git-large-r-textcaps":
expected_slice_logits = torch.tensor([-1.1121, -1.1120, -1.1124])
assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4) assert torch.allclose(logits[0, -1, :3], expected_slice_logits, atol=1e-4)
print("Looks ok!") print("Looks ok!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment