Commit e74519bb authored by pangjm's avatar pangjm
Browse files

minor fix

parent 559b0558
......@@ -35,7 +35,33 @@ def _init_dist_mpi(backend, **kwargs):
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
if '[' in node_list:
beg = node_list.find('[')
pos1 = node_list.find('-', beg)
if pos1 < 0:
pos1 = 1000
pos2 = node_list.find(',', beg)
if pos2 < 0:
pos2 = 1000
node_list = node_list[:min(pos1, pos2)].replace('[', '')
addr = node_list[8:].replace('-', '.')
os.environ['MASTER_PORT'] = str(kwargs['port'])
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
if backend == 'nccl':
dist.init_process_group(backend='nccl')
else:
dist.init_process_group(
backend='gloo', rank=proc_id, world_size=ntasks)
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def set_random_seed(seed):
......
from .resnet import ResNet
from .resnext import ResNeXt
__all__ = ['ResNet']
__all__ = ['ResNet', 'ResNeXt']
......@@ -219,9 +219,13 @@ class ResNet(nn.Module):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
self.depth = depth,
self.num_stages = num_stages,
self.strides = strides,
self.dilations = dilations,
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.block, self.stage_blocks = self.arch_settings[depth]
self.stage_blocks = self.stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
......@@ -240,12 +244,12 @@ class ResNet(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.res_layers = []
for i, num_blocks in enumerate(stage_blocks):
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
planes = 64 * 2**i
res_layer = make_res_layer(
block,
self.block,
self.inplanes,
planes,
num_blocks,
......@@ -253,12 +257,13 @@ class ResNet(nn.Module):
dilation=dilation,
style=self.style,
with_cp=with_cp)
self.inplanes = planes * block.expansion
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
......
import math
import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from .resnet import ResNet
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
groups=1,
base_width=4,
style='pytorch',
with_cp=False):
"""Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
width = planes if groups == 1 else math.floor(
planes * (base_width / 64)) * groups
if style == 'pytorch':
conv1_stride = 1
conv2_stride = stride
else:
conv1_stride = stride
conv2_stride = 1
self.conv1 = nn.Conv2d(
inplanes, width, kernel_size=1, stride=conv1_stride, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
stride=conv2_stride,
padding=dilation,
dilation=dilation,
groups=groups,
bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(
width, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
groups=1,
base_width=4,
style='pytorch',
with_cp=False):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
inplanes,
planes,
stride,
dilation,
downsample,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
1,
dilation,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp))
return nn.Sequential(*layers)
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
groups=1,
base_width=4,
*args,
**kwargs):
super(ResNeXt, self).__init__(*args, **kwargs)
self.groups = groups
self.base_width = base_width
self.inplanes = 64
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[0][i]
dilation = self.dilations[0][i]
planes = 64 * 2**i
res_layer = make_res_layer(
self.block,
self.inplanes,
planes,
num_blocks,
stride=stride,
dilation=dilation,
groups=self.groups,
base_width=self.base_width,
style=self.style,
with_cp=self.with_cp)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
from __future__ import division
import sys
sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmcv')
sys.path.insert(0, '/mnt/lustre/pangjiangmiao/codebase/mmdet')
import argparse
from mmcv import Config
......@@ -14,6 +17,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work_dir', help='the dir to save logs and models')
parser.add_argument('--resume_from', help='the checkpoint to resume from')
parser.add_argument(
'--validate',
action='store_true',
......@@ -43,6 +47,8 @@ def main():
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if cfg.checkpoint_config is not None:
# save mmdet version in checkpoints as meta data
......@@ -67,6 +73,13 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
import torch.distributed as dist
if dist.get_rank() == 0:
with open('/mnt/lustre/pangjiangmiao/r50_32x4d_mmdet.txt', 'w') as f:
for k in model.state_dict().keys():
if 'num_batches_tracked' in k: continue
f.writelines('{}\n'.format(k))
train_dataset = obj_from_dict(cfg.data.train, datasets)
train_detector(
model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment