Unverified Commit 311795d0 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Control ZeRO wall clock timers (#849)



* Control ZeRO wall clock timers

* Disable more ZeRO3 debug prints
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 7925d0c3
...@@ -722,6 +722,7 @@ class DeepSpeedEngine(Module): ...@@ -722,6 +722,7 @@ class DeepSpeedEngine(Module):
zero_stage = self.zero_optimization_stage() zero_stage = self.zero_optimization_stage()
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0]) log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true" assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
timers = self.timers if self.wall_clock_breakdown() else None
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
...@@ -740,7 +741,7 @@ class DeepSpeedEngine(Module): ...@@ -740,7 +741,7 @@ class DeepSpeedEngine(Module):
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer( optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer, optimizer,
timers=self.timers, timers=timers,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(), dynamic_loss_args=self.dynamic_loss_scale_args(),
...@@ -762,7 +763,7 @@ class DeepSpeedEngine(Module): ...@@ -762,7 +763,7 @@ class DeepSpeedEngine(Module):
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3( optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
self.module, self.module,
optimizer, optimizer,
timers=self.timers, timers=timers,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(), dynamic_loss_args=self.dynamic_loss_scale_args(),
......
...@@ -1326,6 +1326,26 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1326,6 +1326,26 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.norm_for_param_grads = {} self.norm_for_param_grads = {}
self.local_overflow = False self.local_overflow = False
def log_timers(self, timer_names):
if self.timers is None:
return
self.timers.log(names=list(timer_names))
def start_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).start()
def stop_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names:
self.timers(name).stop()
def step(self, closure=None): def step(self, closure=None):
""" """
Not supporting closure. Not supporting closure.
...@@ -1340,7 +1360,10 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1340,7 +1360,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
# First compute norm for all group so we know if there is overflow # First compute norm for all group so we know if there is overflow
self.check_overflow() self.check_overflow()
timers = self.timers OPTIMIZER_ALLGATHER = 'optimizer_allgather'
OPTIMIZER_GRADIENTS = 'optimizer_gradients'
OPTIMIZER_STEP = 'optimizer_step'
timer_names = [OPTIMIZER_ALLGATHER, OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]
prev_scale = self.loss_scale prev_scale = self.loss_scale
self._update_scale(self.overflow) self._update_scale(self.overflow)
...@@ -1359,15 +1382,11 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1359,15 +1382,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
"reducing to {}".format(dist.get_rank(), "reducing to {}".format(dist.get_rank(),
prev_scale, prev_scale,
self.loss_scale)) self.loss_scale))
timers('optimizer_gradients').start() self.start_timers(timer_names)
timers('optimizer_gradients').stop() self.stop_timers(timer_names)
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
return return
timers('optimizer_gradients').start() self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = [] norm_groups = []
single_partition_grad_groups = [] single_partition_grad_groups = []
skip = False skip = False
...@@ -1409,10 +1428,9 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1409,10 +1428,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
single_partition_grad_groups.append(single_grad_partition) single_partition_grad_groups.append(single_grad_partition)
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_gradients').stop() self.stop_timers([OPTIMIZER_GRADIENTS])
#torch.set_num_threads(12) self.start_timers([OPTIMIZER_STEP])
timers('optimizer_step').start()
if self.deepspeed_adam_offload: if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam: if type(self.optimizer) == DeepSpeedCPUAdam:
...@@ -1436,12 +1454,12 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1436,12 +1454,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data) fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop() self.stop_timers([OPTIMIZER_STEP])
if self.cpu_offload: if self.cpu_offload:
self.reset_cpu_buffers() self.reset_cpu_buffers()
timers('optimizer_allgather').start() self.start_timers([OPTIMIZER_ALLGATHER])
#gather the updated weights from everyone #gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
...@@ -1474,7 +1492,7 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1474,7 +1492,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
dist.all_gather(shard_list, dist.all_gather(shard_list,
shard_list[partition_id], shard_list[partition_id],
group=self.dp_process_group) group=self.dp_process_group)
timers('optimizer_allgather').stop() self.stop_timers([OPTIMIZER_ALLGATHER])
# TODO: we probably don't need this? just to be safe # TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)): for i in range(len(norm_groups)):
...@@ -1483,11 +1501,9 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1483,11 +1501,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
for p, q in zip(self.fp16_groups[i], updated_params): for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data p.data = q.data
timers.log( self.log_timers(timer_names)
names=['optimizer_gradients',
'optimizer_step',
'optimizer_allgather'])
see_memory_usage('After zero_optimizer step') see_memory_usage('After zero_optimizer step')
return return
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups): def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
......
...@@ -580,7 +580,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -580,7 +580,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
elastic_checkpoint=False): elastic_checkpoint=False):
see_memory_usage("Stage 3 intialize begining", force=True) see_memory_usage("Stage 3 intialize beginning", force=True)
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Reduce bucket size {reduce_bucket_size}")
...@@ -628,7 +628,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -628,7 +628,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
############################################################################ ############################################################################
see_memory_usage("Before Partitioned Parameter Coordinator", force=True) see_memory_usage("Before Partitioned Parameter Coordinator", force=False)
fetch_stream = torch.cuda.Stream() if self.overlap_comm else None fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
self.param_coordinator = PartitionedParameterCoordinator( self.param_coordinator = PartitionedParameterCoordinator(
...@@ -636,7 +636,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -636,7 +636,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
max_reuse_distance_in_numel=int(max_reuse_distance), max_reuse_distance_in_numel=int(max_reuse_distance),
max_available_parameters_in_numel=int(max_live_parameters)) max_available_parameters_in_numel=int(max_live_parameters))
see_memory_usage("After Partitioned Parameter Coordinator", force=True) see_memory_usage("After Partitioned Parameter Coordinator", force=False)
#self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
#-------------Stage 3 Setup-------------------# #-------------Stage 3 Setup-------------------#
...@@ -711,20 +711,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -711,20 +711,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.sub_group_to_group_id = {} self.sub_group_to_group_id = {}
see_memory_usage("Before creating fp16 partitions", force=True) see_memory_usage("Before creating fp16 partitions", force=False)
#self._create_fp16_partitions() #self._create_fp16_partitions()
self._create_fp16_partitions_with_defragmentation() self._create_fp16_partitions_with_defragmentation()
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
force=True) force=False)
see_memory_usage("Before creating fp32 partitions", force=True) see_memory_usage("Before creating fp32 partitions", force=False)
self._create_fp32_partitions() self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=True) see_memory_usage("After creating fp32 partitions", force=False)
see_memory_usage("Before initializing optimizer states", force=True) see_memory_usage("Before initializing optimizer states", force=False)
self.initialize_optimizer_states() self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True) see_memory_usage("After initializing optimizer states", force=False)
if dist.get_rank() == 0: if dist.get_rank() == 0:
logger.info(f"optimizer state initialized") logger.info(f"optimizer state initialized")
...@@ -767,11 +767,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -767,11 +767,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
#Largest partitioned param #Largest partitioned param
largest_partitioned_param_numel = self._get_largest_partitioned_numel() largest_partitioned_param_numel = self._get_largest_partitioned_numel()
see_memory_usage(f"Before Set Grad positions", force=True) see_memory_usage(f"Before Set Grad positions", force=False)
self.grad_position = {} self.grad_position = {}
self.set_grad_positions() self.set_grad_positions()
see_memory_usage(f"Before CPU Offload initialization", force=True) see_memory_usage(f"Before CPU Offload initialization", force=False)
self.grads_in_partition = None self.grads_in_partition = None
...@@ -785,7 +785,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -785,7 +785,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.temp_grad_gpu_buffer = torch.zeros( self.temp_grad_gpu_buffer = torch.zeros(
largest_partitioned_param_numel, largest_partitioned_param_numel,
device=torch.cuda.current_device()).half() device=torch.cuda.current_device()).half()
see_memory_usage(f"After CPU Offload initialization", force=True) see_memory_usage(f"After CPU Offload initialization", force=False)
# stores if a partition has been reduced in this step # stores if a partition has been reduced in this step
self.is_partition_reduced = {} self.is_partition_reduced = {}
...@@ -1614,7 +1614,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -1614,7 +1614,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
see_memory_usage( see_memory_usage(
f"group {i} before creating {total_size} reduced gradients into partition", f"group {i} before creating {total_size} reduced gradients into partition",
force=True) force=False)
if self.cpu_offload_use_pin_memory: if self.cpu_offload_use_pin_memory:
self.grads_in_partition.append( self.grads_in_partition.append(
torch.zeros(int(total_size), torch.zeros(int(total_size),
...@@ -1627,7 +1627,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -1627,7 +1627,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
device=self.device)) device=self.device))
see_memory_usage( see_memory_usage(
f"group {i} after creating {total_size} reduced gradients into partition", f"group {i} after creating {total_size} reduced gradients into partition",
force=True) force=False)
for param in self.previous_reduced_grads: for param in self.previous_reduced_grads:
...@@ -2044,13 +2044,22 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -2044,13 +2044,22 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.local_overflow = False self.local_overflow = False
def log_timers(self, timer_names): def log_timers(self, timer_names):
if self.timers is None:
return
self.timers.log(names=list(timer_names)) self.timers.log(names=list(timer_names))
def start_timers(self, timer_names): def start_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names: for name in timer_names:
self.timers(name).start() self.timers(name).start()
def stop_timers(self, timer_names): def stop_timers(self, timer_names):
if self.timers is None:
return
for name in timer_names: for name in timer_names:
self.timers(name).stop() self.timers(name).stop()
...@@ -2210,7 +2219,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -2210,7 +2219,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
see_memory_usage('After zero_optimizer step', force=False) see_memory_usage('After zero_optimizer step', force=False)
print_rank_0(f"------------------Finishing Step-----------------------", print_rank_0(f"------------------Finishing Step-----------------------",
force=True) force=False)
return return
def _pre_step(self): def _pre_step(self):
...@@ -2327,7 +2336,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -2327,7 +2336,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
self.log_timers(timer_names) self.log_timers(timer_names)
see_memory_usage('After zero_optimizer step', force=True) see_memory_usage('After zero_optimizer step', force=False)
print_rank_0(f"------------------Finishing Step-----------------------") print_rank_0(f"------------------Finishing Step-----------------------")
def step(self, closure=None): def step(self, closure=None):
...@@ -2342,7 +2351,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object): ...@@ -2342,7 +2351,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
norm_groups = self._get_norm_groups() norm_groups = self._get_norm_groups()
timers = self.timers
timer_names = set() timer_names = set()
timer_names.add('optimizer_step') timer_names.add('optimizer_step')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment