Unverified Commit 96d30baf authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Refine weight loading in SPOS (#4755)

parent 047a9e22
......@@ -66,14 +66,33 @@ def test_acc(model, criterion, log_freq, loader):
return meters.acc1.avg
def evaluate_acc(class_cls, criterion, args, train_dataset, val_dataset):
def evaluate_acc(class_cls, criterion, args):
model = class_cls()
with original_state_dict_hooks(model):
model.load_state_dict(load_and_parse_state_dict(args.checkpoint), strict=False)
model.cuda()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=args.workers)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch_size, num_workers=args.workers)
if args.spos_preprocessing:
train_trans = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(0.5),
ToBGRTensor()
])
else:
train_trans = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor()
])
val_trans = transforms.Compose([
transforms.RandomResizedCrop(224),
ToBGRTensor()
])
train_dataset = datasets.ImageNet(args.imagenet_dir, split='train', transform=train_trans)
val_dataset = datasets.ImageNet(args.imagenet_dir, split='val', transform=val_trans)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=args.workers, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch_size, num_workers=args.workers, shuffle=True)
acc_before = test_acc(model, criterion, args.log_frequency, test_loader)
nni.report_intermediate_result(acc_before)
......@@ -144,29 +163,13 @@ def _main():
base_model = ShuffleNetV2OneShot()
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
if args.spos_preprocessing:
# ``nni.trace`` is used to make transforms serializable, so that the trials can run other processes or on remote servers.
trans = nni.trace(transforms.Compose)([
nni.trace(transforms.RandomResizedCrop)(224),
nni.trace(transforms.ColorJitter)(brightness=0.4, contrast=0.4, saturation=0.4),
nni.trace(transforms.RandomHorizontalFlip)(0.5),
nni.trace(ToBGRTensor)(),
])
else:
# ``nni.trace`` is used to make transforms serializable, so that the trials can run other processes or on remote servers.
trans = nni.trace(transforms.Compose)([
nni.trace(transforms.RandomResizedCrop)(224),
nni.trace(transforms.ToTensor)()
])
train_dataset = nni.trace(datasets.ImageNet)(args.imagenet_dir, split='train', transform=trans)
val_dataset = nni.trace(datasets.ImageNet)(args.imagenet_dir, split='val', transform=trans)
if args.latency_filter:
latency_filter = LatencyFilter(threshold=args.latency_threshold, predictor=args.latency_filter)
else:
latency_filter = None
evaluator = FunctionalEvaluator(evaluate_acc, criterion=criterion, args=args, train_dataset=train_dataset, val_dataset=val_dataset)
evaluator = FunctionalEvaluator(evaluate_acc, criterion=criterion, args=args)
evolution_strategy = strategy.RegularizedEvolution(
model_filter=latency_filter,
sample_size=args.evolution_sample_size, population_size=args.evolution_population_size, cycles=args.evolution_cycles)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment