Commit 65e5d6ba authored by Hongxin Liu's avatar Hongxin Liu Committed by ver217
Browse files

[moe] fix mixtral optim checkpoint (#5344)

parent 956b561b
...@@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): ...@@ -393,7 +393,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups. # Store param groups.
index_file.append_meta_data("param_groups", param_group_file) index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file) group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path) param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file. # Store index file.
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
...@@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): ...@@ -451,7 +455,11 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
# Store param groups. # Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file) final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file) group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path) param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path) final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder) rmtree(tmp_index_file_folder)
......
...@@ -117,6 +117,8 @@ def check_mixtral_moe_layer(): ...@@ -117,6 +117,8 @@ def check_mixtral_moe_layer():
# check save optimizer # check save optimizer
optimizer.step() optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
snapshot = get_optimizer_snapshot(optimizer.unwrap()) snapshot = get_optimizer_snapshot(optimizer.unwrap())
booster.save_optimizer(optimizer, "mixtral_optim", shard=True) booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
dist.barrier() dist.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment