Unverified Commit b1fc2903 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update Reference scripts to support the prototype models (#4837)

* Adding prototype preprocessing on segmentation references.

* Adding prototype preprocessing on video references.
parent 140322f9
...@@ -11,6 +11,12 @@ from coco_utils import get_coco ...@@ -11,6 +11,12 @@ from coco_utils import get_coco
from torch import nn from torch import nn
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def get_dataset(dir_path, name, image_set, transform): def get_dataset(dir_path, name, image_set, transform):
def sbd(*args, **kwargs): def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
...@@ -26,11 +32,15 @@ def get_dataset(dir_path, name, image_set, transform): ...@@ -26,11 +32,15 @@ def get_dataset(dir_path, name, image_set, transform):
return ds, num_classes return ds, num_classes
def get_transform(train): def get_transform(train, args):
base_size = 520 if train:
crop_size = 480 return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights:
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) return presets.SegmentationPresetEval(base_size=520)
else:
fn = PM.segmentation.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()
def criterion(inputs, target): def criterion(inputs, target):
...@@ -90,8 +100,8 @@ def main(args): ...@@ -90,8 +100,8 @@ def main(args):
device = torch.device(args.device) device = torch.device(args.device)
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
...@@ -113,9 +123,18 @@ def main(args): ...@@ -113,9 +123,18 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
) )
model = torchvision.models.segmentation.__dict__[args.model]( if not args.weights:
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained model = torchvision.models.segmentation.__dict__[args.model](
) pretrained=args.pretrained,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device) model.to(device)
if args.distributed: if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -247,6 +266,9 @@ def get_args_parser(add_help=True): ...@@ -247,6 +266,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
...@@ -18,6 +18,12 @@ except ImportError: ...@@ -18,6 +18,12 @@ except ImportError:
amp = None amp = None
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
...@@ -149,7 +155,12 @@ def main(args): ...@@ -149,7 +155,12 @@ def main(args):
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
else:
fn = PM.video.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
transform_test = weights.transforms()
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
...@@ -200,7 +211,12 @@ def main(args): ...@@ -200,7 +211,12 @@ def main(args):
) )
print("Creating model") print("Creating model")
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) if not args.weights:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.video.__dict__[args.model](weights=args.weights)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -363,6 +379,9 @@ def parse_args(): ...@@ -363,6 +379,9 @@ def parse_args():
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
args = parser.parse_args() args = parser.parse_args()
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment