Unverified Commit 2f626978 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Pipeline warnings and checkpoint portability (#588)

* Switch from deprecated allreduce interface.

* Make pipeline checkpoint files portable.
parent e8b126d9
...@@ -33,6 +33,7 @@ from deepspeed.utils import logger, log_dist ...@@ -33,6 +33,7 @@ from deepspeed.utils import logger, log_dist
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from .pipe.module import PipelineModule
from .utils import ensure_directory_exists from .utils import ensure_directory_exists
from ..ops.op_builder import UtilsBuilder from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam from ..ops.adam import DeepSpeedCPUAdam
...@@ -1355,6 +1356,10 @@ class DeepSpeedEngine(Module): ...@@ -1355,6 +1356,10 @@ class DeepSpeedEngine(Module):
logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}') logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage) checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
if isinstance(self.module, PipelineModule):
# Pipeline parallelism uses this to load its own checkpoint files.
self._curr_ckpt_path = os.path.join(load_dir, tag)
self.load_module_state_dict(state_dict=checkpoint['module'], self.load_module_state_dict(state_dict=checkpoint['module'],
strict=load_module_strict) strict=load_module_strict)
if not self.zero_optimization(): if not self.zero_optimization():
...@@ -1522,8 +1527,8 @@ class DeepSpeedEngine(Module): ...@@ -1522,8 +1527,8 @@ class DeepSpeedEngine(Module):
save_path = self._get_ckpt_name(save_dir, tag) save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides # A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict() # module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns self._curr_save_path. # then instead just returns None.
self._curr_save_path = os.path.dirname(save_path) self._curr_ckpt_path = os.path.join(save_dir, tag)
state = { state = {
'module': 'module':
......
...@@ -52,6 +52,9 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -52,6 +52,9 @@ class PipelineEngine(DeepSpeedEngine):
super().__init__(*super_args, **super_kwargs) super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert isinstance(self.module, PipelineModule), "model must base PipelineModule"
# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
# pipeline step for logging # pipeline step for logging
self.log_batch_step_id = -1 self.log_batch_step_id = -1
...@@ -546,7 +549,7 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -546,7 +549,7 @@ class PipelineEngine(DeepSpeedEngine):
# The last stage just runs backward on the loss using DeepSpeed's typical # The last stage just runs backward on the loss using DeepSpeed's typical
# mechanisms. # mechanisms.
if self.is_last_stage(): if self.is_last_stage():
super().backward(self.loss, allreduce_gradients=False) super().backward(self.loss)
self.mem_status('AFTER BWD') self.mem_status('AFTER BWD')
return return
...@@ -1100,31 +1103,31 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -1100,31 +1103,31 @@ class PipelineEngine(DeepSpeedEngine):
is ``save_state_dict()``. is ``save_state_dict()``.
Returns: Returns:
str: The directory path where the checkpoint was saved. None
""" """
assert isinstance(self.module, PipelineModule) assert isinstance(self.module, PipelineModule)
assert self._curr_save_path is not None, \ assert self._curr_ckpt_path is not None, \
"PipelineEngine expects module_state_dict() to be called from save_checkpoint()" "PipelineEngine expects module_state_dict() to be called from save_checkpoint()"
self.module.save_state_dict(self._curr_save_path) self.module.save_state_dict(self._curr_ckpt_path)
return self._curr_save_path return None
def load_module_state_dict(self, state_dict, strict=True): def load_module_state_dict(self, state_dict, strict=True):
"""Override hack to instead use a directory path. """Override hack to instead use a directory path.
This is important because pipeline models checkpoint by layer instead of rank. This is important because pipeline models checkpoint by layer instead of rank.
If ``state_dict`` is not a ``str``, we revert to ``super()`` expecting a ``dict``. If ``state_dict`` is not ``None`` or a ``str``, we revert to ``super()`` expecting a ``dict``.
Args: Args:
state_dict (str): Path to the directory for checkpoint. state_dict (str, None): unused
strict (bool, optional): Strict state loading. Defaults to True. strict (bool, optional): Strict state loading. Defaults to True.
""" """
if not isinstance(state_dict, str): if (state_dict is not None) and (not isinstance(state_dict, str)):
super().load_module_state_dict(state_dict, strict) super().load_module_state_dict(state_dict, strict)
return return
self.module.load_state_dir(state_dict, strict=strict) self.module.load_state_dir(load_dir=self._curr_ckpt_path, strict=strict)
# A map of PipeInstruction types to methods. Each method will be executed with the # A map of PipeInstruction types to methods. Each method will be executed with the
# kwargs provided to the PipeInstruction from the scheduler. # kwargs provided to the PipeInstruction from the scheduler.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment