Commit 02c392f0 authored by Shaden Smith's avatar Shaden Smith
Browse files

Revert "Load non-DeepSpeed checkpoints into ZeRO optimizer"

This reverts commit 54c0267e.
parent 54c0267e
...@@ -561,6 +561,7 @@ class DeepSpeedLight(Module): ...@@ -561,6 +561,7 @@ class DeepSpeedLight(Module):
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES: if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode' assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
logger.info('Creating fp16 ZeRO Optimizer Stage 1')
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1( optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer, optimizer,
static_loss_scale=self.loss_scale(), static_loss_scale=self.loss_scale(),
...@@ -592,6 +593,7 @@ class DeepSpeedLight(Module): ...@@ -592,6 +593,7 @@ class DeepSpeedLight(Module):
gradient_predivide_factor=self.gradient_predivide_factor()) gradient_predivide_factor=self.gradient_predivide_factor())
else: else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
logger.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
return optimizer return optimizer
......
...@@ -353,7 +353,6 @@ class FP16_Optimizer(object): ...@@ -353,7 +353,6 @@ class FP16_Optimizer(object):
state_dict['clip_grad'] = self.clip_grad state_dict['clip_grad'] = self.clip_grad
return state_dict return state_dict
# Refresh fp32 master params from fp16 copies
def refresh_fp32_params(self): def refresh_fp32_params(self):
for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat): for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
current.data.copy_(saved.data) current.data.copy_(saved.data)
......
...@@ -10,24 +10,6 @@ from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler ...@@ -10,24 +10,6 @@ from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow from deepspeed.pt.deepspeed_utils import get_grad_norm, CheckOverflow
def flatten_dense_tensors_sub_partition_aligned_(tensor_list,
dp,
max_elements_per_comm,
pg):
num_elements = sum(t.numel() for t in tensor_list)
log_dist("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])
# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)
# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)
def flatten_dense_tensors_sub_partition_aligned(tensor_list, def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp, dp,
max_elements_per_comm, max_elements_per_comm,
...@@ -798,14 +780,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -798,14 +780,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups 'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict return state_dict
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp32.data.copy_(
local_sub_partition_param_fp16.data)
def load_state_dict(self, state_dict, load_optimizer_states=True): def load_state_dict(self, state_dict, load_optimizer_states=True):
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment