Unverified Commit 4ccb9ded authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment...

[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)
parent 0d482302
......@@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper):
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
dp_group = chunk.torch_pg
rank = dist.get_rank(dp_group)
zero_group = chunk.torch_pg
rank = dist.get_rank(zero_group)
master_rank = 0
collected_states = {}
......@@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper):
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
state_names = None
for names in gathered_state_names:
if names is not None:
......@@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper):
_, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
if is_collector:
for state_shard in gathered_state_shards:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment