You need to sign in or sign up before continuing.
Unverified Commit 3b5e6fc4 authored by Zhu Lin Ch'ng's avatar Zhu Lin Ch'ng Committed by GitHub
Browse files

Fix CutMix and MixUp arguments in transforms.py (#8287)

parent 2c127da8
......@@ -222,7 +222,7 @@ def main(args):
num_classes = len(dataset.classes)
mixup_cutmix = get_mixup_cutmix(
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_categories=num_classes, use_v2=args.use_v2
mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
)
if mixup_cutmix is not None:
......
......@@ -7,21 +7,21 @@ from torch import Tensor
from torchvision.transforms import functional as F
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_categories, use_v2):
def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
transforms_module = get_module(use_v2)
mixup_cutmix = []
if mixup_alpha > 0:
mixup_cutmix.append(
transforms_module.MixUp(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomMixUp(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.CutMix(alpha=mixup_alpha, num_categories=num_categories)
transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes)
if use_v2
else RandomCutMix(num_classes=num_categories, p=1.0, alpha=mixup_alpha)
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
)
if not mixup_cutmix:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment