Unverified Commit 7ccc9daf authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Support loading and saving ZeRO checkpoints with changing DP degree (#240)



* Support saving and loading ZeRO checkpoints on different data
parallelism degree.

* Fix formatting

* Support checkpoint with varying GPU count in ZeRO stage 1

* Fix formatting

* Formatting fixes

* Update model tests

* Remove pprint

* Minor fix

* Fix formatting

* Update model tests
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 366d8816
...@@ -121,6 +121,8 @@ class DeepSpeedLight(Module): ...@@ -121,6 +121,8 @@ class DeepSpeedLight(Module):
self.gradient_average = True self.gradient_average = True
self.warn_unscaled_loss = True self.warn_unscaled_loss = True
self.config_params = config_params self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
if dist_init_required is None: if dist_init_required is None:
dist_init_required = not dist.is_initialized() dist_init_required = not dist.is_initialized()
...@@ -289,9 +291,6 @@ class DeepSpeedLight(Module): ...@@ -289,9 +291,6 @@ class DeepSpeedLight(Module):
def zero_overlap_comm(self): def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm return self._config.zero_config.overlap_comm
def zero_max_elements_per_comm(self):
return self._config.zero_max_elements_per_comm
def zero_optimization_stage(self): def zero_optimization_stage(self):
return self._config.zero_optimization_stage return self._config.zero_optimization_stage
...@@ -307,6 +306,9 @@ class DeepSpeedLight(Module): ...@@ -307,6 +306,9 @@ class DeepSpeedLight(Module):
def zero_contiguous_gradients(self): def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients return self._config.zero_config.contiguous_gradients
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def allgather_size(self): def allgather_size(self):
return self._config.allgather_size return self._config.allgather_size
...@@ -472,10 +474,12 @@ class DeepSpeedLight(Module): ...@@ -472,10 +474,12 @@ class DeepSpeedLight(Module):
if self.mpu is None: if self.mpu is None:
self.data_parallel_group = _initialize_parameter_parallel_groups() self.data_parallel_group = _initialize_parameter_parallel_groups()
self.dp_world_size = dist.get_world_size() self.dp_world_size = dist.get_world_size()
self.mp_world_size = 1
self.broadcast_src_rank = 0 self.broadcast_src_rank = 0
else: else:
self.data_parallel_group = self.mpu.get_data_parallel_group() self.data_parallel_group = self.mpu.get_data_parallel_group()
self.dp_world_size = self.mpu.get_data_parallel_world_size() self.dp_world_size = self.mpu.get_data_parallel_world_size()
self.mp_world_size = self.mpu.get_model_parallel_world_size()
self.broadcast_src_rank = _get_global_rank( self.broadcast_src_rank = _get_global_rank(
self.mpu.get_data_parallel_group(), self.mpu.get_data_parallel_group(),
0) 0)
...@@ -1057,18 +1061,19 @@ class DeepSpeedLight(Module): ...@@ -1057,18 +1061,19 @@ class DeepSpeedLight(Module):
def load_module_state_dict(self, state_dict, strict=True): def load_module_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
def _get_zero_ckpt_name(self, checkpoints_path, tag): def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank):
filename = 'zero_pp_rank_{}'.format(dp_rank)
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
filename = 'zero_pp_rank_{}'.format(pp_rank)
zero_ckpt_name = os.path.join( zero_ckpt_name = os.path.join(
checkpoints_path, checkpoints_path,
str(tag), str(tag),
filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt') filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt')
return zero_ckpt_name return zero_ckpt_name
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank)
def _get_ckpt_name(self, checkpoints_path, tag): def _get_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
...@@ -1144,13 +1149,17 @@ class DeepSpeedLight(Module): ...@@ -1144,13 +1149,17 @@ class DeepSpeedLight(Module):
self.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] self.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
self.global_steps = checkpoint['global_steps'] self.global_steps = checkpoint['global_steps']
self.skipped_steps = checkpoint['skipped_steps'] self.skipped_steps = checkpoint['skipped_steps']
self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
deepspeed_states = [ deepspeed_states = [
'module', 'module',
'optimizer', 'optimizer',
'lr_scheduler', 'lr_scheduler',
'csr_tensor_module_names', 'csr_tensor_module_names',
'skipped_steps', 'skipped_steps',
'global_steps' 'global_steps',
'dp_world_size',
'mp_world_size'
] ]
client_state = { client_state = {
key: value key: value
...@@ -1161,18 +1170,72 @@ class DeepSpeedLight(Module): ...@@ -1161,18 +1170,72 @@ class DeepSpeedLight(Module):
return load_path, client_state return load_path, client_state
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
zero_checkpoint_name = self._get_zero_ckpt_name(load_dir, tag) zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return
self.optimizer.load_state_dict(
state_dict_list=zero_sd_list,
load_optimizer_states=load_optimizer_states,
load_from_fp32_weights=self.zero_load_from_fp32_weights())
print(
f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}'
)
if not os.path.exists(zero_checkpoint_name): def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size):
logger.warn( zero_ckpt_names = []
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load' for dp_rank in range(dp_world_size):
.format(zero_checkpoint_name)) ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_rank=dp_rank)
zero_ckpt_names.append(ckpt_name)
return zero_ckpt_names
def _get_all_zero_checkpoint_names(self,
load_dir,
tag,
mp_world_size,
dp_world_size):
zero_ckpt_names = []
for mp_rank in range(mp_world_size):
mp_rank_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=dp_world_size)
zero_ckpt_names += mp_rank_ckpt_names
return zero_ckpt_names
def _get_all_zero_checkpoints(self, load_dir, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(
load_dir=load_dir,
tag=tag,
mp_rank=mp_rank,
dp_world_size=self.loaded_checkpoint_dp_world_size)
invalid_zero_ckpt_paths = []
for ckpt_name in zero_ckpt_names:
if not os.path.exists(ckpt_name):
invalid_zero_ckpt_paths.append(ckpt_name)
if len(invalid_zero_ckpt_paths) > 0:
logging.warn(
f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist"
)
return None return None
zero_sd = torch.load(zero_checkpoint_name, map_location='cpu') zero_sd_list = []
self.optimizer.load_state_dict(zero_sd['optimizer_state_dict'], for ckpt_name in zero_ckpt_names:
load_optimizer_states=load_optimizer_states) zero_sd_list.append(torch.load(ckpt_name, map_location='cpu'))
logger.info('loading zero checkpoint {}'.format(zero_checkpoint_name))
zero_optimizer_sd = [sd['optimizer_state_dict'] for sd in zero_sd_list]
print(
f"successfully loaded {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}"
)
return zero_optimizer_sd
def save_checkpoint(self, save_dir, tag, client_state={}): def save_checkpoint(self, save_dir, tag, client_state={}):
r"""Save training checkpoint r"""Save training checkpoint
...@@ -1232,6 +1295,10 @@ class DeepSpeedLight(Module): ...@@ -1232,6 +1295,10 @@ class DeepSpeedLight(Module):
self.skipped_steps, self.skipped_steps,
'global_steps': 'global_steps':
self.global_steps, self.global_steps,
'dp_world_size':
self.dp_world_size,
'mp_world_size':
self.mp_world_size
} }
state.update(client_state) state.update(client_state)
......
...@@ -23,6 +23,7 @@ ZeRO optimization should be enabled as: ...@@ -23,6 +23,7 @@ ZeRO optimization should be enabled as:
"contiguous_gradients" : [true|false] "contiguous_gradients" : [true|false]
"overlap_comm": [true|false], "overlap_comm": [true|false],
"reduce_bucket_size": 500000000 "reduce_bucket_size": 500000000
"load_from_fp32_weights": [true|false]
} }
} }
''' '''
...@@ -59,17 +60,24 @@ ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000 ...@@ -59,17 +60,24 @@ ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size' ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE = 'allgather_bucket_size'
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000 ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT = 500000000
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size' ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEPRECATED = 'allgather_size'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
ZERO_OPTIMIZATION_DEFAULT = { ZERO_OPTIMIZATION_DEFAULT = {
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_STAGE:
ZERO_OPTIMIZATION_STAGE_DEFAULT,
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, ZERO_OPTIMIZATION_REDUCE_SCATTER:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS: ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE: ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
} }
...@@ -84,6 +92,7 @@ class DeepSpeedZeroConfig(object): ...@@ -84,6 +92,7 @@ class DeepSpeedZeroConfig(object):
self.allgather_partitions = None self.allgather_partitions = None
self.allgather_bucket_size = None self.allgather_bucket_size = None
self.overlap_comm = None self.overlap_comm = None
self.load_from_fp32_weights = None
if ZERO_OPTIMIZATION in param_dict.keys(): if ZERO_OPTIMIZATION in param_dict.keys():
zero_config_dict = param_dict[ZERO_OPTIMIZATION] zero_config_dict = param_dict[ZERO_OPTIMIZATION]
...@@ -148,3 +157,7 @@ class DeepSpeedZeroConfig(object): ...@@ -148,3 +157,7 @@ class DeepSpeedZeroConfig(object):
zero_config_dict, zero_config_dict,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
self.load_from_fp32_weights = get_scalar_param(
zero_config_dict,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
...@@ -12,6 +12,7 @@ from torch.autograd import Variable ...@@ -12,6 +12,7 @@ from torch.autograd import Variable
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS
#Toggle this to true to enable correctness test #Toggle this to true to enable correctness test
#with gradient partitioning and without #with gradient partitioning and without
...@@ -61,8 +62,8 @@ def lcm(x, y): ...@@ -61,8 +62,8 @@ def lcm(x, y):
return x * y // gcd(x, y) return x * y // gcd(x, y)
#create a flat tensor aligned at the alignment boundary # create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment, pg): def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0 num_elements = 0
for tensor in tensor_list: for tensor in tensor_list:
num_elements = num_elements + tensor.numel() num_elements = num_elements + tensor.numel()
...@@ -83,11 +84,21 @@ def flatten_dense_tensors_aligned(tensor_list, alignment, pg): ...@@ -83,11 +84,21 @@ def flatten_dense_tensors_aligned(tensor_list, alignment, pg):
return _flatten_dense_tensors(padded_tensor_list) return _flatten_dense_tensors(padded_tensor_list)
def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
return (alignment - remainder) if remainder else remainder
def move_to_cpu(tensor_list): def move_to_cpu(tensor_list):
for tensor in tensor_list: for tensor in tensor_list:
tensor.data = tensor.data.cpu() tensor.data = tensor.data.cpu()
def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")
class FP16_DeepSpeedZeroOptimizer(object): class FP16_DeepSpeedZeroOptimizer(object):
""" """
DeepSpeedZeroOptimizer designed to reduce the memory footprint DeepSpeedZeroOptimizer designed to reduce the memory footprint
...@@ -195,10 +206,19 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -195,10 +206,19 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.all_reduce_print = False self.all_reduce_print = False
# padding on each partition for alignment purposes
self.groups_padding = []
# loop to deal with groups # loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify # push this group to list before modify
self.fp16_groups.append(param_group['params']) self.fp16_groups.append(param_group['params'])
# Record padding required to align group to world size
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
padding = get_alignment_padding(self.fp16_groups[i],
self.partition_count)
else:
padding = 0
self.groups_padding.append(padding)
#not sure why apex was cloning the weights before flattening #not sure why apex was cloning the weights before flattening
#removing cloning here #removing cloning here
...@@ -212,8 +232,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -212,8 +232,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.fp16_groups_flat.append( self.fp16_groups_flat.append(
flatten_dense_tensors_aligned( flatten_dense_tensors_aligned(
self.fp16_groups[i], self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group), dist.get_world_size(group=self.dp_process_group)).cuda(
self.dp_process_group).cuda(torch.cuda.current_device())) torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU") see_memory_usage(f"After flattening and moving param group {i} to GPU")
if dist.get_rank(group=self.dp_process_group) == 0: if dist.get_rank(group=self.dp_process_group) == 0:
...@@ -1114,8 +1134,7 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1114,8 +1134,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned( single_grad_partition = flatten_dense_tensors_aligned(
self.averaged_gradients[i], self.averaged_gradients[i],
int(self.partition_size[i]), int(self.partition_size[i])).to(
self.dp_process_group).to(
self.single_partition_of_fp32_groups[i].dtype) self.single_partition_of_fp32_groups[i].dtype)
else: else:
single_grad_partition = _flatten_dense_tensors( single_grad_partition = _flatten_dense_tensors(
...@@ -1336,6 +1355,38 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1336,6 +1355,38 @@ class FP16_DeepSpeedZeroOptimizer(object):
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale) cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains a single flattened tensor.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
for i, group in enumerate(groups_with_padding):
lean_length = group.numel() - self.groups_padding[i]
groups_without_padding.append(group[:lean_length])
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
lean_optimizer_state = self._get_state_without_padding(
self.optimizer.state[p],
self.groups_padding[i])
optimizer_groups_state.append(lean_optimizer_state)
return optimizer_groups_state
def state_dict(self): def state_dict(self):
""" """
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...@@ -1351,21 +1402,98 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1351,21 +1402,98 @@ class FP16_DeepSpeedZeroOptimizer(object):
state_dict['loss_scaler'] = self.loss_scaler state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict[
'single_partition_of_fp32_groups'] = self.single_partition_of_fp32_groups
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
state_dict['partition_count'] = self.partition_count state_dict['partition_count'] = self.partition_count
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.single_partition_of_fp32_groups)
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
return state_dict return state_dict
# Refresh the fp32 master params from the fp16 copies. # Restore base optimizer fp32 weights from checkpoint by:
def refresh_fp32_params(self): # 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
partition_id = dist.get_rank(group=self.dp_process_group)
merged_single_partition_of_fp32_groups = []
for i in range(len(self.single_partition_of_fp32_groups)):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
merged_single_partition_of_fp32_groups.append(dp_partitions[partition_id])
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
current.data.copy_(saved.data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group) partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data) fp32_partition.data.copy_(fp16_partitions[partition_id].data)
def load_state_dict(self, state_dict, load_optimizer_states=True): # Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
self._restore_from_fp16_weights()
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
return dp_partitions[partition_id]
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, all_state_dict):
base_optimizer_group_states = []
for i in range(len(self.optimizer.param_groups)):
partition_states = {}
all_partition_group_states = [
sd['base_optimizer_state'][i] for sd in all_state_dict
]
for key in all_partition_group_states[0].keys():
all_partition_states = [
all_states[key] for all_states in all_partition_group_states
]
partition_states[key] = self._partition_base_optimizer_state(
key,
all_partition_states)
base_optimizer_group_states.append(partition_states)
for i, group in enumerate(self.optimizer.param_groups):
p = group['params'][0]
for key, saved in base_optimizer_group_states[i].items():
current = self.optimizer.state[p][key]
current.data.copy_(saved.data)
def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
r"""Loading ZeRO checkpoint
Arguments:
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
Note that the number of saved partitions may differ from number of loading partitions to support
changing GPU count, specifically DP world size, between saving and loading checkpoints.
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
"""
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
...@@ -1382,12 +1510,12 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1382,12 +1510,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
""" """
# I think it should actually be ok to reload the optimizer before the model. # I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler'] self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict['overflow'] self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states: if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) self._restore_base_optimizer_state(state_dict_list)
# At this point, the optimizer's references to the model's fp32 parameters are up to date. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date. # The optimizer's hyperparameters and internal buffers are also up to date.
...@@ -1404,16 +1532,10 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1404,16 +1532,10 @@ class FP16_DeepSpeedZeroOptimizer(object):
# constructed in the same way as the one whose state_dict we are loading, the same master params # constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params. # are guaranteed to exist, so we can just copy_() from the saved master params.
if 'partition_count' in state_dict and state_dict[ if load_from_fp32_weights:
'partition_count'] == self.partition_count: self._restore_from_fp32_weights(state_dict_list)
# Use option 2
for current, saved in zip(self.single_partition_of_fp32_groups, state_dict['single_partition_of_fp32_groups']):
current.data.copy_(saved.data)
else: else:
# Use option 1 self._restore_from_fp16_weights()
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
def _handle_overflow(cpu_sum, x, i): def _handle_overflow(cpu_sum, x, i):
......
...@@ -8,6 +8,32 @@ from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups ...@@ -8,6 +8,32 @@ from deepspeed.pt.zero_utils import _initialize_parameter_parallel_groups
from deepspeed.pt.log_utils import log_dist, logger from deepspeed.pt.log_utils import log_dist, logger
from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
sub_partition_high_limit = (sub_partition_id + 1) * sub_partition_size
if sub_partition_high_limit <= flattened_lean_size:
return 0
else:
return min(sub_partition_size, sub_partition_high_limit - flattened_lean_size)
def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_count):
group_paddings = []
flattened_size = sum([tensor.numel() for tensor in tensor_list])
for i in range(sub_partition_count):
padding = get_alignment_padding(flattened_size, i, sub_partition_size)
group_paddings.append(padding)
logger.info("****Padding information*****")
logger.info(f"tensor_size = {flattened_size}")
logger.info(f"sub_partition_size = {sub_partition_size}")
logger.info(f"sub_partition_count = {sub_partition_count}")
for i, padding in enumerate(group_paddings):
logger.info(f"padding[{i}] = {padding}")
return group_paddings
def flatten_dense_tensors_sub_partition_aligned(tensor_list, def flatten_dense_tensors_sub_partition_aligned(tensor_list,
...@@ -164,6 +190,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -164,6 +190,9 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
local_rank = dist.get_rank(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group)
self.group_paddings = []
self.partition_count = dist.get_world_size(group=self.dp_process_group)
# loop to deal with groups # loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify # push this group to list before modify
...@@ -215,6 +244,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -215,6 +244,13 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
local_sub_partitions.append(fp32_sub_partition) local_sub_partitions.append(fp32_sub_partition)
self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions) self.local_sub_partitions_of_fp32_groups.append(local_sub_partitions)
# Compute sub_partition paddings
sub_partition_paddings = get_group_alignment_padding(
tensor_list=self.fp16_groups[i],
sub_partition_size=sub_partition_size,
sub_partition_count=num_comm_intervals * self.partition_count)
self.group_paddings.append(sub_partition_paddings)
# modify optimizer of have flat master weight # modify optimizer of have flat master weight
# self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it # self.single_partition_of_fp32_groups[i].requires_grad = True # keep this in case internal optimizer uses it
param_group['params'] = self.local_sub_partitions_of_fp32_groups[i] param_group['params'] = self.local_sub_partitions_of_fp32_groups[i]
...@@ -256,6 +292,22 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -256,6 +292,22 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
mpu=self.mpu, mpu=self.mpu,
zero_reduce_scatter=True) zero_reduce_scatter=True)
self._initialize_optimizer_states()
def _initialize_optimizer_states(self):
for group_idx, group in enumerate(self.local_sub_partitions_of_fp32_groups):
for idx, sub_partition_param in enumerate(group):
sub_partition_grad = torch.zeros(int(
self.sub_partition_sizes[group_idx]),
dtype=sub_partition_param.dtype).cuda()
sub_partition_param.grad = sub_partition_grad
self.optimizer.step()
for group in self.local_sub_partitions_of_fp32_groups:
for idx, sub_partition_param in enumerate(group):
sub_partition_param.grad = None
@staticmethod @staticmethod
def get_data_parallel_sub_partitions(tensor, def get_data_parallel_sub_partitions(tensor,
max_elements_per_comm, max_elements_per_comm,
...@@ -721,6 +773,51 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -721,6 +773,51 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
cur_scale = property(_get_loss_scale, _set_loss_scale) cur_scale = property(_get_loss_scale, _set_loss_scale)
# Return group tensor after removing paddings that are added for alignment to DP world size.
# This method works on the assumption that each group contains sub partitions.
def _get_groups_without_padding(self, groups_with_padding):
groups_without_padding = []
local_rank = dist.get_rank(group=self.dp_process_group)
for i, group in enumerate(groups_with_padding):
low_index = local_rank * len(group)
high_index = (local_rank + 1) * len(group)
group_paddings = self.group_paddings[i][low_index:high_index]
lean_sub_partitions = []
for j, sub_partition in enumerate(group):
lean_length = sub_partition.numel() - group_paddings[j]
lean_sub_partitions.append(sub_partition[:lean_length])
groups_without_padding.append(lean_sub_partitions)
return groups_without_padding
# Return optimizer state after removing paddings that are added for alignment.
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
return lean_state
# Return base optimizer states.
# This method assumes that each param group contains a single flattened tensor.
def _get_base_optimizer_state(self):
optimizer_groups_state = []
local_rank = dist.get_rank(group=self.dp_process_group)
for group_idx, group in enumerate(self.optimizer.param_groups):
group_lean_state = []
low_index = local_rank * self.num_comm_intervals_per_group[group_idx]
high_index = (local_rank + 1) * self.num_comm_intervals_per_group[group_idx]
param_paddings = self.group_paddings[group_idx][low_index:high_index]
for param_idx, param in enumerate(group['params']):
lean_state = self._get_state_without_padding(self.optimizer.state[param],
param_paddings[param_idx])
group_lean_state.append(lean_state)
optimizer_groups_state.append(group_lean_state)
return optimizer_groups_state
def state_dict(self): def state_dict(self):
""" """
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...@@ -736,20 +833,140 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -736,20 +833,140 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['loss_scaler'] = self.loss_scaler state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow state_dict['overflow'] = self.overflow
state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
state_dict[
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding(
self.local_sub_partitions_of_fp32_groups)
state_dict['local_sub_partitions_of_fp32_groups'] = fp32_groups_without_padding
return state_dict return state_dict
def _retrieve_group_sub_partition_weights(self, all_partition_fp32_weights):
partition_id = dist.get_rank(group=self.dp_process_group)
all_sub_partition_weights = []
for partition_weights in all_partition_fp32_weights:
for sub_partition_weights in partition_weights:
all_sub_partition_weights.append(sub_partition_weights)
flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_weights,
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Restore base optimizer fp32 weights from checkpoint by:
# 1) Merging fp32 weights from checkpoints of all partitions
# 2) Extracting fp32 weights for current partition from merged weights
# 3) Using extracted weights to update base optimizer weights directly.
def _restore_from_fp32_weights(self, all_state_dict):
sub_partition_of_fp32_groups = []
for group_idx in range(len(self.local_sub_partitions_of_fp32_groups)):
all_partition_fp32_weights = [
sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict
]
sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights)
sub_partition_of_fp32_groups.append(sub_partition_weights)
for current_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, sub_partition_of_fp32_groups):
for current_sub_part, saved_sub_part in zip(current_group, saved_group):
current_sub_part.data.copy_(saved_sub_part.data)
# Extract optimizer state for current partition from merged states of all partitions
def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elements_per_comm,
pg=self.dp_process_group)
comm_partitions, dp_sub_partitions, element_intervals, sub_partition_size, num_comm_intervals = \
self.get_data_parallel_sub_partitions(
tensor=flat_merged_partitions,
max_elements_per_comm=self.max_elements_per_comm,
world_size=dist.get_world_size(group=self.dp_process_group),
dp_process_group=self.dp_process_group
)
return [sub_partition for sub_partition in dp_sub_partitions[partition_id]]
# Compute the optimizer state partitions for the group by
# 1) Merging state values across the previous partitioning.
# 2) Repartition state values for the new partitioning
# 3) Return state corresponding to local partition
def _retrieve_group_optimizer_states(self, all_partition_states):
merged_optimizer_states = {}
for partition_state in all_partition_states:
for sub_partition_state in partition_state:
for key, value in sub_partition_state.items():
if not key in merged_optimizer_states.keys():
merged_optimizer_states[key] = [value]
else:
merged_optimizer_states[key].append(value)
group_optimizer_states = {}
for key, value in merged_optimizer_states.items():
group_optimizer_states[key] = self._partition_base_optimizer_state(
key,
value)
return group_optimizer_states
# Restore base optimizer state from checkpoint by
# 1) Merging optimizer state from checkpoints of all partitions
# 2) Extracting optimizer state for current partition from the merged state
# 3) Using the extracted value to directly update the base optimizer.
def _restore_base_optimizer_state(self, state_dict_list):
base_optimizer_group_states = []
for group_idx in range(len(self.optimizer.param_groups)):
all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list
]
group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states)
base_optimizer_group_states.append(group_optimizer_states)
for group_idx, group in enumerate(self.optimizer.param_groups):
for param_idx, param in enumerate(group['params']):
for key, saved in base_optimizer_group_states[group_idx].items():
current = self.optimizer.state[param][key]
current.data.copy_(saved[param_idx].data)
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def _restore_from_fp16_weights(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for fp16_sub_partition, fp32_sub_partition in zip(fp16_partitions[partition_id], fp32_partitions):
fp32_sub_partition.data.copy_(fp16_sub_partition.data)
# Refresh the fp32 master params from the fp16 copies. # Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self): def refresh_fp32_params(self):
partition_id = dist.get_rank(group=self.dp_process_group) self._restore_from_fp16_weights()
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp32.data.copy_(
local_sub_partition_param_fp16.data)
def load_state_dict(self, state_dict, load_optimizer_states=True): def load_state_dict(self,
state_dict_list,
load_optimizer_states=True,
load_from_fp32_weights=False):
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
...@@ -766,12 +983,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -766,12 +983,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
""" """
# I think it should actually be ok to reload the optimizer before the model. # I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler'] self.loss_scaler = state_dict_list[0]['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
self.overflow = state_dict['overflow'] self.overflow = state_dict_list[0]['overflow']
if load_optimizer_states: if load_optimizer_states:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) self._restore_base_optimizer_state(state_dict_list)
for curr_group, saved_group in zip(self.local_sub_partitions_of_fp32_groups, state_dict['local_sub_partitions_of_fp32_groups']): if load_from_fp32_weights:
for curr_param, saved_param in zip(curr_group, saved_group): self._restore_from_fp32_weights(state_dict_list)
curr_param.data.copy_(saved_param.data) else:
self._restore_from_fp16_weights()
...@@ -53,7 +53,7 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -53,7 +53,7 @@ class GPT2CheckpointTestCase(BaseTestCase):
def tearDown(self): def tearDown(self):
os.chdir(self.save_dir) os.chdir(self.save_dir)
def test_mp4_gpu16_node1_with_zero1(self): def test_mp2_gpu4_node1_with_zero1(self):
test_config = { test_config = {
"mp": 2, "mp": 2,
"gpus": 4, "gpus": 4,
...@@ -68,14 +68,14 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -68,14 +68,14 @@ class GPT2CheckpointTestCase(BaseTestCase):
"tag": "ds_zero1", "tag": "ds_zero1",
"zero": True, "zero": True,
"other_args": "", "other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero1", "checkpoint_name": "ckpt_mp2_gpu8_w_zero1",
"checkpoint_interval": 1000, "checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json", "json": "ds_config_func_bs8_zero1.json",
} }
succ = self.run_test(test_config, 0.01) succ = self.run_test(test_config, 0.01)
self.assertTrue(succ) self.assertTrue(succ)
def test_mp4_gpu16_node1_with_zero2(self): def test_mp2_gpu4_node1_with_zero2(self):
test_config = { test_config = {
"mp": 2, "mp": 2,
"gpus": 4, "gpus": 4,
...@@ -90,14 +90,198 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -90,14 +90,198 @@ class GPT2CheckpointTestCase(BaseTestCase):
"tag": "ds_zero2", "tag": "ds_zero2",
"zero": True, "zero": True,
"other_args": "", "other_args": "",
"checkpoint_name": "ckpt_mp4_gpu16_w_zero2", "checkpoint_name": "ckpt_mp2_gpu8_w_zero2",
"checkpoint_interval": 1000, "checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json", "json": "ds_config_func_bs8_zero2.json",
} }
succ = self.run_test(test_config, 0.01) succ = self.run_test(test_config, 0.01)
self.assertTrue(succ) self.assertTrue(succ)
def test_mp4_gpu16_node1_without_zero(self): def test_mp1_gpu2_load_gpu1_node1_with_zero1(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 1,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu4_node1_with_zero1(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu1_node1_with_zero2(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 1,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp1_gpu2_load_gpu4_node1_with_zero2(self):
test_config = {
"mp": 1,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_load_gpu2_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 4,
"load_gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu2_load_gpu4_node1_with_zero1(self):
test_config = {
"mp": 2,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero1",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero1",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero1.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_load_gpu2_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 4,
"load_gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu2_load_gpu4_node1_with_zero2(self):
test_config = {
"mp": 2,
"gpus": 2,
"load_gpus": 4,
"nodes": 1,
"bs": 8,
"steps": 1100,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": 256,
"heads": ATTN_HEADS,
"deepspeed": True,
"tag": "ds_zero2",
"zero": True,
"other_args": "",
"checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero2",
"checkpoint_interval": 1000,
"json": "ds_config_func_bs8_zero2.json",
}
succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)
def test_mp2_gpu4_node1_without_zero(self):
test_config = { test_config = {
"mp": 2, "mp": 2,
"gpus": 4, "gpus": 4,
...@@ -130,6 +314,14 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -130,6 +314,14 @@ class GPT2CheckpointTestCase(BaseTestCase):
print("{0}: starting......".format(self.id())) print("{0}: starting......".format(self.id()))
# Cache save and load gpu counts
save_gpus = test_config["gpus"]
if "load_gpus" in test_config:
load_gpus = test_config["load_gpus"]
del test_config["load_gpus"]
else:
load_gpus = test_config["gpus"]
# save to current directory. # save to current directory.
checkpoint_folder = test_config["checkpoint_name"] checkpoint_folder = test_config["checkpoint_name"]
checkpoint_interval = test_config["checkpoint_interval"] checkpoint_interval = test_config["checkpoint_interval"]
...@@ -178,6 +370,9 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -178,6 +370,9 @@ class GPT2CheckpointTestCase(BaseTestCase):
prefix = "gpt2_loading_checkpoint" prefix = "gpt2_loading_checkpoint"
# set load gpus
test_config["gpus"] = load_gpus
print("{0}: Second run loading checkpoint and continuing.".format(self.id())) print("{0}: Second run loading checkpoint and continuing.".format(self.id()))
test_file = self.gen_name(test_config, prefix) test_file = self.gen_name(test_config, prefix)
...@@ -216,9 +411,22 @@ class GPT2CheckpointTestCase(BaseTestCase): ...@@ -216,9 +411,22 @@ class GPT2CheckpointTestCase(BaseTestCase):
def checkpoint_suite(): def checkpoint_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_with_zero2')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp4_gpu16_node1_without_zero'))
# Shrink DP
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2'))
# Expand DP
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero1'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2'))
suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_without_zero'))
return suite return suite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment