"vscode:/vscode.git/clone" did not exist on "d798c9b8c6a20bffef4fb9632fc3cb6b60daf6a3"
Unverified Commit e1cb8faa authored by J-shang's avatar J-shang Committed by GitHub
Browse files

update exclude example (#4031)

parent 26c58399
...@@ -20,6 +20,7 @@ from torchvision import datasets, transforms ...@@ -20,6 +20,7 @@ from torchvision import datasets, transforms
sys.path.append('../models') sys.path.append('../models')
from mnist.lenet import LeNet from mnist.lenet import LeNet
from cifar10.vgg import VGG from cifar10.vgg import VGG
from cifar10.resnet import ResNet18
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
...@@ -119,6 +120,12 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite ...@@ -119,6 +120,12 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR( scheduler = MultiStepLR(
optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1) optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
elif args.model == 'resnet18':
model = ResNet18().to(device)
if args.pretrained_model_dir is None:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(
optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
else: else:
raise ValueError("model not recognized") raise ValueError("model not recognized")
...@@ -253,14 +260,19 @@ def main(args): ...@@ -253,14 +260,19 @@ def main(args):
'sparsity': args.sparsity, 'sparsity': args.sparsity,
'op_types': ['BatchNorm2d'], 'op_types': ['BatchNorm2d'],
}] }]
else: elif args.model == 'resnet18':
config_list = [{ config_list = [{
'sparsity': args.sparsity, 'sparsity': args.sparsity,
'op_types': ['Conv2d'], 'op_types': ['Conv2d']
'op_names': ['feature.0', 'feature.10', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}, { }, {
'exclude': True, 'exclude': True,
'op_names': ['feature.10'] 'op_names': ['layer1.0.conv1', 'layer1.0.conv2']
}]
else:
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}] }]
pruner = pruner_cls(model, config_list, **kw_args) pruner = pruner_cls(model, config_list, **kw_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment