Unverified Commit ad0daff1 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add groups support to ResNet (#822)

* Add groups support to ResNet

* Kill BaseResNet

* Make it support multi-machine training
parent 8697f9e0
......@@ -8,3 +8,4 @@ torchvision.egg-info/
docs/build
.coverage
htmlcov
.*.swp
......@@ -60,17 +60,8 @@ def evaluate(model, criterion, data_loader, device):
def main(args):
args.gpu = args.local_rank
if args.distributed:
args.rank = int(os.environ["RANK"])
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
dist_url = 'env://'
print('| distributed init (rank {}): {}'.format(
args.rank, dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=dist_url)
utils.setup_for_distributed(args.rank == 0)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
......@@ -203,15 +194,15 @@ if __name__ == "__main__":
help="Only test the model",
action="store_true",
)
parser.add_argument('--local_rank', default=0, type=int, help='print frequency')
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
args = parser.parse_args()
print(args)
if args.output_dir:
utils.mkdir(args.output_dir)
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
main(args)
......@@ -211,3 +211,27 @@ def is_main_process():
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
setup_for_distributed(args.rank == 0)
......@@ -3,7 +3,7 @@ import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
model_urls = {
......@@ -15,10 +15,10 @@ model_urls = {
}
def conv3x3(in_planes, out_planes, stride=1):
def conv3x3(in_planes, out_planes, stride=1, groups=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
padding=1, groups=groups, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
......@@ -29,10 +29,12 @@ def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1:
raise ValueError('BasicBlock only supports groups=1')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
......@@ -64,14 +66,14 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.conv2 = conv3x3(planes, planes, stride, groups)
self.bn2 = norm_layer(planes)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
......@@ -104,22 +106,24 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, norm_layer=None):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1,width_per_group=64, norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
self.inplanes = planes[0]
self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(64)
self.bn1 = norm_layer(planes[0])
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer)
self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer)
self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer)
self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = nn.Linear(planes[3] * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
......@@ -138,7 +142,7 @@ class ResNet(nn.Module):
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
downsample = None
......@@ -149,10 +153,10 @@ class ResNet(nn.Module):
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, norm_layer))
layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, norm_layer=norm_layer))
layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer))
return nn.Sequential(*layers)
......@@ -232,3 +236,17 @@ def resnet152(pretrained=False, **kwargs):
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
def resnext50_32x4d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, **kwargs)
#if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnext101_32x8d(pretrained=False, **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, **kwargs)
#if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment