Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
e0422994
"...text-generation-inference.git" did not exist on "4c693e65245058a4d0ca227ee30b6d8a35d115f1"
Commit
e0422994
authored
Sep 29, 2018
by
Kai Chen
Browse files
update cifar10 example
parent
961c3388
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
27 deletions
+107
-27
examples/config_cifar10.py
examples/config_cifar10.py
+7
-6
examples/dist_train_cifar10.sh
examples/dist_train_cifar10.sh
+5
-0
examples/train_cifar10.py
examples/train_cifar10.py
+95
-21
No files found.
examples/config_cifar10.py
View file @
e0422994
...
@@ -8,23 +8,24 @@ batch_size = 64
...
@@ -8,23 +8,24 @@ batch_size = 64
# optimizer and learning rate
# optimizer and learning rate
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
lr_policy
=
dict
(
policy
=
'step'
,
step
=
2
)
optimizer_config
=
dict
(
grad_clip
=
None
)
lr_config
=
dict
(
policy
=
'step'
,
step
=
2
)
# runtime settings
# runtime settings
work_dir
=
'./demo'
work_dir
=
'./demo'
gpus
=
range
(
2
)
gpus
=
range
(
2
)
dist_params
=
dict
(
backend
=
'gloo'
)
# gloo is much slower than nccl
data_workers
=
2
# data workers per gpu
data_workers
=
2
# data workers per gpu
checkpoint_c
f
g
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
checkpoint_c
onfi
g
=
dict
(
interval
=
1
)
# save checkpoint at every epoch
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
workflow
=
[(
'train'
,
1
),
(
'val'
,
1
)]
max
_epoch
=
6
total
_epoch
s
=
6
resume_from
=
None
resume_from
=
None
load_from
=
None
load_from
=
None
# logging settings
# logging settings
log_level
=
'INFO'
log_level
=
'INFO'
log_cfg
=
dict
(
log_config
=
dict
(
# log at every 50 iterations
interval
=
50
,
# log at every 50 iterations
interval
=
50
,
hooks
=
[
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
# dict(type='TensorboardLoggerHook', log_dir=work_dir + '/log'),
...
...
examples/dist_train_cifar10.sh
0 → 100755
View file @
e0422994
#!/usr/bin/env bash
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train_cifar10.py
$1
--launcher
pytorch
${
@
:3
}
\ No newline at end of file
examples/train_cifar10.py
View file @
e0422994
import
logging
import
os
from
argparse
import
ArgumentParser
from
argparse
import
ArgumentParser
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.torchpack
import
Runner
from
mmcv.torchpack
import
Runner
,
DistSamplerSeedHook
from
torch.nn.parallel
import
DataParallel
,
DistributedDataParallel
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
torchvision
import
datasets
,
transforms
from
torchvision
import
datasets
,
transforms
import
resnet_cifar
import
resnet_cifar
...
@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode):
...
@@ -41,18 +48,55 @@ def batch_processor(model, data, train_mode):
return
outputs
return
outputs
def
get_logger
(
log_level
):
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
,
level
=
log_level
)
logger
=
logging
.
getLogger
()
return
logger
def
init_dist
(
backend
=
'nccl'
,
**
kwargs
):
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
rank
=
int
(
os
.
environ
[
'RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
def
parse_args
():
def
parse_args
():
parser
=
ArgumentParser
(
description
=
'Train CIFAR-10 classification'
)
parser
=
ArgumentParser
(
description
=
'Train CIFAR-10 classification'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
model
=
getattr
(
resnet_cifar
,
cfg
.
model
)()
model
=
torch
.
nn
.
DataParallel
(
model
,
device_ids
=
cfg
.
gpus
).
cuda
()
logger
=
get_logger
(
cfg
.
log_level
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
logger
.
info
(
'Disabled distributed training.'
)
else
:
dist
=
True
init_dist
(
**
cfg
.
dist_params
)
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
if
rank
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'Enabled distributed training.'
)
# build datasets and dataloaders
normalize
=
transforms
.
Normalize
(
mean
=
cfg
.
mean
,
std
=
cfg
.
std
)
normalize
=
transforms
.
Normalize
(
mean
=
cfg
.
mean
,
std
=
cfg
.
std
)
train_dataset
=
datasets
.
CIFAR10
(
train_dataset
=
datasets
.
CIFAR10
(
root
=
cfg
.
data_root
,
root
=
cfg
.
data_root
,
...
@@ -65,37 +109,67 @@ def main():
...
@@ -65,37 +109,67 @@ def main():
]))
]))
val_dataset
=
datasets
.
CIFAR10
(
val_dataset
=
datasets
.
CIFAR10
(
root
=
cfg
.
data_root
,
root
=
cfg
.
data_root
,
train
=
False
,
transform
=
transforms
.
Compose
([
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
normalize
,
normalize
,
]))
]))
if
dist
:
num_workers
=
cfg
.
data_workers
*
len
(
cfg
.
gpus
)
num_workers
=
cfg
.
data_workers
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
assert
cfg
.
batch_size
%
world_size
==
0
batch_size
=
cfg
.
batch_size
//
world_size
train_sampler
=
DistributedSampler
(
train_dataset
,
world_size
,
rank
)
val_sampler
=
DistributedSampler
(
val_dataset
,
world_size
,
rank
)
shuffle
=
False
else
:
num_workers
=
cfg
.
data_workers
*
len
(
cfg
.
gpus
)
batch_size
=
cfg
.
batch_size
train_sampler
=
None
val_sampler
=
None
shuffle
=
True
train_loader
=
DataLoader
(
train_dataset
,
train_dataset
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
batch_size
,
shuffle
=
Tru
e
,
shuffle
=
shuffl
e
,
num_workers
=
num_work
er
s
,
sampler
=
train_sampl
er
,
pin_memory
=
True
)
num_workers
=
num_workers
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_loader
=
DataLoader
(
val_dataset
,
val_dataset
,
batch_size
=
cfg
.
batch_size
,
batch_size
=
batch_size
,
shuffle
=
False
,
shuffle
=
False
,
num_workers
=
num_workers
,
sampler
=
val_sampler
,
pin_memory
=
True
)
num_workers
=
num_workers
)
runner
=
Runner
(
model
,
cfg
.
optimizer
,
batch_processor
,
cfg
.
work_dir
)
runner
.
register_default_hooks
(
lr_config
=
cfg
.
lr_policy
,
checkpoint_config
=
cfg
.
checkpoint_cfg
,
log_config
=
cfg
.
log_cfg
)
# build model
model
=
getattr
(
resnet_cifar
,
cfg
.
model
)()
if
dist
:
model
=
DistributedDataParallel
(
model
.
cuda
(),
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
model
=
DataParallel
(
model
,
device_ids
=
cfg
.
gpus
).
cuda
()
# build runner and register hooks
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
log_level
=
cfg
.
log_level
)
runner
.
register_training_hooks
(
lr_config
=
cfg
.
lr_config
,
optimizer_config
=
cfg
.
optimizer_config
,
checkpoint_config
=
cfg
.
checkpoint_config
,
log_config
=
cfg
.
log_config
)
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
# load param (if necessary) and run
if
cfg
.
get
(
'resume_from'
)
is
not
None
:
if
cfg
.
get
(
'resume_from'
)
is
not
None
:
runner
.
resume
(
cfg
.
resume_from
)
runner
.
resume
(
cfg
.
resume_from
)
elif
cfg
.
get
(
'load_from'
)
is
not
None
:
elif
cfg
.
get
(
'load_from'
)
is
not
None
:
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
load_checkpoint
(
cfg
.
load_from
)
runner
.
run
([
train_loader
,
val_loader
],
cfg
.
workflow
,
cfg
.
max
_epoch
)
runner
.
run
([
train_loader
,
val_loader
],
cfg
.
workflow
,
cfg
.
total
_epoch
s
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment