Commit ca810168 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add bfloat16 support to DeepSpeed config script

parent 90ce2be9
...@@ -129,11 +129,18 @@ warmup_decay.add_argument( ...@@ -129,11 +129,18 @@ warmup_decay.add_argument(
) )
p = parser.add_argument_group("16-bit training") p = parser.add_argument_group("Half-precision training (fp16)")
p.add_argument("--fp16", dest="fp16", action="store_true", default=False, p.add_argument("--fp16", dest="fp16", action="store_true", default=False,
help="""Whether to train in 16-bit/mixed-precision mode. help="""Whether to train in 16-bit/mixed-precision mode.
Mutually exclusive with --amp""") Mutually exclusive with --amp""")
p = parser.add_argument_group("Half-precision training (bfloat16)")
p.add_argument("--bfloat16", dest="bfloat16", action="store_true",
default=False,
help="""Whether to train in 16-bit bfloat16 mode. Mutually
exclusive with --amp and --fp16. Requires hardware
support""")
p = parser.add_argument_group("AMP") p = parser.add_argument_group("AMP")
p.add_argument("--amp", action="store_true", default=False, p.add_argument("--amp", action="store_true", default=False,
help="""Whether to enable AMP training. Mutually exclusive with help="""Whether to enable AMP training. Mutually exclusive with
...@@ -251,18 +258,22 @@ if(args.scheduler is not None): ...@@ -251,18 +258,22 @@ if(args.scheduler is not None):
d["scheduler"] = scheduler d["scheduler"] = scheduler
# 16-bit training # 16-bit training
if(args.fp16 and args.amp): if(sum([args.amp, args.fp16, args.bfloat16]) > 1):
raise ValueError("--fp16 and --amp cannot both be enabled") raise ValueError("Only one of --fp16, --amp, or --bfloat16 can be enabled")
elif(args.amp):
if(args.amp):
amp = {} amp = {}
amp["enabled"] = True amp["enabled"] = True
amp["pin_memory"] = args.opt_level amp["pin_memory"] = args.opt_level
d["amp"] = amp d["amp"] = amp
else: elif(args.fp16):
fp16 = {} fp16 = {}
fp16["enabled"] = args.fp16 fp16["enabled"] = args.fp16
d["fp16"] = fp16 d["fp16"] = fp16
elif(args.bfloat16):
bfloat16 = {}
bfloat16["enabled"] = args.bfloat16
d["bfloat16"] = bfloat16
# Activation checkpointing # Activation checkpointing
ac = {} ac = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment