Unverified Commit 4cf67f53 authored by Mingyao Li's avatar Mingyao Li Committed by GitHub
Browse files

Increase SPOS accuracy (#2902)

* 🚧 fix spos always affine=False bug; change workers from to 4 to avoid dali breaking out; add dump_checkpoint to save the result searched by scratch.py

* 🐛

 fix bug of import os
Co-authored-by: default avatarlimingyao <limingyao@ainirobot.com>
parent 8c718139
......@@ -10,7 +10,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
"""
def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True):
super().__init__()
assert stride in [1, 2]
assert ksize in [3, 5, 7]
......@@ -22,6 +22,7 @@ class ShuffleNetBlock(nn.Module):
self.stride = stride
self.pad = ksize // 2
self.oup_main = oup - self.channels
self._affine = affine
assert self.oup_main > 0
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
......@@ -31,10 +32,10 @@ class ShuffleNetBlock(nn.Module):
# dw
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.BatchNorm2d(self.channels, affine=affine),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.BatchNorm2d(self.channels, affine=affine),
nn.ReLU(inplace=True)
)
......@@ -61,12 +62,12 @@ class ShuffleNetBlock(nn.Module):
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.BatchNorm2d(c, affine=self._affine))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.BatchNorm2d(c, affine=self._affine))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
......@@ -85,5 +86,5 @@ class ShuffleNetBlock(nn.Module):
class ShuffleXceptionBlock(ShuffleNetBlock):
def __init__(self, inp, oup, mid_channels, stride):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp")
def __init__(self, inp, oup, mid_channels, stride, affine=True):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine)
......@@ -21,7 +21,7 @@ class ShuffleNetV2OneShot(nn.Module):
]
def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000,
op_flops_path="./data/op_flops_dict.pkl"):
op_flops_path="./data/op_flops_dict.pkl", affine=False):
super().__init__()
assert input_size % 32 == 0
......@@ -36,11 +36,12 @@ class ShuffleNetV2OneShot(nn.Module):
self._first_conv_channels = first_conv_channels
self._last_conv_channels = last_conv_channels
self._n_classes = n_classes
self._affine = affine
# building first layer
self.first_conv = nn.Sequential(
nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(first_conv_channels, affine=False),
nn.BatchNorm2d(first_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self._feature_map_size //= 2
......@@ -54,7 +55,7 @@ class ShuffleNetV2OneShot(nn.Module):
self.conv_last = nn.Sequential(
nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=False),
nn.BatchNorm2d(last_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AvgPool2d(self._feature_map_size)
......@@ -75,10 +76,10 @@ class ShuffleNetV2OneShot(nn.Module):
base_mid_channels = channels // 2
mid_channels = int(base_mid_channels) # prepare for scale
choice_block = mutables.LayerChoice([
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride)
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
])
result.append(choice_block)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import argparse
import logging
import random
......@@ -70,12 +71,24 @@ def validate(epoch, model, criterion, loader, writer, args):
logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)
def dump_checkpoint(model, epoch, checkpoint_dir):
if isinstance(model, nn.DataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
dest_path = os.path.join(checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser("SPOS Training From Scratch")
parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
parser.add_argument("--tb-dir", type=str, default="runs")
parser.add_argument("--architecture", type=str, default="architecture_final.json")
parser.add_argument("--workers", type=int, default=12)
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=240)
parser.add_argument("--learning-rate", type=float, default=0.5)
......@@ -96,7 +109,7 @@ if __name__ == "__main__":
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot()
model = ShuffleNetV2OneShot(affine=True)
model.cuda()
apply_fixed_architecture(model, args.architecture)
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
......@@ -124,5 +137,6 @@ if __name__ == "__main__":
train(epoch, model, criterion, optimizer, train_loader, writer, args)
validate(epoch, model, criterion, val_loader, writer, args)
scheduler.step()
dump_checkpoint(model, epoch, "scratch_checkpoints")
writer.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment