Unverified Commit a35be97a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove the unused/buggy `--train-center-crop` flag from Classification preset (#6642)

* Fixing inverted center_crop check on Classification preset

* Remove the `--train-center-crop` flag.
parent 21f70c17
......@@ -248,7 +248,7 @@ Note that `--val-resize-size` was optimized in a post-training step, see their `
### MaxViT
```
torchrun --nproc_per_node=8 --n_nodes=4 train.py\
--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224
--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --model-ema --val-resize-size 224\
--val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn
```
Here `$MODEL` is `maxvit_t`.
......
......@@ -16,13 +16,8 @@ class ClassificationPresetTrain:
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
center_crop=False,
):
trans = (
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if center_crop
else [transforms.CenterCrop(crop_size)]
)
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
......
......@@ -113,11 +113,10 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
val_resize_size, val_crop_size, train_crop_size, center_crop = (
val_resize_size, val_crop_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
args.train_center_crop,
)
interpolation = InterpolationMode(args.interpolation)
......@@ -136,7 +135,6 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
center_crop=center_crop,
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
......@@ -501,11 +499,6 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument(
"--train-center-crop",
action="store_true",
help="use center crop instead of random crop for training (default: False)",
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment