Unverified Commit 111228cb authored by wfng92's avatar wfng92 Committed by GitHub
Browse files

Fix torchvision.transforms and transforms function naming clash (#2274)



* Fix torchvision.transforms and transforms function naming clash

* Update unconditional script for onnx

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent bbb46ad3
......@@ -386,13 +386,13 @@ def main(args):
]
)
def transforms(examples):
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms)
dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
......
......@@ -386,13 +386,13 @@ def main(args):
]
)
def transforms(examples):
def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
logger.info(f"Dataset size: {len(dataset)}")
dataset.set_transform(transforms)
dataset.set_transform(transform_images)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment