Commit a65d5678 authored by wxj's avatar wxj
Browse files

Update saver_legacy.py

parent 63c300ba
Pipeline #2652 passed with stage
......@@ -206,7 +206,7 @@ def save_checkpoint(queue, args):
mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# fused_kernels.load(margs)
# Embeddings
# -----------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment