Unverified Commit 2312f04b authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support migration to FP16 optimizer (#249)

* Debugging

* Fix step() bug; Make step timing optional

* Remove unnecessary changes

* Format fixes

* Replace list with scalar variable

* Remove redundant code

* Fix typo
parent f73d717c
......@@ -524,6 +524,7 @@ class DeepSpeedLight(Module):
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
optimizer = FP16_Optimizer(
optimizer,
dynamic_loss_scale=True,
......@@ -531,7 +532,8 @@ class DeepSpeedLight(Module):
dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu,
clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
fused_adam_legacy=self.optimizer_legacy_fusion(),
timers=timers)
else:
logger.info('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale()))
......@@ -848,27 +850,30 @@ class DeepSpeedLight(Module):
if self.wall_clock_breakdown():
self.timers('step').stop()
self.timers('step_microstep').stop()
self.timers.log([
timer_names = [
'forward_microstep',
'backward_microstep',
'backward_inner_microstep',
'backward_allreduce_microstep',
'step_microstep'
],
memory_breakdown=self.memory_breakdown())
]
self.timers.log(names=timer_names, memory_breakdown=self.memory_breakdown())
# Log timing
if self.is_gradient_accumulation_boundary():
if self.tensorboard_enabled() and torch.distributed.get_rank(
) == 0: # this is done before the log because log resets timers
self.summary_events = [(f'Train/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/elapsed_time_ms_forward', self.timers('forward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward', self.timers('backward').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_inner', self.timers('backward_inner').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_backward_allreduce', self.timers('backward_allreduce').elapsed(reset=False) * 1000.0, self.sample_count), \
(f'Train/Samples/elapsed_time_ms_step', self.timers('step').elapsed(reset=False) * 1000.0, self.sample_count)
]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()
if self.wall_clock_breakdown():
self.timers.log([
'forward',
'backward',
......
......@@ -29,9 +29,11 @@ class FP16_Optimizer(object):
verbose=True,
mpu=None,
clip_grad=0.0,
fused_adam_legacy=False):
fused_adam_legacy=False,
timers=None):
self.fused_adam_legacy = fused_adam_legacy
self.timers = timers
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
......@@ -158,6 +160,20 @@ class FP16_Optimizer(object):
p.data = q.data
return self.overflow
def start_timers(self, name_list):
if self.timers is not None:
for name in name_list:
self.timers(name).start()
def stop_timers(self, name_list):
if self.timers is not None:
for name in name_list:
self.timers(name).stop()
def log_timers(self, name_list):
if self.timers is not None:
self.timers.log(name_list)
def step(self, closure=None):
"""
Not supporting closure.
......@@ -166,9 +182,16 @@ class FP16_Optimizer(object):
if self.fused_adam_legacy:
return self.step_fused_adam()
COMPUTE_NORM = "compute_norm"
OVERFLOW_CHECK = 'overflow_check'
OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
UNSCALE_AND_CLIP = 'unscale_and_clip'
BASIC_STEP = 'basic_step'
UPDATE_FP16 = 'update_fp16'
STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]
# First compute norm for all group so we know if there is overflow
grads_groups_flat = []
norm_groups = []
for i, group in enumerate(self.fp16_groups):
data_type = self.fp32_groups_flat[i].dtype
......@@ -183,33 +206,46 @@ class FP16_Optimizer(object):
self.fp32_groups_flat[i].grad = grads_groups_flat[i]
norm_groups.append(get_grad_norm(self.fp32_groups_flat, mpu=self.mpu))
self.start_timers([COMPUTE_NORM])
all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
self.stop_timers([COMPUTE_NORM])
self.start_timers([OVERFLOW_CHECK])
self.overflow = self.overflow_checker.check_using_norm([all_groups_norm])
self.stop_timers([OVERFLOW_CHECK])
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(
prev_scale,
self.cur_scale))
print("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
return self.overflow
self.unscale_and_clip_grads(grads_groups_flat, norm_groups)
self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.stop_timers([UNSCALE_AND_CLIP])
self.start_timers([BASIC_STEP])
self.optimizer.step()
self.stop_timers([BASIC_STEP])
#get rid of the fp32 gradients. Not needed anymore
for group in self.fp32_groups_flat:
group.grad = None
for i in range(len(norm_groups)):
self.start_timers([UPDATE_FP16])
for i in range(len(self.fp16_groups)):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)
self.stop_timers([UPDATE_FP16])
self.log_timers(STEP_TIMERS)
return self.overflow
......@@ -317,6 +353,10 @@ class FP16_Optimizer(object):
state_dict['clip_grad'] = self.clip_grad
return state_dict
def refresh_fp32_params(self):
for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
current.data.copy_(saved.data)
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment