Unverified Commit 39013dd2 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

save_fp16_model consolidated for zero3 (#893)


Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
parent 7531c6bf
...@@ -388,6 +388,9 @@ class DeepSpeedEngine(Module): ...@@ -388,6 +388,9 @@ class DeepSpeedEngine(Module):
def zero_param_persistence_threshold(self): def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold return self._config.zero_config.param_persistence_threshold
def zero_gather_fp16_weights_on_model_save(self):
return self._config.zero_config.gather_fp16_weights_on_model_save
def fp16_enabled(self): def fp16_enabled(self):
return self._config.fp16_enabled return self._config.fp16_enabled
...@@ -1714,3 +1717,98 @@ class DeepSpeedEngine(Module): ...@@ -1714,3 +1717,98 @@ class DeepSpeedEngine(Module):
torch.save(zero_sd, zero_checkpoint_name) torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path) self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
def _zero3_consolidated_fp16_state_dict(self):
"""
Get a full non-partitioned state_dict with fp16 weights on cpu.
This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but:
1. consolidates the weights from different partitions on gpu0
2. works on one layer at a time to require as little gpu0 memory as possible, by
moving the already consolidated weights to cpu
3. takes care to keep the shared params shared when gradually copying the params to cpu
Returns:
a consolidated fp16 ``state_dict`` on cpu on rank 0, ``None`` on other ranks
"""
import deepspeed
if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
shared_weights = {}
def get_layer_state_dict(module, prefix=""):
# gather one layer at a time to be memory-efficient
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False))):
if torch.distributed.get_rank() == 0:
for name, param in module.named_parameters(recurse=False):
if param is None:
continue
key = prefix + name
# for shared weights we want to make sure not to unshare them when copying to cpu
data_ptr_id = param.storage().data_ptr()
if data_ptr_id in shared_weights:
# shared weights
# print(f"`{key}` is shared with `{shared_weights[data_ptr_id]}`")
state_dict[key] = state_dict[shared_weights[data_ptr_id]]
else:
state_dict[key] = param.detach().cpu()
shared_weights[data_ptr_id] = key
#print(f"param {name} {param.shape}")
#print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}")
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in module.named_buffers(recurse=False):
if buf is not None and name not in module._non_persistent_buffers_set:
state_dict[prefix + name] = buf.detach().cpu()
for name, child in module.named_children():
if child is not None:
get_layer_state_dict(child, prefix + name + ".")
see_memory_usage("before get_layer_state_dict", force=False)
get_layer_state_dict(self.module, prefix="")
see_memory_usage("after get_layer_state_dict", force=False)
return state_dict
def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
r"""Save fp16 model weights
This method saves the fp16 model weights at the desired destination.
Arguments:
save_dir: Required. Directory for saving the model
save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
Important: all processes must call this method and not just the process with rank 0. It is
because the processes need to work in sync to gather the weights. This method will hang
waiting to synchronize with other processes if it's called just for the process with rank 0.
"""
path = os.path.join(save_dir, save_filename)
if self.zero_optimization_partition_weights():
if self.zero_gather_fp16_weights_on_model_save():
# consolidation is expensive in time and memory and therefore isn't a default
state_dict = self._zero3_consolidated_fp16_state_dict()
else:
# the model will be bogus if not consolidated so don't confuse the user by saving it
logger.info(
f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False"
)
return
else:
state_dict = self.module.state_dict()
if torch.distributed.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)
...@@ -34,6 +34,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject): ...@@ -34,6 +34,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject):
self.param_persistence_threshold = None self.param_persistence_threshold = None
self.max_live_parameters = None self.max_live_parameters = None
self.max_reuse_distance = None self.max_reuse_distance = None
self.gather_fp16_weights_on_model_save = None
#Stage3 Specific Parameters #Stage3 Specific Parameters
self.prefetch_bucket_size = None self.prefetch_bucket_size = None
...@@ -150,3 +151,8 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject): ...@@ -150,3 +151,8 @@ class DeepSpeedZeroConfig(DeepSpeedConfigObject):
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT)
self.gather_fp16_weights_on_model_save = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT)
...@@ -99,6 +99,10 @@ ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT = 50000000 ...@@ -99,6 +99,10 @@ ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT = 50000000
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold' ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD = 'stage3_param_persistence_threshold'
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000 ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT = 100000
# gathers params for saving a model - inefficient but is required in certain situations
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save'
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False
ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_STAGE_DEFAULT,
...@@ -133,5 +137,7 @@ ZERO_OPTIMIZATION_DEFAULT = { ...@@ -133,5 +137,7 @@ ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE: ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE:
ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD: ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD:
ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT,
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE:
ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment