Unverified Commit b1fc2903 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update Reference scripts to support the prototype models (#4837)

* Adding prototype preprocessing on segmentation references.

* Adding prototype preprocessing on video references.
parent 140322f9
......@@ -11,6 +11,12 @@ from coco_utils import get_coco
from torch import nn
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def get_dataset(dir_path, name, image_set, transform):
def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
......@@ -26,11 +32,15 @@ def get_dataset(dir_path, name, image_set, transform):
return ds, num_classes
def get_transform(train):
base_size = 520
crop_size = 480
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights:
return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()
def criterion(inputs, target):
......@@ -90,8 +100,8 @@ def main(args):
device = torch.device(args.device)
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
......@@ -113,9 +123,18 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
model = torchvision.models.segmentation.__dict__[args.model](
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained
)
if not args.weights:
model = torchvision.models.segmentation.__dict__[args.model](
pretrained=args.pretrained,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device)
if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -247,6 +266,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......
......@@ -18,6 +18,12 @@ except ImportError:
amp = None
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
......@@ -149,7 +155,12 @@ def main(args):
print("Loading validation data")
cache_path = _get_cache_path(valdir)
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
transform_test = weights.transforms()
if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
......@@ -200,7 +211,12 @@ def main(args):
)
print("Creating model")
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
if not args.weights:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.video.__dict__[args.model](weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -363,6 +379,9 @@ def parse_args():
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
args = parser.parse_args()
return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment