"server/text_generation_server/models/flash_cohere.py" did not exist on "85aa7e2e7b02608eea04206b6cc0fa0ccced80ef"
Unverified Commit dce054db authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

backwards compatability w. v020 ckpts, fix issue with zero-1 ckpts (#543)

parent 9de21b72
...@@ -42,6 +42,8 @@ sys.modules['deepspeed.pt'] = deepspeed.pt ...@@ -42,6 +42,8 @@ sys.modules['deepspeed.pt'] = deepspeed.pt
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config) setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config
setattr(deepspeed.pt, 'loss_scaler', deepspeed.runtime.fp16.loss_scaler)
sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler
def initialize(args, def initialize(args,
......
...@@ -194,7 +194,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -194,7 +194,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# max elems per param group # max elems per param group
self.max_elems_per_comm = [] self.max_elems_per_comm = []
self.legacy_max_elements_per_comm = max_elements_per_comm
# loop to deal with groups # loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
...@@ -859,7 +858,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -859,7 +858,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES state_dict['zero_stage'] = ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict['partition_count'] = self.partition_count state_dict['partition_count'] = self.partition_count
state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group state_dict['num_comm_intervals_per_group'] = self.num_comm_intervals_per_group
state_dict['max_elems_per_comm'] = self.max_elems_per_comm
# Remove paddings for DP alignment to enable loading for other alignment values # Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding = self._get_groups_without_padding( fp32_groups_without_padding = self._get_groups_without_padding(
...@@ -933,10 +931,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -933,10 +931,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
sd['local_sub_partitions_of_fp32_groups'][group_idx] sd['local_sub_partitions_of_fp32_groups'][group_idx]
for sd in all_state_dict for sd in all_state_dict
] ]
if 'max_elems_per_comm' in all_state_dict[0]: max_elems_per_comm = self.max_elems_per_comm[group_idx]
max_elems_per_comm = all_state_dict[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
sub_partition_weights = self._retrieve_group_sub_partition_weights( sub_partition_weights = self._retrieve_group_sub_partition_weights(
all_partition_fp32_weights, all_partition_fp32_weights,
...@@ -1009,10 +1004,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object): ...@@ -1009,10 +1004,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_partition_group_states = [ all_partition_group_states = [
sd['base_optimizer_state'][group_idx] for sd in state_dict_list sd['base_optimizer_state'][group_idx] for sd in state_dict_list
] ]
if 'max_elems_per_comm' in state_dict_list[0]: max_elems_per_comm = self.max_elems_per_comm[group_idx]
max_elems_per_comm = state_dict_list[0]['max_elems_per_comm'][group_idx]
else:
max_elems_per_comm = self.legacy_max_elements_per_comm
group_optimizer_states = self._retrieve_group_optimizer_states( group_optimizer_states = self._retrieve_group_optimizer_states(
all_partition_group_states, all_partition_group_states,
max_elems_per_comm) max_elems_per_comm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment