Commit 78e8e038 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
from math import pi, cos, log, floor
from torch.optim.lr_scheduler import _LRScheduler
class CosineWarmupLR(_LRScheduler):
'''
Cosine lr decay function with warmup.
Ref: https://github.com/PistonY/torch-toolbox/blob/master/torchtoolbox/optimizer/lr_scheduler.py
https://github.com/Randl/MobileNetV3-pytorch/blob/master/cosine_with_warmup.py
Lr warmup is proposed by
`Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour`
`https://arxiv.org/pdf/1706.02677.pdf`
Cosine decay is proposed by
`Stochastic Gradient Descent with Warm Restarts`
`https://arxiv.org/abs/1608.03983`
Args:
optimizer (Optimizer): optimizer of a model.
iter_in_one_epoch (int): number of iterations in one epoch.
epochs (int): number of epochs to train.
lr_min (float): minimum(final) lr.
warmup_epochs (int): warmup epochs before cosine decay.
last_epoch (int): init iteration. In truth, this is last_iter
Attributes:
niters (int): number of iterations of all epochs.
warmup_iters (int): number of iterations of all warmup epochs.
cosine_iters (int): number of iterations of all cosine epochs.
'''
def __init__(self, optimizer, epochs, iter_in_one_epoch, lr_min=0, warmup_epochs=0, last_epoch=-1):
self.lr_min = lr_min
self.niters = epochs * iter_in_one_epoch
self.warmup_iters = iter_in_one_epoch * warmup_epochs
self.cosine_iters = iter_in_one_epoch * (epochs - warmup_epochs)
super(CosineWarmupLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_iters:
return [(self.lr_min + (base_lr - self.lr_min) * self.last_epoch / self.warmup_iters) for base_lr in self.base_lrs]
else:
return [(self.lr_min + (base_lr - self.lr_min) * (1 + cos(pi * (self.last_epoch - self.warmup_iters) / self.cosine_iters)) / 2) for base_lr in self.base_lrs]
class CosineAnnealingWarmRestarts(_LRScheduler):
'''
copied from https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#CosineAnnealingWarmRestarts
Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
is the number of epochs since the last restart and :math:`T_{i}` is the number
of epochs between two warm restarts in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{i}}\pi))
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_0 (int): Number of iterations for the first restart.
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
'''
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warmup_epochs=0, decay_rate=0.5):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
if warmup_epochs < 0 or not isinstance(warmup_epochs, int):
raise ValueError("Expected positive integer warmup_epochs, but got {}".format(warmup_epochs))
self.T_0 = T_0
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.warmup_epochs = warmup_epochs
self.decay_rate = decay_rate
self.decay_power = 0
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
self.T_cur = self.last_epoch
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
return [(self.eta_min + (base_lr - self.eta_min) * self.T_cur / self.warmup_epochs) for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr * (self.decay_rate**self.decay_power) - self.eta_min) * (1 + cos(pi * self.T_cur / self.T_i)) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
'''Step could be called after every batch update
Example:
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> iters = len(dataloader)
>>> for epoch in range(20):
>>> for i, sample in enumerate(dataloader):
>>> inputs, labels = sample['inputs'], sample['labels']
>>> scheduler.step(epoch + i / iters)
>>> optimizer.zero_grad()
>>> outputs = net(inputs)
>>> loss = criterion(outputs, labels)
>>> loss.backward()
>>> optimizer.step()
This function can be called in an interleaved way.
Example:
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
>>> for epoch in range(20):
>>> scheduler.step()
>>> scheduler.step(26)
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
'''
if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch < self.warmup_epochs:
self.T_cur = epoch
else:
epoch_cur = epoch - self.warmup_epochs
if epoch_cur >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch_cur % self.T_0
self.decay_power = epoch_cur // self.T_0
else:
n = int(log((epoch_cur / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.T_cur = epoch_cur - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
self.decay_power = n
else:
self.T_i = self.T_0
self.T_cur = epoch_cur
self.last_epoch = floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
\ No newline at end of file
'''
A new dataloader using NVIDIA DALI in order to speed up the dataloader in pytorch
Ref: https://github.com/d-li14/mobilenetv2.pytorch/blob/master/utils/dataloaders.py
https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py
'''
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from math import ceil
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
except ImportError:
print("Please install DALI from https://www.github.com/NVIDIA/DALI to run DataLoader.")
class TinyImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(TinyImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
# decide to work on cpu or gpu
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class TinyImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed):
super(TinyImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
output = self.cmnp(images)
return [output, self.labels]
class ImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(ImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
# decide to work on cpu or gpu
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
# without additional reallocations
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
'''
self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.75, 1.25],
random_area=[0.08, 1.0],
num_attempts=100)
self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR)
'''
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class ImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed):
super(ImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.res = ops.Resize(device='gpu', resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
class DALIWrapper(object):
def gen_wrapper(dali_pipeline):
for data in dali_pipeline:
input = data[0]['data']
target = data[0]['label'].squeeze().cuda().long()
yield input, target
def __init__(self, dali_pipeline):
self.dali_pipeline = dali_pipeline
def __iter__(self):
return DALIWrapper.gen_wrapper(self.dali_pipeline)
def get_dali_tinyImageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'train')
pipe = TinyImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=56, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_tinyImageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'val')
pipe = TinyImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=56, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))
def get_dali_imageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'ILSVRC2012_img_train')
pipe = ImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=224, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_imageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'ILSVRC2012_img_val')
pipe = ImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=224, size=256, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))
\ No newline at end of file
# -*- coding: UTF-8 -*-
'''
Image dataset loader
'''
from torchvision import transforms, datasets
import os
import torch
from PIL import Image
import scipy.io as scio
def Cifar10DataLoader(args):
data_transforms = {
'train': transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
}
image_datasets = {}
image_datasets['train'] = datasets.CIFAR10(root=args.data_dir, train=True, download=True, transform=data_transforms['train'])
image_datasets['val'] = datasets.CIFAR10(root=args.data_dir, train=False, download=True, transform=data_transforms['val'])
train_sampler = torch.utils.data.distributed.DistributedSampler(image_datasets['train'])
val_sampler = torch.utils.data.distributed.DistributedSampler(image_datasets['val'], shuffle=False, drop_last=True)
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,shuffle= (train_sampler is None),num_workers=args.num_workers, pin_memory=True,sampler=train_sampler if x=='train' else val_sampler) for x in ['train', 'val']}
return dataloders
def Cifar100DataLoader(args):
data_transforms = {
'train': transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
])
}
image_datasets = {}
image_datasets['train'] = datasets.CIFAR100(root=args.data_dir, train=True, download=True, transform=data_transforms['train'])
image_datasets['val'] = datasets.CIFAR100(root=args.data_dir, train=False, download=True, transform=data_transforms['val'])
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False,
num_workers=args.num_workers, pin_memory=True) for x in ['train', 'val']}
return dataloders
def ImageNetDataLoader(args):
# data transform
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(root=os.path.join(args.data_dir, 'ILSVRC2012_img_train'), transform=data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(root=os.path.join(args.data_dir, 'ILSVRC2012_img_val'), transform=data_transforms['val'])
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False,
num_workers=args.num_workers, pin_memory=True) for x in ['train', 'val']}
return dataloders
def TinyImageNetDataLoader(args):
# data transform
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(56),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.CenterCrop(56),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(root=os.path.join(args.data_dir, 'train'), transform=data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'), transform=data_transforms['val'])
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False,
num_workers=args.num_workers, pin_memory=True) for x in ['train', 'val']}
return dataloders
def SVHNDataLoader(args):
from SVHN import SVHN
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4309, 0.4302, 0.4463), (0.1965, 0.1983, 0.1994))
]),
'val': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4524, 0.4525, 0.4690), (0.2194, 0.2266, 0.2285))
])
}
image_datasets = {}
image_datasets['train'] = SVHN(root=os.path.join(args.data_dir, 'SVHN'), split='train', download=False, transform=data_transforms['train'])
image_datasets['val'] = SVHN(root=os.path.join(args.data_dir, 'SVHN'), split='test', download=False, transform=data_transforms['val'])
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False,
num_workers=args.num_workers, pin_memory=True) for x in ['train', 'val']}
return dataloders
def dataloaders(args):
dataset = args.dataset.lower()
assert dataset in ['imagenet', 'tinyimagenet', 'cifar10', 'cifar100', 'svhn']
if dataset == 'imagenet':
return ImageNetDataLoader(args)
elif dataset == 'tinyimagenet':
return TinyImageNetDataLoader(args)
elif dataset == 'cifar10':
return Cifar10DataLoader(args)
elif dataset == 'cifar100':
return Cifar100DataLoader(args)
elif dataset == 'svhn':
return SVHNDataLoader(args)
\ No newline at end of file
# -*- coding: UTF-8 -*-
'''
exponential moving average
Ref: https://blog.csdn.net/zhang2010hao/article/details/91599411
'''
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
\ No newline at end of file
# -*- coding: UTF-8 -*-
'''
Label Smoothing described in "Rethinking the Inception Architecture for Computer Vision"
Ref: https://github.com/PistonY/torch-toolbox/blob/master/torchtoolbox/nn/loss.py
https://github.com/whr94621/NJUNMT-pytorch/blob/master/src/modules/criterions.py
'''
import torch
from torch import nn
from torch.autograd import Variable
class LabelSmoothingLoss(nn.Module):
'''
Label Smoothing Loss function
'''
def __init__(self, classes_num, label_smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - label_smoothing
self.label_smoothing = label_smoothing
self.classes_num = classes_num
self.dim = dim
self.criterion = nn.KLDivLoss(reduction='batchmean')
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
smooth_label = torch.empty(size=pred.size(), device=target.device)
smooth_label.fill_(self.label_smoothing / (self.classes_num - 1))
smooth_label.scatter_(1, target.data.unsqueeze(1), self.confidence)
#return torch.mean(torch.sum(-smooth_label * pred, dim=self.dim))
return self.criterion(pred, Variable(smooth_label, requires_grad=False))
if __name__ == "__main__":
loss1 = LabelSmoothingLoss(5, 0.0)
predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.1, 0], [1, 0.2, 0.7, 0.1, 0]])
v1 = loss1(Variable(predict), Variable(torch.LongTensor([2, 1, 0])))
print(v1)
loss2 = nn.CrossEntropyLoss()
v2 = loss2(Variable(predict), Variable(torch.LongTensor([2, 1, 0])))
print(v2.data)
\ No newline at end of file
'''
mixup
Ref: https://github.com/BIGBALLON/CIFAR-ZOO/blob/master/utils.py
'''
import torch
import numpy as np
def mixup_data(x, y, alpha):
'''
Returns mixed inputs, pairs of targets, and lambda
'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
\ No newline at end of file
import torch.nn as nn
def noBiasDecay(model, lr, weight_decay):
'''
no bias decay : only apply weight decay to the weights in convolution and fully-connected layers
In paper [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187)
Ref: https://github.com/weiaicunzai/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks/blob/master/utils.py
'''
decay, bias_no_decay, weight_no_decay = [], [], []
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
decay.append(m.weight)
if m.bias is not None:
bias_no_decay.append(m.bias)
else:
if hasattr(m, 'weight'):
weight_no_decay.append(m.weight)
if hasattr(m, 'bias'):
bias_no_decay.append(m.bias)
assert len(list(model.parameters())) == len(decay) + len(bias_no_decay) + len(weight_no_decay)
# bias using 2*lr
return [{'params': bias_no_decay, 'lr': 2*lr, 'weight_decay': 0.0}, {'params': weight_no_decay, 'lr': lr, 'weight_decay': 0.0}, {'params': decay, 'lr': lr, 'weight_decay': weight_decay}]
\ No newline at end of file
# MobileNetV3
An implementation of MobileNetV3 with pyTorch
# Theory
&emsp;You can find the paper of MobileNetV3 at [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244).
# Prepare data
* CIFAR-10
* CIFAR-100
* SVHN
* Tiny-ImageNet
* ImageNet: Please move validation images to labeled subfolders, you can use the script [here](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
# Train
* Train from scratch:
```
CUDA_VISIBLE_DEVICES=3 python train.py --batch-size=128 --mode=small \
--print-freq=100 --dataset=CIFAR100 --ema-decay=0 --label-smoothing=0.1 \
--lr=0.3 --save-epoch-freq=1000 --lr-decay=cos --lr-min=0 \
--warmup-epochs=5 --weight-decay=6e-5 --num-epochs=200 --width-multiplier=1 \
-nbd -zero-gamma -mixup
```
where the meaning of the parameters are as followed:
```
batch-size
mode: using MobileNetV3-Small(if set to small) or MobileNetV3-Large(if set to large).
dataset: which dataset to use(CIFAR10, CIFAR100, SVHN, TinyImageNet or ImageNet).
ema-decay: decay of EMA, if set to 0, do not use EMA.
label-smoothing: $epsilon$ using in label smoothing, if set to 0, do not use label smoothing.
lr-decay: learning rate decay schedule, step or cos.
lr-min: min lr in cos lr decay.
warmup-epochs: warmup epochs using in cos lr deacy.
num-epochs: total training epochs.
nbd: no bias decay.
zero-gamma: zero $gamma$ of last BN in each block.
mixup: using Mixup.
```
# Pretrained models
&emsp;We have provided the pretrained MobileNetV3-Small model in `pretrained`.
# Experiments
## Training setting
### on ImageNet
```
CUDA_VISIBLE_DEVICES=5 python train.py --batch-size=128 --mode=small --print-freq=2000 --dataset=imagenet \
--ema-decay=0.99 --label-smoothing=0.1 --lr=0.1 --save-epoch-freq=50 --lr-decay=cos --lr-min=0 --warmup-epochs=5 \
--weight-decay=1e-5 --num-epochs=250 --num-workers=2 --width-multiplier=1 -dali -nbd -mixup -zero-gamma -save
```
### on CIFAR-10
```
CUDA_VISIBLE_DEVICES=1 python train.py --batch-size=128 --mode=small --print-freq=100 --dataset=CIFAR10\
--ema-decay=0 --label-smoothing=0 --lr=0.35 --save-epoch-freq=1000 --lr-decay=cos --lr-min=0\
--warmup-epochs=5 --weight-decay=6e-5 --num-epochs=400 --num-workers=2 --width-multiplier=1
```
### on CIFAR-100
```
CUDA_VISIBLE_DEVICES=1 python train.py --batch-size=128 --mode=small --print-freq=100 --dataset=CIFAR100\
--ema-decay=0 --label-smoothing=0 --lr=0.35 --save-epoch-freq=1000 --lr-decay=cos --lr-min=0\
--warmup-epochs=5 --weight-decay=6e-5 --num-epochs=400 --num-workers=2 --width-multiplier=1
```
&emsp;Using more tricks:
```
CUDA_VISIBLE_DEVICES=1 python train.py --batch-size=128 --mode=small --print-freq=100 --dataset=CIFAR100\
--ema-decay=0.999 --label-smoothing=0.1 --lr=0.35 --save-epoch-freq=1000 --lr-decay=cos --lr-min=0\
--warmup-epochs=5 --weight-decay=6e-5 --num-epochs=400 --num-workers=2 --width-multiplier=1\
-zero-gamma -nbd -mixup
```
### on SVHN
```
CUDA_VISIBLE_DEVICES=3 python train.py --batch-size=128 --mode=small --print-freq=1000 --dataset=SVHN\
--ema-decay=0 --label-smoothing=0 --lr=0.35 --save-epoch-freq=1000 --lr-decay=cos --lr-min=0\
--warmup-epochs=5 --weight-decay=6e-5 --num-epochs=20 --num-workers=2 --width-multiplier=1
```
### on Tiny-ImageNet
```
CUDA_VISIBLE_DEVICES=7 python train.py --batch-size=128 --mode=small --print-freq=100 --dataset=tinyimagenet\
--data-dir=/media/data2/chenjiarong/ImageData/tiny-imagenet --ema-decay=0 --label-smoothing=0 --lr=0.15\
--save-epoch-freq=1000 --lr-decay=cos --lr-min=0 --warmup-epochs=5 --weight-decay=6e-5 --num-epochs=200\
--num-workers=2 --width-multiplier=1 -dali
```
&emsp;Using more tricks:
```
CUDA_VISIBLE_DEVICES=7 python train.py --batch-size=128 --mode=small --print-freq=100 --dataset=tinyimagenet\
--data-dir=/media/data2/chenjiarong/ImageData/tiny-imagenet --ema-decay=0.999 --label-smoothing=0.1 --lr=0.15\
--save-epoch-freq=1000 --lr-decay=cos --lr-min=0 --warmup-epochs=5 --weight-decay=6e-5 --num-epochs=200\
--num-workers=2 --width-multiplier=1 -dali -nbd -mixup
```
## MobileNetV3-Large
### on ImageNet
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Offical 1.0 | 219 M | 5.4 M | 75.2% | - |
| Ours 1.0 | 216.6 M | 5.47 M | - | - |
### on CIFAR-10
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 66.47 M | 4.21 M | - | - |
### on CIFAR-100
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 66.58 M | 4.32 M | - | - |
## MobileNetV3-Small
### on ImageNet
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Offical 1.0 | 56.5 M | 2.53 M | 67.4% | - |
| Ours 1.0 | 56.51 M | 2.53 M | 67.52% | 87.58% |
&emsp;The pretrained model with top-1 accuracy 67.52% is provided in the folder [pretrained](https://github.com/ShowLo/MobileNetV3/tree/master/pretrained).
### on CIFAR-10 (Average accuracy of 5 runs)
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 17.51 M | 1.52 M | 92.97% | - |
### on CIFAR-100 (Average accuracy of 5 runs)
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 17.60 M | 1.61 M | 73.69% | 92.31% |
| More Tricks | same | same | 76.24% | 92.58% |
### on SVHN (Average accuracy of 5 runs)
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 17.51 M | 1.52 M | 97.92% | - |
### on Tiny-ImageNet (Average accuracy of 5 runs)
| | Madds | Parameters | Top1-acc | Top5-acc |
| ----------- | --------- | ---------- | --------- | --------- |
| Ours 1.0 | 51.63 M | 1.71 M | 59.32% | 81.38% |
| More Tricks | same | same | 62.62% | 84.04% |
## Dependency
&emsp;This project uses Python 3.7 and PyTorch 1.1.0. The FLOPs and Parameters and measured using [torchsummaryX](https://github.com/nmhkahn/torchsummaryX).
# -*- coding: UTF-8 -*-
'''
Write the result(every epoch) into file
'''
import os
import csv
class ResultWriter():
def __init__(self, save_folder, file_name):
super(ResultWriter, self).__init__()
self.save_path = os.path.join(save_folder, file_name)
self.csv_writer = None
def create_csv(self, csv_head):
with open(self.save_path, 'w') as f:
csv_writer = csv.writer(f)
csv_writer.writerow(csv_head)
def write_csv(self, data_row):
with open(self.save_path, 'a') as f:
csv_write = csv.writer(f)
csv_write.writerow(data_row)
\ No newline at end of file
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import numpy as np
class SVHN(VisionDataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
expect the class labels to be in the range `[0, C-1]`
Args:
root (string): Root directory of dataset where directory
``SVHN`` exists.
split (string): One of {'train', 'test', 'extra'}.
Accordingly dataset is selected. 'extra' is Extra training set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"merge_32x32.mat"],
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat",]}
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
super(SVHN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set
if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
# import here rather than at top of file because this is
# an optional dependency for torchvision
import scipy.io as sio
# reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat['X']
# loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
self.labels = loaded_mat['y'].astype(np.int64).squeeze()
# the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions
# which expect the class labels to be in the range [0, C-1]
np.place(self.labels, self.labels == 10, 0)
self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.labels[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
\ No newline at end of file
# -*- coding: UTF-8 -*-
'''
MobileNetV3 From <Searching for MobileNetV3>, arXiv:1905.02244.
Ref: https://github.com/d-li14/mobilenetv3.pytorch/blob/master/mobilenetv3.py
https://github.com/kuan-wang/pytorch-mobilenet-v3/blob/master/mobilenetv3.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import OrderedDict
def _ensure_divisible(number, divisor, min_value=None):
'''
Ensure that 'number' can be 'divisor' divisible
Reference from original tensorflow repo:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
'''
if min_value is None:
min_value = divisor
new_num = max(min_value, int(number + divisor / 2) // divisor * divisor)
if new_num < 0.9 * number:
new_num += divisor
return new_num
class H_sigmoid(nn.Module):
'''
hard sigmoid
'''
def __init__(self, inplace=True):
super(H_sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3, inplace=self.inplace) / 6
class H_swish(nn.Module):
'''
hard swish
'''
def __init__(self, inplace=True):
super(H_swish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3, inplace=self.inplace) / 6
class SEModule(nn.Module):
'''
SE Module
Ref: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py
'''
def __init__(self, in_channels_num, reduction_ratio=4):
super(SEModule, self).__init__()
if in_channels_num % reduction_ratio != 0:
raise ValueError('in_channels_num must be divisible by reduction_ratio(default = 4)')
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels_num, in_channels_num // reduction_ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels_num // reduction_ratio, in_channels_num, bias=False),
H_sigmoid()
)
def forward(self, x):
batch_size, channel_num, _, _ = x.size()
y = self.avg_pool(x).view(batch_size, channel_num)
y = self.fc(y).view(batch_size, channel_num, 1, 1)
return x * y
class Bottleneck(nn.Module):
'''
The basic unit of MobileNetV3
'''
def __init__(self, in_channels_num, exp_size, out_channels_num, kernel_size, stride, use_SE, NL, BN_momentum):
'''
use_SE: True or False -- use SE Module or not
NL: nonlinearity, 'RE' or 'HS'
'''
super(Bottleneck, self).__init__()
assert stride in [1, 2]
NL = NL.upper()
assert NL in ['RE', 'HS']
use_HS = NL == 'HS'
# Whether to use residual structure or not
self.use_residual = (stride == 1 and in_channels_num == out_channels_num)
if exp_size == in_channels_num:
# Without expansion, the first depthwise convolution is omitted
self.conv1 = nn.Sequential(
# Depthwise Convolution
nn.Conv2d(in_channels=in_channels_num, out_channels=exp_size, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=in_channels_num, bias=False),
nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),
# SE Module
SEModule(exp_size) if use_SE else nn.Sequential(),
H_swish() if use_HS else nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(
# Linear Pointwise Convolution
nn.Conv2d(in_channels=exp_size, out_channels=out_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
#nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
nn.Sequential(OrderedDict([('lastBN', nn.BatchNorm2d(num_features=out_channels_num))])) if self.use_residual else
nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
)
else:
# With expansion
self.conv1 = nn.Sequential(
# Pointwise Convolution for expansion
nn.Conv2d(in_channels=in_channels_num, out_channels=exp_size, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),
H_swish() if use_HS else nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(
# Depthwise Convolution
nn.Conv2d(in_channels=exp_size, out_channels=exp_size, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=exp_size, bias=False),
nn.BatchNorm2d(num_features=exp_size, momentum=BN_momentum),
# SE Module
SEModule(exp_size) if use_SE else nn.Sequential(),
H_swish() if use_HS else nn.ReLU(inplace=True),
# Linear Pointwise Convolution
nn.Conv2d(in_channels=exp_size, out_channels=out_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
#nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
nn.Sequential(OrderedDict([('lastBN', nn.BatchNorm2d(num_features=out_channels_num))])) if self.use_residual else
nn.BatchNorm2d(num_features=out_channels_num, momentum=BN_momentum)
)
def forward(self, x, expand=False):
out1 = self.conv1(x)
out = self.conv2(out1)
if self.use_residual:
out = out + x
if expand:
return out, out1
else:
return out
class MobileNetV3(nn.Module):
'''
'''
def __init__(self, mode='small', classes_num=1000, input_size=224, width_multiplier=1.0, dropout=0.2, BN_momentum=0.1, zero_gamma=False):
'''
configs: setting of the model
mode: type of the model, 'large' or 'small'
'''
super(MobileNetV3, self).__init__()
mode = mode.lower()
assert mode in ['large', 'small']
s = 2
if input_size == 32 or input_size == 56:
# using cifar-10, cifar-100 or Tiny-ImageNet
s = 1
# setting of the model
if mode == 'large':
# Configuration of a MobileNetV3-Large Model
configs = [
#kernel_size, exp_size, out_channels_num, use_SE, NL, stride
[3, 16, 16, False, 'RE', 1],
[3, 64, 24, False, 'RE', s],
[3, 72, 24, False, 'RE', 1],
[5, 72, 40, True, 'RE', 2],
[5, 120, 40, True, 'RE', 1],
[5, 120, 40, True, 'RE', 1],
[3, 240, 80, False, 'HS', 2],
[3, 200, 80, False, 'HS', 1],
[3, 184, 80, False, 'HS', 1],
[3, 184, 80, False, 'HS', 1],
[3, 480, 112, True, 'HS', 1],
[3, 672, 112, True, 'HS', 1],
[5, 672, 160, True, 'HS', 2],
[5, 960, 160, True, 'HS', 1],
[5, 960, 160, True, 'HS', 1]
]
elif mode == 'small':
# Configuration of a MobileNetV3-Small Model
configs = [
#kernel_size, exp_size, out_channels_num, use_SE, NL, stride
[3, 16, 16, True, 'RE', s],
[3, 72, 24, False, 'RE', 2],
[3, 88, 24, False, 'RE', 1],
[5, 96, 40, True, 'HS', 2],
[5, 240, 40, True, 'HS', 1],
[5, 240, 40, True, 'HS', 1],
[5, 120, 48, True, 'HS', 1],
[5, 144, 48, True, 'HS', 1],
[5, 288, 96, True, 'HS', 2],
[5, 576, 96, True, 'HS', 1],
[5, 576, 96, True, 'HS', 1]
]
first_channels_num = 16
# last_channels_num = 1280
# according to https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
# if small -- 1024, if large -- 1280
last_channels_num = 1280 if mode == 'large' else 1024
divisor = 8
########################################################################################################################
# feature extraction part
# input layer
input_channels_num = _ensure_divisible(first_channels_num * width_multiplier, divisor)
last_channels_num = _ensure_divisible(last_channels_num * width_multiplier, divisor) if width_multiplier > 1 else last_channels_num
feature_extraction_layers = []
first_layer = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=input_channels_num, kernel_size=3, stride=s, padding=1, bias=False),
nn.BatchNorm2d(num_features=input_channels_num, momentum=BN_momentum),
H_swish()
)
feature_extraction_layers.append(first_layer)
# Overlay of multiple bottleneck structures
for kernel_size, exp_size, out_channels_num, use_SE, NL, stride in configs:
output_channels_num = _ensure_divisible(out_channels_num * width_multiplier, divisor)
exp_size = _ensure_divisible(exp_size * width_multiplier, divisor)
feature_extraction_layers.append(Bottleneck(input_channels_num, exp_size, output_channels_num, kernel_size, stride, use_SE, NL, BN_momentum))
input_channels_num = output_channels_num
# the last stage
last_stage_channels_num = _ensure_divisible(exp_size * width_multiplier, divisor)
last_stage_layer1 = nn.Sequential(
nn.Conv2d(in_channels=input_channels_num, out_channels=last_stage_channels_num, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(num_features=last_stage_channels_num, momentum=BN_momentum),
H_swish()
)
feature_extraction_layers.append(last_stage_layer1)
self.featureList = nn.ModuleList(feature_extraction_layers)
# SE Module
# remove the last SE Module according to https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
# feature_extraction_layers.append(SEModule(last_stage_channels_num) if mode == 'small' else nn.Sequential())
last_stage = []
last_stage.append(nn.AdaptiveAvgPool2d(1))
last_stage.append(nn.Conv2d(in_channels=last_stage_channels_num, out_channels=last_channels_num, kernel_size=1, stride=1, padding=0, bias=False))
last_stage.append(H_swish())
self.last_stage_layers = nn.Sequential(*last_stage)
########################################################################################################################
# Classification part
self.classifier = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(last_channels_num, classes_num)
)
'''
self.extras = nn.ModuleList([
InvertedResidual(576, 512, 2, 0.2),
InvertedResidual(512, 256, 2, 0.25),
InvertedResidual(256, 256, 2, 0.5),
InvertedResidual(256, 64, 2, 0.25)
])
'''
########################################################################################################################
# Initialize the weights
self._initialize_weights(zero_gamma)
def forward(self, x):
for i in range(9):
x = self.featureList[i](x)
x = self.featureList[9](x)
for i in range(10, len(self.featureList)):
x = self.featureList[i](x)
x = self.last_stage_layers(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _initialize_weights(self, zero_gamma):
'''
Initialize the weights
'''
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if zero_gamma:
for m in self.modules():
if hasattr(m, 'lastBN'):
nn.init.constant_(m.lastBN.weight, 0.0)
if __name__ == "__main__":
width_multiplier = 1
from torchsummaryX import summary
# cifar10
#model_large = MobileNetV3(mode='large', classes_num=10, input_size=32, width_multiplier=width_multiplier)
#model_large.eval()
'''
model_small = MobileNetV3(mode='small', classes_num=10, input_size=32, width_multiplier=width_multiplier)
model_small.eval()
summary(model_small, torch.zeros((1, 3, 32, 32)))
print('MobileNetV3-Small-%.2f cifar10-summaryX\n' % width_multiplier)
'''
'''
# cifar100
model_large = MobileNetV3(mode='large', classes_num=100, input_size=32, width_multiplier=width_multiplier)
model_large.eval()
model_small = MobileNetV3(mode='small', classes_num=100, input_size=32, width_multiplier=width_multiplier)
model_small.eval()
input = torch.randn(1, 3, 32, 32)
from thop import profile
FLOPs_large, params_large = profile(model_large, inputs=(input,))
FLOPs_small, params_small = profile(model_small, inputs=(input,))
print('\nOn cifar100 using thop')
print('MobileNetV3-Large-%.2f:' % width_multiplier)
print('Total flops: %.4fM' % (FLOPs_large/1000000.0))
print('Total params: %.4fM' % (params_large/1000000.0))
print()
print('MobileNetV3-Small-%.2f:' % width_multiplier)
print('Total flops: %.4fM' % (FLOPs_small/1000000.0))
print('Total params: %.4fM' % (params_small/1000000.0))
summary(model_large, torch.zeros((1, 3, 32, 32)))
print('MobileNetV3-Large-%.2f cifar100-summaryX\n' % width_multiplier)
summary(model_small, torch.zeros((1, 3, 32, 32)))
print('MobileNetV3-Small-%.2f cifar100-summaryX\n' % width_multiplier)
# ImageNet
model_large = MobileNetV3(mode='large', classes_num=1000, input_size=224, width_multiplier=width_multiplier)
model_large.eval()
model_small = MobileNetV3(mode='small', classes_num=1000, input_size=224, width_multiplier=width_multiplier)
model_small.eval()
input = torch.randn(1, 3, 224, 224)
from thop import profile
FLOPs_large, params_large = profile(model_large, inputs=(input,))
FLOPs_small, params_small = profile(model_small, inputs=(input,))
print('\nOn ImageNet using thop')
print('MobileNetV3-Large-%.2f:' % width_multiplier)
print('Total flops: %.4fM' % (FLOPs_large/1000000.0))
print('Total params: %.4fM' % (params_large/1000000.0))
print()
print('MobileNetV3-Small-%.2f:' % width_multiplier)
print('Total flops: %.4fM' % (FLOPs_small/1000000.0))
print('Total params: %.4fM' % (params_small/1000000.0))
'''
'''
summary(model_large, torch.zeros((1, 3, 224, 224)))
print('MobileNetV3-Large-%.2f ImageNet-summaryX\n' % width_multiplier)
summary(model_small, torch.zeros((1, 3, 224, 224)))
print('MobileNetV3-Small-%.2f ImageNet-summaryX\n' % width_multiplier)
'''
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment