Commit e1aa783c authored by sugon_cxj's avatar sugon_cxj
Browse files

first commit

parent 2e9800bb
Pipeline #527 canceled with stages
import numpy as np
import torch
from typing import Callable
__all__=['Accuracy', 'TopkAccuracy']
from abc import ABC, abstractmethod
from typing import Callable, Union, Any, Mapping, Sequence
import numbers
import numpy as np
class Metric(ABC):
@abstractmethod
def update(self, pred, target):
""" Overridden by subclasses """
raise NotImplementedError()
@abstractmethod
def get_results(self):
""" Overridden by subclasses """
raise NotImplementedError()
@abstractmethod
def reset(self):
""" Overridden by subclasses """
raise NotImplementedError()
class MetricCompose(dict):
def __init__(self, metric_dict: Mapping):
self._metric_dict = metric_dict
@property
def metrics(self):
return self._metric_dict
@torch.no_grad()
def update(self, outputs, targets):
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
metric.update(outputs, targets)
def get_results(self):
results = {}
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
results[key] = metric.get_results()
return results
def reset(self):
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
metric.reset()
def __getitem__(self, name):
return self._metric_dict[name]
class Accuracy(Metric):
def __init__(self):
self.reset()
@torch.no_grad()
def update(self, outputs, targets):
outputs = outputs.max(1)[1]
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum()
self._cnt += torch.numel( targets )
def get_results(self):
return (self._correct / self._cnt * 100.).detach().cpu()
def reset(self):
self._correct = self._cnt = 0.0
class TopkAccuracy(Metric):
def __init__(self, topk=(1, 5)):
self._topk = topk
self.reset()
@torch.no_grad()
def update(self, outputs, targets):
for k in self._topk:
_, topk_outputs = outputs.topk(k, dim=1, largest=True, sorted=True)
correct = topk_outputs.eq( targets.view(-1, 1).expand_as(topk_outputs) )
self._correct[k] += correct[:, :k].view(-1).float().sum(0).item()
self._cnt += len(targets)
def get_results(self):
return tuple( self._correct[k] / self._cnt * 100. for k in self._topk )
def reset(self):
self._correct = {k: 0 for k in self._topk}
self._cnt = 0.0
class RunningLoss(Metric):
def __init__(self, loss_fn, is_batch_average=False):
self.reset()
self.loss_fn = loss_fn
self.is_batch_average = is_batch_average
@torch.no_grad()
def update(self, outputs, targets):
self._accum_loss += self.loss_fn(outputs, targets)
if self.is_batch_average:
self._cnt += 1
else:
self._cnt += len(outputs)
def get_results(self):
return (self._accum_loss / self._cnt).detach().cpu()
def reset(self):
self._accum_loss = self._cnt = 0.0
\ No newline at end of file
from contextlib import contextmanager
import logging
import os, sys
from termcolor import colored
import copy
import numpy as np
import torch
class MagnitudeRecover():
def __init__(self, model, reg=1e-3):
self.rec = {}
self.reg = reg
self.cnt = 0
with torch.no_grad():
for name, p in model.named_parameters():
norm = p.pow(2).mean()
self.rec[name] = norm
def regularize(self, model):
with torch.no_grad():
for name, p in model.named_parameters():
if name in self.rec:
target_norm = self.rec[name]
if p.data.pow(2).mean() > target_norm:
self.rec.pop(name)
continue
p.grad.data+= -self.reg * p.data
if self.cnt%1000==0:
print(name, p.pow(2).mean(), target_norm)
self.cnt+=1
def flatten_dict(dic):
flattned = dict()
def _flatten(prefix, d):
for k, v in d.items():
if isinstance(v, dict):
if prefix is None:
_flatten( k, v )
else:
_flatten( prefix+'/%s'%k, v )
else:
if prefix is None:
flattned[k] = v
else:
flattned[ prefix+'/%s'%k ] = v
_flatten(None, dic)
return flattned
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "yellow", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
def get_logger(name='train', output=None, color=True):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# STDOUT
stdout_handler = logging.StreamHandler( stream=sys.stdout )
stdout_handler.setLevel( logging.DEBUG )
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" )
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S")
else:
formatter = plain_formatter
stdout_handler.setFormatter(formatter)
logger.addHandler(stdout_handler)
# FILE
if output is not None:
if output.endswith('.txt') or output.endswith('.log'):
os.makedirs(os.path.dirname(output), exist_ok=True)
filename = output
else:
os.makedirs(output, exist_ok=True)
filename = os.path.join(output, "log.txt")
file_handler = logging.FileHandler(filename)
file_handler.setFormatter(plain_formatter)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
return logger
\ No newline at end of file
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from functools import partial
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
import engine.utils as utils
import registry
import time
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("--mode", type=str, required=True, choices=["prune", "test"])
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--dataset", type=str, default="cifar100", choices=['cifar10', 'cifar100'])
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--total-epochs", type=int, default=100)
parser.add_argument("--lr-decay-milestones", default="60,80", type=str, help="milestones for learning rate decay")
parser.add_argument("--lr-decay-gamma", default=0.1, type=float)
parser.add_argument("--lr", default=0.01, type=float, help="learning rate")
parser.add_argument("--restore", type=str, default=None)
parser.add_argument('--output-dir', default='run', help='path where to save')
# For pruning
parser.add_argument("--method", type=str, default=None)
parser.add_argument("--speed-up", type=float, default=2)
parser.add_argument("--max-sparsity", type=float, default=1.0)
parser.add_argument("--soft-keeping-ratio", type=float, default=0.0)
parser.add_argument("--reg", type=float, default=5e-4)
parser.add_argument("--delta_reg", type=float, default=1e-4, help='for growing regularization')
parser.add_argument("--weight-decay", type=float, default=5e-4)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--global-pruning", action="store_true", default=False)
parser.add_argument("--sl-total-epochs", type=int, default=100, help="epochs for sparsity learning")
parser.add_argument("--sl-lr", default=0.01, type=float, help="learning rate for sparsity learning")
parser.add_argument("--sl-lr-decay-milestones", default="60,80", type=str, help="milestones for sparsity learning")
parser.add_argument("--sl-reg-warmup", type=int, default=0, help="epochs for sparsity learning")
parser.add_argument("--sl-restore", type=str, default=None)
parser.add_argument("--iterative-steps", default=400, type=int)
args = parser.parse_args()
def progressive_pruning(pruner, model, speed_up, example_inputs):
model.eval()
base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = 1
while current_speed_up < speed_up:
pruner.step(interactive=False)
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = float(base_ops) / pruned_ops
if pruner.current_step == pruner.iterative_steps:
break
return current_speed_up
def eval(model, test_loader, device=None):
correct = 0
total = 0
loss = 0
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
num_test = 0
t1 = time.perf_counter()
with torch.no_grad():
for i, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
out = model(data)
loss += F.cross_entropy(out, target, reduction="sum")
pred = out.max(1)[1]
correct += (pred == target).sum()
total += len(target)
num_test = i + 1
t2 = time.perf_counter()
eval_time = t2 - t1
return (correct / total).item(), (loss / total).item(), (eval_time / total)
def train_model(
model,
train_loader,
test_loader,
epochs,
lr,
lr_decay_milestones,
lr_decay_gamma=0.1,
save_as=None,
# For pruning
weight_decay=5e-4,
save_state_dict_only=True,
pruner=None,
device=None,
):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.SGD(
model.parameters(),
lr=lr,
momentum=0.9,
weight_decay=weight_decay if pruner is None else 0,
)
milestones = [int(ms) for ms in lr_decay_milestones.split(",")]
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=milestones, gamma=lr_decay_gamma
)
model.to(device)
best_acc = -1
for epoch in range(epochs):
model.train()
for i, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = F.cross_entropy(out, target)
loss.backward()
if pruner is not None:
pruner.regularize(model) # for sparsity learning
optimizer.step()
if i % 10 == 0 and args.verbose:
args.logger.info(
"Epoch {:d}/{:d}, iter {:d}/{:d}, loss={:.4f}, lr={:.4f}".format(
epoch,
epochs,
i,
len(train_loader),
loss.item(),
optimizer.param_groups[0]["lr"],
)
)
if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner):
pruner.update_reg() # increase the strength of regularization
#print(pruner.group_reg[pruner._groups[0]])
model.eval()
acc, val_loss = eval(model, test_loader, device=device)
args.logger.info(
"Epoch {:d}/{:d}, Acc={:.4f}, Val Loss={:.4f}, lr={:.4f}".format(
epoch, epochs, acc, val_loss, optimizer.param_groups[0]["lr"]
)
)
if best_acc < acc:
os.makedirs(args.output_dir, exist_ok=True)
if args.mode == "prune":
if save_as is None:
save_as = os.path.join( args.output_dir, "{}_{}_{}.pt".format(args.dataset, args.model, args.method) )
if save_state_dict_only:
torch.save(model.state_dict(), save_as)
else:
torch.save(model, save_as)
best_acc = acc
scheduler.step()
args.logger.info("Best Acc=%.4f" % (best_acc))
def get_pruner(model, example_inputs):
args.sparsity_learning = False
if args.method == "random":
imp = tp.importance.RandomImportance()
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "l1":
imp = tp.importance.MagnitudeImportance(p=1)
pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=args.global_pruning)
elif args.method == "lamp":
imp = tp.importance.LAMPImportance(p=2)
pruner_entry = partial(tp.pruner.BNScalePruner, global_pruning=args.global_pruning)
elif args.method == "slim":
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "group_slim":
args.sparsity_learning = True
imp = tp.importance.BNScaleImportance()
pruner_entry = partial(tp.pruner.BNScalePruner, reg=args.reg, global_pruning=args.global_pruning, group_lasso=True)
elif args.method == "group_norm":
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning)
elif args.method == "group_sl":
args.sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GroupNormPruner, reg=args.reg, global_pruning=args.global_pruning)
elif args.method == "growing_reg":
args.sparsity_learning = True
imp = tp.importance.GroupNormImportance(p=2)
pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning)
else:
raise NotImplementedError
#args.is_accum_importance = is_accum_importance
unwrapped_parameters = []
ignored_layers = []
ch_sparsity_dict = {}
# ignore output layers
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == args.num_classes:
ignored_layers.append(m)
elif isinstance(m, torch.nn.modules.conv._ConvNd) and m.out_channels == args.num_classes:
ignored_layers.append(m)
# Here we fix iterative_steps=200 to prune the model progressively with small steps
# until the required speed up is achieved.
pruner = pruner_entry(
model,
example_inputs,
importance=imp,
iterative_steps=args.iterative_steps,
ch_sparsity=1.0,
ch_sparsity_dict=ch_sparsity_dict,
max_ch_sparsity=args.max_sparsity,
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters,
)
return pruner
def main():
if args.seed is not None:
torch.manual_seed(args.seed)
# Logger
if args.mode == "prune":
prefix = 'global' if args.global_pruning else 'local'
logger_name = "{}-{}-{}-{}".format(args.dataset, prefix, args.method, args.model)
args.output_dir = os.path.join(args.output_dir, args.dataset, args.mode, logger_name)
log_file = "{}/{}.txt".format(args.output_dir, logger_name)
elif args.mode == "test":
log_file = None
logger_name = None
args.logger = utils.get_logger(logger_name, output=log_file)
# Model & Dataset
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes, train_dst, val_dst, input_size = registry.get_dataset(
args.dataset, data_root="data"
)
args.num_classes = num_classes
model = registry.get_model(args.model, num_classes=num_classes, pretrained=True, target_dataset=args.dataset)
train_loader = torch.utils.data.DataLoader(
train_dst,
batch_size=args.batch_size,
num_workers=4,
drop_last=True,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
val_dst, batch_size=args.batch_size, num_workers=4
)
for k, v in utils.utils.flatten_dict(vars(args)).items(): # print args
args.logger.info("%s: %s" % (k, v))
if args.restore is not None:
loaded = torch.load(args.restore, map_location="cpu")
if isinstance(loaded, nn.Module):
model = loaded
else:
model.load_state_dict(loaded)
args.logger.info("Loading model from {restore}".format(restore=args.restore))
model = model.to(args.device)
######################################################
# Pruning / Testing
example_inputs = train_dst[0][0].unsqueeze(0).to(args.device)
if args.mode == "prune":
pruner = get_pruner(model, example_inputs=example_inputs)
# 0. Sparsity Learning
if args.sparsity_learning:
reg_pth = "reg_{}_{}_{}_{}.pth".format(args.dataset, args.model, args.method, args.reg)
reg_pth = os.path.join( os.path.join(args.output_dir, reg_pth) )
if not args.sl_restore:
args.logger.info("Regularizing...")
train_model(
model,
train_loader=train_loader,
test_loader=test_loader,
epochs=args.sl_total_epochs,
lr=args.sl_lr,
lr_decay_milestones=args.sl_lr_decay_milestones,
lr_decay_gamma=args.lr_decay_gamma,
pruner=pruner,
save_state_dict_only=True,
save_as = reg_pth,
)
args.logger.info("Loading the sparse model from {}...".format(reg_pth))
model.load_state_dict( torch.load( reg_pth, map_location=args.device) )
# 1. Pruning
model.eval()
ori_ops, ori_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
ori_acc, ori_val_loss = eval(model, test_loader, device=args.device)
args.logger.info("Pruning...")
progressive_pruning(pruner, model, speed_up=args.speed_up, example_inputs=example_inputs)
del pruner # remove reference
args.logger.info(model)
pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
pruned_acc, pruned_val_loss = eval(model, test_loader, device=args.device)
args.logger.info(
"Params: {:.2f} M => {:.2f} M ({:.2f}%)".format(
ori_size / 1e6, pruned_size / 1e6, pruned_size / ori_size * 100
)
)
args.logger.info(
"FLOPs: {:.2f} M => {:.2f} M ({:.2f}%, {:.2f}X )".format(
ori_ops / 1e6,
pruned_ops / 1e6,
pruned_ops / ori_ops * 100,
ori_ops / pruned_ops,
)
)
args.logger.info("Acc: {:.4f} => {:.4f}".format(ori_acc, pruned_acc))
args.logger.info(
"Val Loss: {:.4f} => {:.4f}".format(ori_val_loss, pruned_val_loss)
)
# 2. Finetuning
args.logger.info("Finetuning...")
train_model(
model,
epochs=args.total_epochs,
lr=args.lr,
lr_decay_milestones=args.lr_decay_milestones,
train_loader=train_loader,
test_loader=test_loader,
device=args.device,
save_state_dict_only=False,
)
elif args.mode == "test":
model.eval()
ops, params = tp.utils.count_ops_and_params(
model, example_inputs=example_inputs,
)
args.logger.info("Params: {:.2f} M".format(params / 1e6))
args.logger.info("ops: {:.2f} M".format(ops / 1e6))
acc, val_loss, eval_time = eval(model, test_loader)
args.logger.info("Acc: {:.4f} Val Loss: {:.4f} eval time: {}\n".format(acc, val_loss, eval_time))
if __name__ == "__main__":
main()
# 模型名称
modelName=mobilenet_prune
# 模型描述
modelDescription=mobilenet_prune是一个实现对mobilenet、efficientnet剪枝的demo
# 应用场景
appScenario=剪枝、微调
# 框架类型
frameType=pytorch
from pyexpat import model
from torchvision import datasets, transforms as T
from PIL import PngImagePlugin
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
import os, sys
import engine.models as models
import engine.utils as utils
from functools import partial
NORMALIZE_DICT = {
'cifar10': dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
'cifar100': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
'cifar10_224': dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010) ),
'cifar100_224': dict( mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761) ),
}
MODEL_DICT = {
'mobilenetv2': models.cifar.mobilenetv2.mobilenetv2,
'mobilenetv3': models.cifar.mobilenetv3.mobilenetv3,
'efficientnet': models.cifar.efficientnet.efficientnet,
}
def get_model(name: str, num_classes, pretrained=False, target_dataset='cifar', **kwargs):
if 'cifar' in target_dataset:
model = MODEL_DICT[name](num_classes=num_classes)
return model
def get_dataset(name: str, data_root: str='data', return_transform=False):
name = name.lower()
data_root = os.path.expanduser( data_root )
if name=='cifar10':
num_classes = 10
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize( **NORMALIZE_DICT[name] ),
])
val_transform = T.Compose([
T.ToTensor(),
T.Normalize( **NORMALIZE_DICT[name] ),
])
data_root = os.path.join( data_root, 'torchdata' )
train_dst = datasets.CIFAR10(data_root, train=True, download=True, transform=train_transform)
val_dst = datasets.CIFAR10(data_root, train=False, download=False, transform=val_transform)
input_size = (1, 3, 32, 32)
elif name=='cifar100':
num_classes = 100
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize( **NORMALIZE_DICT[name] ),
])
val_transform = T.Compose([
T.ToTensor(),
T.Normalize( **NORMALIZE_DICT[name] ),
])
data_root = os.path.join( data_root, 'torchdata' )
train_dst = datasets.CIFAR100(data_root, train=True, download=True, transform=train_transform)
val_dst = datasets.CIFAR100(data_root, train=False, download=True, transform=val_transform)
input_size = (1, 3, 32, 32)
else:
raise NotImplementedError
if return_transform:
return num_classes, train_dst, val_dst, input_size, train_transform, val_transform
return num_classes, train_dst, val_dst, input_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment