"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "44c7857b873e535d8a200000b1da2ec23cf74273"
Unverified Commit 0a22335e authored by Eduardo Gonzalez Ponferrada's avatar Eduardo Gonzalez Ponferrada Committed by GitHub
Browse files

[Flax/run_hybrid_clip] Fix duplicating images when captions_per_image exceeds...

[Flax/run_hybrid_clip] Fix duplicating images when captions_per_image exceeds the number of captions, enable truncation 
parent c1c2d68d
...@@ -224,8 +224,9 @@ class ImageTextDataset(VisionDataset): ...@@ -224,8 +224,9 @@ class ImageTextDataset(VisionDataset):
self.image_paths = [] self.image_paths = []
for example in examples: for example in examples:
self.captions.extend(example["captions"][:captions_per_image]) captions_subset = example["captions"][:captions_per_image]
self.image_paths.extend([example["image_path"]] * captions_per_image) self.captions.extend(captions_subset)
self.image_paths.extend([example["image_path"]] * len(captions_subset))
def _load_image(self, idx: int): def _load_image(self, idx: int):
path = self.image_paths[idx] path = self.image_paths[idx]
...@@ -373,7 +374,9 @@ def main(): ...@@ -373,7 +374,9 @@ def main():
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy() pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
captions = [example[1] for example in examples] captions = [example[1] for example in examples]
inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np") inputs = tokenizer(
captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np"
)
batch = { batch = {
"pixel_values": pixel_values, "pixel_values": pixel_values,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment