Unverified Commit 20e54e49 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Indexing fix - CLIP checkpoint conversion (#22776)

* Indexing fix - CLIP checkpoint conversion

* Fix up
parent 895ae3b5
...@@ -127,9 +127,9 @@ def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa ...@@ -127,9 +127,9 @@ def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa
input_ids = torch.arange(0, 77).unsqueeze(0) input_ids = torch.arange(0, 77).unsqueeze(0)
pixel_values = torch.randn(1, 3, 224, 224) pixel_values = torch.randn(1, 3, 224, 224)
hf_logits_per_image, hf_logits_per_text = hf_model( hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
input_ids=input_ids, pixel_values=pixel_values, return_dict=True hf_logits_per_image = hf_outputs.logits_per_image
)[1:3] hf_logits_per_text = hf_outputs.logits_per_text
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment