Unverified Commit 4ccb9ded authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment...

[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)
parent 0d482302
...@@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper):
param = self.id_to_real_params[param_id] param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None) fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param) chunk = self.chunk_manager.get_chunk(param)
dp_group = chunk.torch_pg zero_group = chunk.torch_pg
rank = dist.get_rank(dp_group) rank = dist.get_rank(zero_group)
master_rank = 0 master_rank = 0
collected_states = {} collected_states = {}
...@@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper):
local_state_names = None local_state_names = None
if fake_param is not None: if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys()) local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier() dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names, dp_group) dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
state_names = None state_names = None
for names in gathered_state_names: for names in gathered_state_names:
if names is not None: if names is not None:
...@@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper):
_, shard_offset, shard_size = self.get_offsets(param_id) _, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering. # Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier() dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
if is_collector: if is_collector:
for state_shard in gathered_state_shards: for state_shard in gathered_state_shards:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment