Commit cc6ba534 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add DeepSpeed config script

parent e98c202d
import argparse
import json
parser = argparse.ArgumentParser(description='''Outputs a DeepSpeed
configuration file to
stdout''')
p = parser.add_argument_group("Optimizer")
p.add_argument("--optimizer", default=None,
help='''Choice of optimizer. Choose between "Adam" or
"OneBitAdam"''')
p.add_argument("--lr", dest="lr", type=float, default=1e-3,
help="The learning rate")
p.add_argument("--freeze_step", type=int, default=100,
help='''Number of warm-up steps before 1-bit compression
activates. Applies only when --optimizer is
OneBitAdam''')
p.add_argument("--cuda_aware", action="store_true", default=False,
help='''Indicates that the underlying MPI library supports
CUDA-Aware communication. Applies only when
--optimizer is OneBitAdam''')
p.add_argument("--comm_backend_name", type=str, default="nccl",
help='''Communication implementation for OneBitAdam. Choose
from nccl and mpi''')
p.add_argument("--eps", type=float, default=1e-8,
help="Adam epsilon parameter")
sched = parser.add_argument_group("Scheduler")
sched.add_argument(
"--scheduler", type=str, default=None,
help='''The LR scheduler. Choose from "LRRangeTest", "OneCycle", WarmupLR,
and WarmupDecayLR". Documentation for each can be found here:
deepspeed.readthedocs.io/en/latest/schedulers.html'''
)
range_test = sched.add_argument_group("LRRangeTest")
range_test.add_argument(
"--lr_range_test_min_lr", type=float, default=1e-04
)
range_test.add_argument(
"--lr_range_test_step_size", type=int, default=2000
)
range_test.add_argument(
"--lr_range_test_step_rate", type=float, default=1.0
)
range_test.add_argument(
"--lr_range_test_staircase", type=bool, default=False
)
cycle = sched.add_argument_group("OneCycle")
cycle.add_argument(
"--cycle_min_lr", type=float, default=1e-06
)
cycle.add_argument(
"--cycle_max_lr", type=float, default=1e-03
)
cycle.add_argument(
"--cycle_decay_lr_rate", type=float, default=0
)
cycle.add_argument(
"--cycle_first_step_size", type=int, default=2000
)
cycle.add_argument(
"--cycle_second_step_size", type=int, default=None
)
cycle.add_argument(
"--cycle_first_stair_count", type=int, default=0
)
cycle.add_argument(
"--cycle_second_stair_count", type=int, default=0
)
cycle.add_argument(
"--cycle_decay_step_size", type=int, default=0
)
cycle.add_argument(
"--cycle_momentum", type=bool, default=True
)
cycle.add_argument(
"--cycle_min_mom", type=float, default=0.8
)
cycle.add_argument(
"--cycle_max_mom", type=float, default=0.9
)
cycle.add_argument(
"--cycle_decay_mom_rate", type=float, default=0
)
warmup = sched.add_argument_group("WarmupLR")
warmup.add_argument(
"--warmup_min_lr", type=float, default=0.
)
warmup.add_argument(
"--warmup_max_lr", type=float, default=0.001
)
warmup.add_argument(
"--warmup_num_steps", type=int, default=1000
)
warmup_decay = sched.add_argument_group("WarmupDecayLR")
warmup_decay.add_argument(
"--warmup_decay_total_num_steps", type=int, default=1e05
)
warmup_decay.add_argument(
"--warmup_decay_min_lr", type=float, default=0.
)
warmup_decay.add_argument(
"--warmup_decay_max_lr", type=float, default=0.001
)
warmup_decay.add_argument(
"--warmup_decay_num_steps", type=int, default=1000
)
p = parser.add_argument_group("16-bit training")
p.add_argument("--fp16", dest="fp16", action="store_true", default=False,
help="""Whether to train in 16-bit/mixed-precision mode.
Mutually exclusive with --amp""")
p = parser.add_argument_group("AMP")
p.add_argument("--amp", action="store_true", default=False,
help="""Whether to enable AMP training. Mutually exclusive with
--fp16""")
p.add_argument("--opt_level", action="store_true", default=False,
help="""AMP optimization level. One of "O0", "O1", "O2", or
"O3".""")
p = parser.add_argument_group("Activation checkpointing")
p.add_argument("--partition_activations", action="store_true",
default=False,
help="Activation checkpointing")
p.add_argument("--cpu_checkpointing", action="store_true", default=False,
help="Offload activation checkpoints to CPU")
p.add_argument("--profile", action="store_true",
default=False,
help="Whether to profile activation checkpointing")
p = parser.add_argument_group("ZeRO optimization")
p.add_argument("--zero_stage", type=int, default=2,
help="ZeRO optimizer stage")
p.add_argument("--allgather_partitions", action="store_true",
default=False,
help='''Allgather collective vs. broadcast collectives
for parameter gathering''')
p.add_argument("--allgather_bucket_size", type=int, default=1e9,
help="Number of elements allgathered at one time")
p.add_argument("--overlap_comm", action="store_true", default=False,
help='''Whether to overlap gradient reduction and backward
pass''')
p.add_argument("--reduce_scatter", action="store_true", default=False,
help="Use reduce to average gradients")
p.add_argument("--reduce_bucket_size", type=int, default=1e9,
help="Number of elements reduced at one time")
p.add_argument("--offload_optimizer", action="store_true", default=False,
help='''Offload optimizer state to CPU. Valid only when
--stage is 2 or 3''')
p.add_argument("--pin_memory", action="store_true", default=False,
help="Speeds up offloaded throughput at the cost of memory")
p = parser.add_argument_group("Flops profiler")
p.add_argument("--flops_profiler", action="store_true", default=False,
help="Whether to enable the DeepSpeed Flops Profiler")
p.add_argument("--profile_step", type=int, default=1,
help='''The global training step at which to run the flops
profiler. Has no effect unless --flops_profiler is
given''')
p.add_argument("--module_depth", type=int, default=-1,
help='''Depth to which aggregated module info is printed. Has
no effect unless --flops_profiler is given''')
p.add_argument("--top_modules", type=int, default=3,
help='''Number of top modules to print in the aggregated
profile. Has no effect unless --flops_profiler is
given''')
p.add_argument("--detailed_flops_profile", action="store_true",
default=False,
help='''Whether the flops_profiler should be detailed. Has
no effect unless --flops_profiler is given''')
args = parser.parse_args()
d = {}
# Optimizer settings
if(args.optimizer is not None):
optimizer = {}
optimizer["type"] = args.optimizer
params = {}
params["lr"] = args.lr
params["eps"] = args.eps
if(args.optimizer == "OneBitAdam"):
params["freeze_step"] = args.freeze_step
params["cuda_aware"] = args.cuda_aware
params["comm_backend_name"] = args.comm_backend_name
optimizer["params"] = params
d["optimizer"] = optimizer
# LR scheduler
if(args.scheduler is not None):
scheduler = {}
scheduler["type"] = args.scheduler
params = {}
if(args.scheduler == "LRRangeTest"):
params["lr_range_test_min_lr"] = args.lr_range_test_min_lr
params["lr_range_test_step_size"] = args.lr_range_test_step_size
params["lr_range_test_step_rate"] = args.lr_range_test_step_rate
params["lr_range_test_staircase"] = args.lr_range_test_staircase
elif(args.scheduler == "OneCycle"):
params["cycle_min_lr"] = args.cycle_min_lr
params["cycle_max_lr"] = args.cycle_max_lr
params["decay_lr_rate"] = args.cycle_decay_lr_rate
params["cycle_first_step_size"] = args.cycle_first_step_size
params["cycle_second_step_size"] = args.cycle_second_step_size
params["cycle_first_stair_count"] = args.cycle_first_stair_count
params["cycle_second_stair_count"] = args.cycle_second_stair_count
params["cycle_momentum"] = args.cycle_momentum
params["cycle_min_mom"] = args.cycle_min_mom
params["cycle_max_mom"] = args.cycle_max_mom
params["decay_mom_rate"] = args.cycle_decay_mom_rate
elif(args.scheduler == "WarmupLR"):
params["warmup_min_lr"] = args.warmup_min_lr
params["warmup_max_lr"] = args.warmup_max_lr
params["warmup_num_steps"] = args.warmup_num_steps
elif(args.scheduler == "WarmupDecayLR"):
params["total_num_steps"] = args.warmup_decay_total_num_steps
params["warmup_min_lr"] = args.warmup_decay_min_lr
params["warmup_max_lr"] = args.warmup_decay_max_lr
else:
raise ValueError("Invalid scheduler")
scheduler["params"] = params
d["scheduler"] = scheduler
# 16-bit training
if(args.fp16 and args.amp):
raise ValueError("--fp16 and --amp cannot both be enabled")
elif(args.amp):
amp = {}
amp["enabled"] = True
amp["pin_memory"] = args.opt_level
d["amp"] = amp
else:
fp16 = {}
fp16["enabled"] = args.fp16
d["fp16"] = fp16
# Activation checkpointing
ac = {}
ac["partition_activations"] = args.partition_activations
ac["cpu_checkpointing"] = args.cpu_checkpointing
ac["profile"] = args.profile
d["activation_checkpointing"] = ac
# ZeRO optimization
zo = {}
zo["stage"] = args.zero_stage
zo["allgather_partitions"] = args.allgather_partitions
zo["allgather_bucket_size"] = args.allgather_bucket_size
zo["reduce_bucket_size"] = args.reduce_bucket_size
zo["overlap_comm"] = args.overlap_comm
zo["reduce_scatter"] = args.reduce_scatter
if(args.offload_optimizer):
oo = {}
oo["device"] = "cpu"
oo["pin_memory"] = args.pin_memory
zo["offload_optimizer"] = oo
d["zero_optimization"] = zo
# Flops Profiler
flops_profiler = {}
flops_profiler["enabled"] = args.flops_profiler
flops_profiler["profile_step"] = args.profile_step
flops_profiler["module_depth"] = args.module_depth
flops_profiler["top_modules"] = args.top_modules
flops_profiler["detailed"] = args.detailed_flops_profile
d ["flops_profiler"] = flops_profiler
print(json.dumps(d, indent=2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment