"...text-generation-inference.git" did not exist on "cb0a29484d573125336a02dc2191479f18cacabe"
Unverified Commit 08c96a1b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

ZeRO-1 tune max-elems + bug fix (#532)

* zero-1 memory fix

* auto-tune max elems per comm to reduce padding/comm intervals

* clean-up and added previously missing reduction options

* fix testing backing to work with torch1.7
parent fdd81c30
...@@ -661,7 +661,7 @@ class DeepSpeedEngine(Module): ...@@ -661,7 +661,7 @@ class DeepSpeedEngine(Module):
def _configure_zero_optimizer(self, optimizer): def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage() zero_stage = self.zero_optimization_stage()
logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage)) logger.info('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage))
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
......
...@@ -68,12 +68,6 @@ class CheckOverflow(object): ...@@ -68,12 +68,6 @@ class CheckOverflow(object):
return bool(overflow) return bool(overflow)
def check(self, param_groups=None): def check(self, param_groups=None):
#TODO: what's the equivalent here? do we need this?
# for group in self.fp32_from_fp32_groups:
# for param in group:
# params.append(param)
params = [] params = []
if param_groups is None: if param_groups is None:
params = self.params params = self.params
......
This diff is collapsed.
...@@ -41,6 +41,8 @@ def distributed_test(world_size=2, backend='nccl'): ...@@ -41,6 +41,8 @@ def distributed_test(world_size=2, backend='nccl'):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if 'args' in func_kwargs:
func_kwargs['args'].local_rank = local_rank
run_func(*func_args, **func_kwargs) run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs): def dist_launcher(num_procs, *func_args, **func_kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment