Commit 7c82f221 authored by Michael Carilli's avatar Michael Carilli
Browse files

Cleaning up FusedAdam testing

parent d8b5d1be
...@@ -96,9 +96,9 @@ def wrap_fused_adam(optimizer, properties): ...@@ -96,9 +96,9 @@ def wrap_fused_adam(optimizer, properties):
assert (properties.keep_batchnorm_fp32 is False or assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic" if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True) return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale) return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
......
...@@ -61,6 +61,8 @@ def scale_loss(loss, ...@@ -61,6 +61,8 @@ def scale_loss(loss,
# I can simply construct a set of attributes (e.g. master params) and assign them # I can simply construct a set of attributes (e.g. master params) and assign them
# directly to optimizer instances. # directly to optimizer instances.
if not delay_unscale: if not delay_unscale:
# The FP16_Optimizer for FusedAdam will take care of unscaling as part of
# its step() method.
if not isinstance(optimizer, FP16_Optimizer_for_fused): if not isinstance(optimizer, FP16_Optimizer_for_fused):
if isinstance(optimizer, FP16_Optimizer_general): if isinstance(optimizer, FP16_Optimizer_general):
optimizer.update_master_grads() optimizer.update_master_grads()
......
...@@ -5,9 +5,14 @@ parser = argparse.ArgumentParser(description='Compare') ...@@ -5,9 +5,14 @@ parser = argparse.ArgumentParser(description='Compare')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
args = parser.parse_args() args = parser.parse_args()
base_file = str(args.opt_level) + "_" + str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) base_file = str(args.opt_level) + "_" +\
str(args.loss_scale) + "_" +\
str(args.keep_batchnorm_fp32) + "_" +\
str(args.fused_adam)
file_e = "True_" + base_file file_e = "True_" + base_file
file_p = "False_" + base_file file_p = "False_" + base_file
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
from apex import amp from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
...@@ -72,6 +72,7 @@ parser.add_argument('--has-ext', action='store_true') ...@@ -72,6 +72,7 @@ parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
parser.add_argument('--prints-to-process', type=int, default=10) parser.add_argument('--prints-to-process', type=int, default=10)
...@@ -148,6 +149,9 @@ def main(): ...@@ -148,6 +149,9 @@ def main():
# Scale learning rate based on global batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
if args.fused_adam:
optimizer = optimizers.FusedAdam(model.parameters())
else:
optimizer = torch.optim.SGD(model.parameters(), args.lr, optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
...@@ -380,7 +384,8 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -380,7 +384,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
if len(run_info_dict["Loss"]) == args.prints_to_process: if len(run_info_dict["Loss"]) == args.prints_to_process:
torch.save(run_info_dict, torch.save(run_info_dict,
str(args.has_ext) + "_" + str(args.opt_level) + "_" + str(args.has_ext) + "_" + str(args.opt_level) + "_" +
str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32)) str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
str(args.fused_adam))
quit() quit()
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/" DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
BASE_CMD="python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5" BASE_CMD="python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5"
ADAM_ARGS="--opt-level O2 --keep-batchnorm-fp32 False --fused-adam"
print_banner() { print_banner() {
printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n" printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n"
} }
...@@ -42,14 +44,26 @@ do ...@@ -42,14 +44,26 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
print_banner "$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR"
set -x set -x
$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR
set +x set +x
done done
done done
done done
# Handle FusedAdam separately due to limited support.
# FusedAdam will not be tested for bitwise accuracy against the Python implementation.
# The L0 tests already do so. These tests are here to ensure that it actually runs,
# and get an idea of performance.
for loss_scale in "${loss_scales[@]}"
do
print_banner "${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR"
set -x
${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR
set +x
done
pushd ../../.. pushd ../../..
python setup.py install python setup.py install
popd popd
...@@ -60,14 +74,16 @@ do ...@@ -60,14 +74,16 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
print_banner "$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR"
set -x set -x
$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR
set +x set +x
done done
done done
done done
print_banner "Checking for bitwise accuracy between Python-only and cpp/cuda extension installs"
for opt_level in "${opt_levels[@]}" for opt_level in "${opt_levels[@]}"
do do
for loss_scale in "${loss_scales[@]}" for loss_scale in "${loss_scales[@]}"
...@@ -75,7 +91,7 @@ do ...@@ -75,7 +91,7 @@ do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
set -x set -x
python compare.py --opt-level $opt_level ${loss_scale} ${keep_batchnorm} python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
set +x set +x
done done
done done
......
...@@ -5,9 +5,14 @@ parser = argparse.ArgumentParser(description='Compare') ...@@ -5,9 +5,14 @@ parser = argparse.ArgumentParser(description='Compare')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
args = parser.parse_args() args = parser.parse_args()
base_file = str(args.opt_level) + "_" + str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) base_file = str(args.opt_level) + "_" +\
str(args.loss_scale) + "_" +\
str(args.keep_batchnorm_fp32) + "_" +\
str(args.fused_adam)
file_e = "True_" + base_file file_e = "True_" + base_file
file_p = "False_" + base_file file_p = "False_" + base_file
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
try: try:
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import * from apex.fp16_utils import *
from apex import amp from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
...@@ -72,6 +72,7 @@ parser.add_argument('--has-ext', action='store_true') ...@@ -72,6 +72,7 @@ parser.add_argument('--has-ext', action='store_true')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--fused-adam', action='store_true')
parser.add_argument('--prints-to-process', type=int, default=10) parser.add_argument('--prints-to-process', type=int, default=10)
...@@ -148,6 +149,9 @@ def main(): ...@@ -148,6 +149,9 @@ def main():
# Scale learning rate based on global batch size # Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
if args.fused_adam:
optimizer = optimizers.FusedAdam(model.parameters())
else:
optimizer = torch.optim.SGD(model.parameters(), args.lr, optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
...@@ -380,7 +384,8 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -380,7 +384,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
if len(run_info_dict["Loss"]) == args.prints_to_process: if len(run_info_dict["Loss"]) == args.prints_to_process:
torch.save(run_info_dict, torch.save(run_info_dict,
str(args.has_ext) + "_" + str(args.opt_level) + "_" + str(args.has_ext) + "_" + str(args.opt_level) + "_" +
str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32)) str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
str(args.fused_adam))
quit() quit()
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/" DATADIR="/home/mcarilli/Desktop/pt18data/apex/examples/imagenet/bare_metal_train_val/"
BASE_CMD="python -m multiproc python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5" BASE_CMD="python -m multiproc python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5"
ADAM_ARGS="--opt-level O2 --keep-batchnorm-fp32 False --fused-adam"
print_banner() { print_banner() {
printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n" printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n"
} }
...@@ -42,14 +44,26 @@ do ...@@ -42,14 +44,26 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
print_banner "$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR"
set -x set -x
$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR
set +x set +x
done done
done done
done done
# Handle FusedAdam separately due to limited support.
# FusedAdam will not be tested for bitwise accuracy against the Python implementation.
# The L0 tests already do so. These tests are here to ensure that it actually runs,
# and get an idea of performance.
for loss_scale in "${loss_scales[@]}"
do
print_banner "${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR"
set -x
${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR
set +x
done
pushd ../../.. pushd ../../..
python setup.py install python setup.py install
popd popd
...@@ -60,14 +74,16 @@ do ...@@ -60,14 +74,16 @@ do
do do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
print_banner "$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} $DATADIR" print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR"
set -x set -x
$BASE_CMD --opt-level $opt_level ${loss_scale} ${keep_batchnorm} $DATADIR ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR
set +x set +x
done done
done done
done done
print_banner "Checking for bitwise accuracy between Python-only and cpp/cuda extension installs"
for opt_level in "${opt_levels[@]}" for opt_level in "${opt_levels[@]}"
do do
for loss_scale in "${loss_scales[@]}" for loss_scale in "${loss_scales[@]}"
...@@ -75,7 +91,7 @@ do ...@@ -75,7 +91,7 @@ do
for keep_batchnorm in "${keep_batchnorms[@]}" for keep_batchnorm in "${keep_batchnorms[@]}"
do do
set -x set -x
python compare.py --opt-level $opt_level ${loss_scale} ${keep_batchnorm} python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm}
set +x set +x
done done
done done
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment