Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d8b5d1be
Commit
d8b5d1be
authored
Feb 28, 2019
by
Michael Carilli
Browse files
Adding distributed tests and support for FusedAdam
parent
d24c25b9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
681 additions
and
24 deletions
+681
-24
apex/amp/_initialize.py
apex/amp/_initialize.py
+16
-0
apex/amp/handle.py
apex/amp/handle.py
+30
-23
apex/optimizers/fp16_optimizer.py
apex/optimizers/fp16_optimizer.py
+1
-1
tests/L1/cross_product/compare.py
tests/L1/cross_product/compare.py
+0
-0
tests/L1/cross_product/main_amp.py
tests/L1/cross_product/main_amp.py
+0
-0
tests/L1/cross_product/run_test.sh
tests/L1/cross_product/run_test.sh
+0
-0
tests/L1/cross_product_distributed/compare.py
tests/L1/cross_product_distributed/compare.py
+33
-0
tests/L1/cross_product_distributed/main_amp.py
tests/L1/cross_product_distributed/main_amp.py
+515
-0
tests/L1/cross_product_distributed/run_test.sh
tests/L1/cross_product_distributed/run_test.sh
+86
-0
No files found.
apex/amp/_initialize.py
View file @
d8b5d1be
...
...
@@ -86,6 +86,22 @@ def check_optimizers(optimizers):
"on the specified opt_level (and optional overridden properties)."
)
def
wrap_fused_adam
(
optimizer
,
properties
):
msg
=
'Currently, the usage of FusedAdam is restricted to '
\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '
\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert
properties
.
master_weights
is
True
,
msg
assert
properties
.
cast_model_type
is
torch
.
float16
,
msg
assert
(
properties
.
keep_batchnorm_fp32
is
False
or
properties
.
keep_batchnorm_fp32
is
None
),
msg
if
properties
.
loss_scale
==
"dynamic"
return
FP16_Optimizer_for_fused
(
optimizer
,
dynamic_loss_scale
=
True
)
else
return
FP16_Optimizer_for_fused
(
optimizer
,
static_loss_scale
=
properties
.
loss_scale
)
def
_initialize
(
models
,
optimizers
,
properties
):
from
apex.parallel
import
DistributedDataParallel
as
apex_DDP
from
.amp
import
init
as
amp_init
...
...
apex/amp/handle.py
View file @
d8b5d1be
...
...
@@ -6,7 +6,8 @@ from . import utils
from
.opt
import
OptimWrapper
from
.scaler
import
LossScaler
,
iter_params
from
._amp_state
import
_amp_state
from
..fp16_utils
import
FP16_Optimizer
from
..fp16_utils
import
FP16_Optimizer
as
FP16_Optimizer_general
from
..optimizers
import
FP16_Optimizer
as
FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...
...
@@ -35,12 +36,17 @@ def scale_loss(loss,
if
optimizer
.
loss_scaler
is
None
:
raise
RuntimeError
(
"optimizer passed to scale_loss does not have a loss_scaler."
)
loss_scale
=
optimizer
.
loss_scaler
.
loss_scale
()
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if
isinstance
(
optimizer
,
FP16_Optimizer_for_fused
):
loss_scale
=
optimizer
.
cur_scale
else
:
loss_scale
=
optimizer
.
loss_scaler
.
loss_scale
()
if
((
not
_amp_state
.
opt_properties
.
master_weights
)
and
(
not
optimizer
.
loss_scaler
.
dynamic
)
and
loss_scale
==
1.0
):
yield
loss
yield
loss
.
float
()
# Needing to drop the cache here as well is an ugly gotcha.
# But for now I think it's necessary to short-circuit.
# Probably ok to skip this if not delay_unscale
...
...
@@ -48,32 +54,33 @@ def scale_loss(loss,
_amp_state
.
handle
.
_clear_cache
()
return
yield
loss
*
loss_scale
yield
(
loss
.
float
())
*
loss_scale
# this isn't pretty but it unifies things. Once I deprecate the old API entirely,
# I will have freedom to clean this up. Maybe instead of wrapping optimizers,
# I can simply construct a set of attributes (e.g. master params) and assign them
# directly to optimizer instances.
if
not
delay_unscale
:
if
isinstance
(
optimizer
,
FP16_Optimizer
):
optimizer
.
update_master_grads
()
else
:
optimizer
.
loss_scaler
.
clear_overflow_state
()
optimizer
.
loss_scaler
.
unscale
(
iter_params
(
optimizer
.
param_groups
),
iter_params
(
optimizer
.
param_groups
),
loss_scale
)
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip
=
optimizer
.
loss_scaler
.
update_scale
()
if
should_skip
:
optimizer_step
=
optimizer
.
step
def
skip_step
():
logger
=
logging
.
getLogger
(
'apex.amp'
)
logger
.
warning
(
"Gradient overflow. Skipping step, reducing "
+
"loss scale to {}"
.
format
(
optimizer
.
loss_scaler
.
loss_scale
()))
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
skip_step
if
not
isinstance
(
optimizer
,
FP16_Optimizer_for_fused
):
if
isinstance
(
optimizer
,
FP16_Optimizer_general
):
optimizer
.
update_master_grads
()
else
:
optimizer
.
loss_scaler
.
clear_overflow_state
()
optimizer
.
loss_scaler
.
unscale
(
iter_params
(
optimizer
.
param_groups
),
iter_params
(
optimizer
.
param_groups
),
loss_scale
)
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip
=
optimizer
.
loss_scaler
.
update_scale
()
if
should_skip
:
optimizer_step
=
optimizer
.
step
def
skip_step
():
logger
=
logging
.
getLogger
(
'apex.amp'
)
logger
.
warning
(
"Gradient overflow. Skipping step, reducing "
+
"loss scale to {}"
.
format
(
optimizer
.
loss_scaler
.
loss_scale
()))
optimizer
.
step
=
optimizer_step
optimizer
.
step
=
skip_step
# Probably ok to skip this if not delay_unscale
if
_amp_state
.
opt_properties
.
patch_torch_functions
:
...
...
apex/optimizers/fp16_optimizer.py
View file @
d8b5d1be
...
...
@@ -96,7 +96,7 @@ class FP16_Optimizer(object):
if
dynamic_loss_args
is
not
None
:
raise
SystemError
(
"Do not support dynamic loss scale args for now."
)
self
.
dynamic_loss_scale
=
True
self
.
cur_scale
=
2
**
32
self
.
cur_scale
=
2
**
16
self
.
cur_iter
=
0
self
.
last_overflow_iter
=
-
1
self
.
scale_factor
=
2
...
...
tests/L1/c
ompare_python_vs_extensions
/compare.py
→
tests/L1/c
ross_product
/compare.py
View file @
d8b5d1be
File moved
tests/L1/c
ompare_python_vs_extensions
/main_amp.py
→
tests/L1/c
ross_product
/main_amp.py
View file @
d8b5d1be
File moved
tests/L1/c
ompare_python_vs_extensions
/run_test.sh
→
tests/L1/c
ross_product
/run_test.sh
View file @
d8b5d1be
File moved
tests/L1/cross_product_distributed/compare.py
0 → 100644
View file @
d8b5d1be
import
argparse
import
torch
parser
=
argparse
.
ArgumentParser
(
description
=
'Compare'
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
base_file
=
str
(
args
.
opt_level
)
+
"_"
+
str
(
args
.
loss_scale
)
+
"_"
+
str
(
args
.
keep_batchnorm_fp32
)
file_e
=
"True_"
+
base_file
file_p
=
"False_"
+
base_file
dict_e
=
torch
.
load
(
file_e
)
dict_p
=
torch
.
load
(
file_p
)
torch
.
set_printoptions
(
precision
=
10
)
print
(
file_e
)
print
(
file_p
)
for
n
,
(
i_e
,
i_p
)
in
enumerate
(
zip
(
dict_e
[
"Iteration"
],
dict_p
[
"Iteration"
])):
assert
i_e
==
i_p
,
"i_e = {}, i_p = {}"
.
format
(
i_e
,
i_p
)
loss_e
=
dict_e
[
"Loss"
][
n
]
loss_p
=
dict_p
[
"Loss"
][
n
]
assert
loss_e
==
loss_p
,
"Iteration {}, loss_e = {}, loss_p = {}"
.
format
(
i_e
,
loss_e
,
loss_p
)
print
(
"{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}"
.
format
(
i_e
,
loss_e
,
loss_p
,
dict_e
[
"Speed"
][
n
],
dict_p
[
"Speed"
][
n
]))
tests/L1/cross_product_distributed/main_amp.py
0 → 100644
View file @
d8b5d1be
import
argparse
import
os
import
shutil
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.parallel
import
torch.backends.cudnn
as
cudnn
import
torch.distributed
as
dist
import
torch.optim
import
torch.utils.data
import
torch.utils.data.distributed
import
torchvision.transforms
as
transforms
import
torchvision.datasets
as
datasets
import
torchvision.models
as
models
import
numpy
as
np
try
:
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.fp16_utils
import
*
from
apex
import
amp
from
apex.multi_tensor_apply
import
multi_tensor_applier
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
and
callable
(
models
.
__dict__
[
name
]))
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch ImageNet Training'
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--arch'
,
'-a'
,
metavar
=
'ARCH'
,
default
=
'resnet18'
,
choices
=
model_names
,
help
=
'model architecture: '
+
' | '
.
join
(
model_names
)
+
' (default: resnet18)'
)
parser
.
add_argument
(
'-j'
,
'--workers'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of data loading workers (default: 4)'
)
parser
.
add_argument
(
'--epochs'
,
default
=
90
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of total epochs to run'
)
parser
.
add_argument
(
'--start-epoch'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'manual epoch number (useful on restarts)'
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
metavar
=
'N'
,
help
=
'mini-batch size per process (default: 256)'
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.'
)
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum'
)
parser
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'W'
,
help
=
'weight decay (default: 1e-4)'
)
parser
.
add_argument
(
'--print-freq'
,
'-p'
,
default
=
10
,
type
=
int
,
metavar
=
'N'
,
help
=
'print frequency (default: 10)'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
type
=
str
,
metavar
=
'PATH'
,
help
=
'path to latest checkpoint (default: none)'
)
parser
.
add_argument
(
'-e'
,
'--evaluate'
,
dest
=
'evaluate'
,
action
=
'store_true'
,
help
=
'evaluate model on validation set'
)
parser
.
add_argument
(
'--pretrained'
,
dest
=
'pretrained'
,
action
=
'store_true'
,
help
=
'use pre-trained model'
)
parser
.
add_argument
(
'--prof'
,
dest
=
'prof'
,
action
=
'store_true'
,
help
=
'Only run 10 iterations for profiling.'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
)
parser
.
add_argument
(
"--local_rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
'--sync_bn'
,
action
=
'store_true'
,
help
=
'enabling apex sync BN.'
)
parser
.
add_argument
(
'--has-ext'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--prints-to-process'
,
type
=
int
,
default
=
10
)
cudnn
.
benchmark
=
True
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
tens
=
torch
.
from_numpy
(
nump_array
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
best_prec1
=
0
args
=
parser
.
parse_args
()
# Let multi_tensor_applier be the canary in the coalmine
# that verifies if the backend is what we think it is
assert
multi_tensor_applier
.
available
==
args
.
has_ext
print
(
"opt_level = {}"
.
format
(
args
.
opt_level
))
print
(
"keep_batchnorm_fp32 = {}"
.
format
(
args
.
keep_batchnorm_fp32
),
type
(
args
.
keep_batchnorm_fp32
))
print
(
"loss_scale = {}"
.
format
(
args
.
loss_scale
),
type
(
args
.
loss_scale
))
if
args
.
deterministic
:
cudnn
.
benchmark
=
False
cudnn
.
deterministic
=
True
torch
.
manual_seed
(
args
.
local_rank
)
torch
.
set_printoptions
(
precision
=
10
)
def
main
():
global
best_prec1
,
args
args
.
distributed
=
False
if
'WORLD_SIZE'
in
os
.
environ
:
args
.
distributed
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
args
.
gpu
=
0
args
.
world_size
=
1
if
args
.
distributed
:
args
.
gpu
=
args
.
local_rank
%
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
args
.
gpu
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
# create model
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
model
=
models
.
__dict__
[
args
.
arch
](
pretrained
=
True
)
else
:
print
(
"=> creating model '{}'"
.
format
(
args
.
arch
))
model
=
models
.
__dict__
[
args
.
arch
]()
if
args
.
sync_bn
:
import
apex
print
(
"using apex synced BN"
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
model
.
cuda
()
# Scale learning rate based on global batch size
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
# enabled=False,
opt_level
=
args
.
opt_level
,
keep_batchnorm_fp32
=
args
.
keep_batchnorm_fp32
,
loss_scale
=
args
.
loss_scale
)
if
args
.
distributed
:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model
=
DDP
(
model
,
delay_allreduce
=
True
)
# define loss function (criterion) and optimizer
criterion
=
nn
.
CrossEntropyLoss
().
cuda
()
# Optionally resume from a checkpoint
if
args
.
resume
:
# Use a local scope to avoid dangling references
def
resume
():
if
os
.
path
.
isfile
(
args
.
resume
):
print
(
"=> loading checkpoint '{}'"
.
format
(
args
.
resume
))
checkpoint
=
torch
.
load
(
args
.
resume
,
map_location
=
lambda
storage
,
loc
:
storage
.
cuda
(
args
.
gpu
))
args
.
start_epoch
=
checkpoint
[
'epoch'
]
best_prec1
=
checkpoint
[
'best_prec1'
]
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
print
(
"=> loaded checkpoint '{}' (epoch {})"
.
format
(
args
.
resume
,
checkpoint
[
'epoch'
]))
else
:
print
(
"=> no checkpoint found at '{}'"
.
format
(
args
.
resume
))
resume
()
# Data loading code
traindir
=
os
.
path
.
join
(
args
.
data
,
'train'
)
valdir
=
os
.
path
.
join
(
args
.
data
,
'val'
)
if
(
args
.
arch
==
"inception_v3"
):
crop_size
=
299
val_size
=
320
# I chose this value arbitrarily, we can adjust.
else
:
crop_size
=
224
val_size
=
256
train_dataset
=
datasets
.
ImageFolder
(
traindir
,
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
crop_size
),
transforms
.
RandomHorizontalFlip
(),
# transforms.ToTensor(), Too slow
# normalize,
]))
val_dataset
=
datasets
.
ImageFolder
(
valdir
,
transforms
.
Compose
([
transforms
.
Resize
(
val_size
),
transforms
.
CenterCrop
(
crop_size
),
]))
train_sampler
=
None
val_sampler
=
None
if
args
.
distributed
:
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_collate
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
val_sampler
,
collate_fn
=
fast_collate
)
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
return
for
epoch
in
range
(
args
.
start_epoch
,
args
.
epochs
):
if
args
.
distributed
:
train_sampler
.
set_epoch
(
epoch
)
# train for one epoch
train
(
train_loader
,
model
,
criterion
,
optimizer
,
epoch
)
if
args
.
prof
:
break
# evaluate on validation set
prec1
=
validate
(
val_loader
,
model
,
criterion
)
# remember best prec@1 and save checkpoint
if
args
.
local_rank
==
0
:
is_best
=
prec1
>
best_prec1
best_prec1
=
max
(
prec1
,
best_prec1
)
save_checkpoint
({
'epoch'
:
epoch
+
1
,
'arch'
:
args
.
arch
,
'state_dict'
:
model
.
state_dict
(),
'best_prec1'
:
best_prec1
,
'optimizer'
:
optimizer
.
state_dict
(),
},
is_best
)
class
data_prefetcher
():
def
__init__
(
self
,
loader
):
self
.
loader
=
iter
(
loader
)
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
mean
=
torch
.
tensor
([
0.485
*
255
,
0.456
*
255
,
0.406
*
255
]).
cuda
().
view
(
1
,
3
,
1
,
1
)
self
.
std
=
torch
.
tensor
([
0.229
*
255
,
0.224
*
255
,
0.225
*
255
]).
cuda
().
view
(
1
,
3
,
1
,
1
)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self
.
preload
()
def
preload
(
self
):
try
:
self
.
next_input
,
self
.
next_target
=
next
(
self
.
loader
)
except
StopIteration
:
self
.
next_input
=
None
self
.
next_target
=
None
return
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
next_input
=
self
.
next_input
.
cuda
(
non_blocking
=
True
)
self
.
next_target
=
self
.
next_target
.
cuda
(
non_blocking
=
True
)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
def
next
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
input
=
self
.
next_input
target
=
self
.
next_target
self
.
preload
()
return
input
,
target
def
train
(
train_loader
,
model
,
criterion
,
optimizer
,
epoch
):
batch_time
=
AverageMeter
()
data_time
=
AverageMeter
()
losses
=
AverageMeter
()
top1
=
AverageMeter
()
top5
=
AverageMeter
()
# switch to train mode
model
.
train
()
end
=
time
.
time
()
run_info_dict
=
{
"Iteration"
:
[],
"Loss"
:
[],
"Speed"
:
[]}
prefetcher
=
data_prefetcher
(
train_loader
)
input
,
target
=
prefetcher
.
next
()
i
=
-
1
while
input
is
not
None
:
i
+=
1
# No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
# adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if
args
.
prof
:
if
i
>
10
:
break
# measure data loading time
data_time
.
update
(
time
.
time
()
-
end
)
# compute output
output
=
model
(
input
)
loss
=
criterion
(
output
,
target
)
# measure accuracy and record loss
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
if
args
.
distributed
:
reduced_loss
=
reduce_tensor
(
loss
.
data
)
prec1
=
reduce_tensor
(
prec1
)
prec5
=
reduce_tensor
(
prec5
)
else
:
reduced_loss
=
loss
.
data
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
# compute gradient and do SGD step
optimizer
.
zero_grad
()
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
# for param in model.parameters():
# print(param.data.double().sum().item(), param.grad.data.double().sum().item())
# torch.cuda.synchronize()
torch
.
cuda
.
nvtx
.
range_push
(
"step"
)
optimizer
.
step
()
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
synchronize
()
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
input
,
target
=
prefetcher
.
next
()
if
args
.
local_rank
==
0
and
i
%
args
.
print_freq
==
0
and
i
>
1
:
print
(
'Epoch: [{0}][{1}/{2}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Speed {3:.3f} ({4:.3f})
\t
'
'Data {data_time.val:.3f} ({data_time.avg:.3f})
\t
'
'Loss {loss.val:.10f} ({loss.avg:.4f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
epoch
,
i
,
len
(
train_loader
),
args
.
world_size
*
args
.
batch_size
/
batch_time
.
val
,
args
.
world_size
*
args
.
batch_size
/
batch_time
.
avg
,
batch_time
=
batch_time
,
data_time
=
data_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
run_info_dict
[
"Iteration"
].
append
(
i
)
run_info_dict
[
"Loss"
].
append
(
losses
.
val
)
run_info_dict
[
"Speed"
].
append
(
args
.
world_size
*
args
.
batch_size
/
batch_time
.
val
)
if
len
(
run_info_dict
[
"Loss"
])
==
args
.
prints_to_process
:
torch
.
save
(
run_info_dict
,
str
(
args
.
has_ext
)
+
"_"
+
str
(
args
.
opt_level
)
+
"_"
+
str
(
args
.
loss_scale
)
+
"_"
+
str
(
args
.
keep_batchnorm_fp32
))
quit
()
def
validate
(
val_loader
,
model
,
criterion
):
batch_time
=
AverageMeter
()
losses
=
AverageMeter
()
top1
=
AverageMeter
()
top5
=
AverageMeter
()
# switch to evaluate mode
model
.
eval
()
end
=
time
.
time
()
prefetcher
=
data_prefetcher
(
val_loader
)
input
,
target
=
prefetcher
.
next
()
i
=
-
1
while
input
is
not
None
:
i
+=
1
# compute output
with
torch
.
no_grad
():
output
=
model
(
input
)
loss
=
criterion
(
output
,
target
)
# measure accuracy and record loss
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
if
args
.
distributed
:
reduced_loss
=
reduce_tensor
(
loss
.
data
)
prec1
=
reduce_tensor
(
prec1
)
prec5
=
reduce_tensor
(
prec5
)
else
:
reduced_loss
=
loss
.
data
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
if
args
.
local_rank
==
0
and
i
%
args
.
print_freq
==
0
:
print
(
'Test: [{0}/{1}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Speed {2:.3f} ({3:.3f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
i
,
len
(
val_loader
),
args
.
world_size
*
args
.
batch_size
/
batch_time
.
val
,
args
.
world_size
*
args
.
batch_size
/
batch_time
.
avg
,
batch_time
=
batch_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
input
,
target
=
prefetcher
.
next
()
print
(
' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.
format
(
top1
=
top1
,
top5
=
top5
))
return
top1
.
avg
def
save_checkpoint
(
state
,
is_best
,
filename
=
'checkpoint.pth.tar'
):
torch
.
save
(
state
,
filename
)
if
is_best
:
shutil
.
copyfile
(
filename
,
'model_best.pth.tar'
)
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
adjust_learning_rate
(
optimizer
,
epoch
,
step
,
len_epoch
):
"""LR schedule that should yield 76% converged accuracy with batch size 256"""
factor
=
epoch
//
30
if
epoch
>=
80
:
factor
=
factor
+
1
lr
=
args
.
lr
*
(
0.1
**
factor
)
"""Warmup"""
if
epoch
<
5
:
lr
=
lr
*
float
(
1
+
step
+
epoch
*
len_epoch
)
/
(
5.
*
len_epoch
)
# if(args.local_rank == 0):
# print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""Computes the precision@k for the specified values of k"""
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
batch_size
))
return
res
def
reduce_tensor
(
tensor
):
rt
=
tensor
.
clone
()
dist
.
all_reduce
(
rt
,
op
=
dist
.
reduce_op
.
SUM
)
rt
/=
args
.
world_size
return
rt
if
__name__
==
'__main__'
:
main
()
tests/L1/cross_product_distributed/run_test.sh
0 → 100644
View file @
d8b5d1be
#!/bin/bash
DATADIR
=
"/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
BASE_CMD
=
"python -m multiproc python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5"
print_banner
()
{
printf
"
\n\n\n\e
[30m
\e
[42m
$1
\e
[0m
\n\n\n\n
"
}
keep_batchnorms
=(
""
"--keep-batchnorm-fp32 True"
"--keep-batchnorm-fp32 False"
)
loss_scales
=(
""
"--loss-scale 1.0"
"--loss-scale 128.0"
"--loss-scale dynamic"
)
opt_levels
=(
"O0"
"O1"
"O2"
"O3"
)
rm
True
*
rm
False
*
set
-e
pushd
../../..
python setup.py
install
--cuda_ext
--cpp_ext
popd
for
opt_level
in
"
${
opt_levels
[@]
}
"
do
for
loss_scale
in
"
${
loss_scales
[@]
}
"
do
for
keep_batchnorm
in
"
${
keep_batchnorms
[@]
}
"
do
print_banner
"
$BASE_CMD
--opt-level
$opt_level
${
loss_scale
}
${
keep_batchnorm
}
--has-ext
$DATADIR
"
set
-x
$BASE_CMD
--opt-level
$opt_level
${
loss_scale
}
${
keep_batchnorm
}
--has-ext
$DATADIR
set
+x
done
done
done
pushd
../../..
python setup.py
install
popd
for
opt_level
in
"
${
opt_levels
[@]
}
"
do
for
loss_scale
in
"
${
loss_scales
[@]
}
"
do
for
keep_batchnorm
in
"
${
keep_batchnorms
[@]
}
"
do
print_banner
"
$BASE_CMD
--opt-level
$opt_level
${
loss_scale
}
${
keep_batchnorm
}
$DATADIR
"
set
-x
$BASE_CMD
--opt-level
$opt_level
${
loss_scale
}
${
keep_batchnorm
}
$DATADIR
set
+x
done
done
done
for
opt_level
in
"
${
opt_levels
[@]
}
"
do
for
loss_scale
in
"
${
loss_scales
[@]
}
"
do
for
keep_batchnorm
in
"
${
keep_batchnorms
[@]
}
"
do
set
-x
python compare.py
--opt-level
$opt_level
${
loss_scale
}
${
keep_batchnorm
}
set
+x
done
done
done
pushd
../../..
python setup.py
install
--cuda_ext
--cpp_ext
popd
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment