Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
e0422994
Commit
e0422994
authored
Sep 29, 2018
by
Kai Chen
Browse files
update cifar10 example
parent
961c3388
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
27 deletions
+107
-27
examples/config_cifar10.py
examples/config_cifar10.py
+7
-6
examples/dist_train_cifar10.sh
examples/dist_train_cifar10.sh
+5
-0
examples/train_cifar10.py
examples/train_cifar10.py
+95
-21
No files found.
examples/config_cifar10.py
View file @
e0422994
...
...
@@ -8,23 +8,24 @@ batch_size = 64
# optimizer and learning rate
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
lr_policy
=
dict
(
policy
=
'step'
,
step
=
2
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'step'
,
step
=
2
)
# runtime settings
work_dir
=
'./demo'
gpus
=
range
(
2
)
dist_params
=
dict
(
backend
=
'gloo'
)
# gloo is much slower than nccl
data_workers
=
2
# data workers per gpu
checkpoint_c
f
g
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
checkpoint_c
onfi
g
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
max
_epoch
=
6
total
_epoch
s
=
6
resume_from
=
None
load_from
=
None
# logging settings
log_level
=
'INFO'
log_cfg
=
dict
(
# log at every 50 iterations
interval
=
50
,
log_config
=
dict
(
interval
=
50
,
# log at every 50 iterations
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
...
...
examples/dist_train_cifar10.sh
0 → 100755
View file @
e0422994
#!/usr/bin/env bash
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train_cifar10.py
$1
--launcher
pytorch
${
@
:3
}
\ No newline at end of file
examples/train_cifar10.py
View file @
e0422994
import
logging
import
os
from
argparse
import
ArgumentParser
from
collections
import
OrderedDict
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
from
mmcv
import
Config
from
mmcv.torchpack
import
Runner
from
mmcv.torchpack
import
Runner
,
DistSamplerSeedHook
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision
import
datasets
,
transforms
import
resnet_cifar
...
...
@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode):
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
init_dist
(
backend
=
'nccl'
,
**
kwargs
):
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
rank
=
int
(
os
.
environ
[
'RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
parse_args
():
parser
=
ArgumentParser
(
description
=
'Train CIFAR-10 classification'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
return
parser
.
parse_args
()
def
main
():
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
model
=
getattr
(
resnet_cifar
,
cfg
.
model
)()
model
=
torch
.
nn
.
DataParallel
(
model
,
device_ids
=
cfg
.
gpus
).
cuda
()
logger
=
get_logger
(
cfg
.
log_level
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Disabled distributed training.'
)
else
:
dist
=
True
init_dist
(
**
cfg
.
dist_params
)
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
if
rank
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Enabled distributed training.'
)
# build datasets and dataloaders
normalize
=
transforms
.
Normalize
(
mean
=
cfg
.
mean
,
std
=
cfg
.
std
)
train_dataset
=
datasets
.
CIFAR10
(
root
=
cfg
.
data_root
,
...
...
@@ -65,37 +109,67 @@ def main():
]))
val_dataset
=
datasets
.
CIFAR10
(
root
=
cfg
.
data_root
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
]))
num_workers
=
cfg
.
data_workers
*
len
(
cfg
.
gpus
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
if
dist
:
num_workers
=
cfg
.
data_workers
assert
cfg
.
batch_size
%
world_size
==
0
batch_size
=
cfg
.
batch_size
//
world_size
train_sampler
=
DistributedSampler
(
train_dataset
,
world_size
,
rank
)
val_sampler
=
DistributedSampler
(
val_dataset
,
world_size
,
rank
)
shuffle
=
False
else
:
num_workers
=
cfg
.
data_workers
*
len
(
cfg
.
gpus
)
batch_size
=
cfg
.
batch_size
train_sampler
=
None
val_sampler
=
None
shuffle
=
True
train_loader
=
DataLoader
(
train_dataset
,
batch_size
=
cfg
.
batch_size
,
shuffle
=
Tru
e
,
num_workers
=
num_work
er
s
,
pin_memory
=
True
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
batch_size
=
batch_size
,
shuffle
=
shuffl
e
,
sampler
=
train_sampl
er
,
num_workers
=
num_workers
)
val_loader
=
DataLoader
(
val_dataset
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
num_workers
,
pin_memory
=
True
)
runner
=
Runner
(
model
,
cfg
.
optimizer
,
batch_processor
,
cfg
.
work_dir
)
runner
.
register_default_hooks
(
lr_config
=
cfg
.
lr_policy
,
checkpoint_config
=
cfg
.
checkpoint_cfg
,
log_config
=
cfg
.
log_cfg
)
sampler
=
val_sampler
,
num_workers
=
num_workers
)
# build model
model
=
getattr
(
resnet_cifar
,
cfg
.
model
)()
if
dist
:
model
=
DistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
model
=
DataParallel
(
model
,
device_ids
=
cfg
.
gpus
).
cuda
()
# build runner and register hooks
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
log_level
=
cfg
.
log_level
)
runner
.
register_training_hooks
(
lr_config
=
cfg
.
lr_config
,
optimizer_config
=
cfg
.
optimizer_config
,
checkpoint_config
=
cfg
.
checkpoint_config
,
log_config
=
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# load param (if necessary) and run
if
cfg
.
get
(
'resume_from'
)
is
not
None
:
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
get
(
'load_from'
)
is
not
None
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
([
train_loader
,
val_loader
],
cfg
.
workflow
,
cfg
.
max
_epoch
)
runner
.
run
([
train_loader
,
val_loader
],
cfg
.
workflow
,
cfg
.
total
_epoch
s
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment