Unverified Commit 28557e0c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix copypaste collate pickle issues (#6181)

parent d0d7058a
...@@ -35,6 +35,11 @@ from torchvision.transforms import InterpolationMode ...@@ -35,6 +35,11 @@ from torchvision.transforms import InterpolationMode
from transforms import SimpleCopyPaste from transforms import SimpleCopyPaste
def copypaste_collate_fn(batch):
copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
return copypaste(*utils.collate_fn(batch))
def get_dataset(name, image_set, transform, data_path): def get_dataset(name, image_set, transform, data_path):
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[name] p, ds_fn, num_classes = paths[name]
...@@ -194,11 +199,6 @@ def main(args): ...@@ -194,11 +199,6 @@ def main(args):
if args.data_augmentation != "lsj": if args.data_augmentation != "lsj":
raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies") raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")
copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True)
def copypaste_collate_fn(batch):
return copypaste(*utils.collate_fn(batch))
train_collate_fn = copypaste_collate_fn train_collate_fn = copypaste_collate_fn
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment