Commit 38a61e7d authored by wxj's avatar wxj
Browse files

Update saver_megatron.py 取消fused_kernels.load()方法的使用

parent c814bf1a
Pipeline #2197 passed with stage
...@@ -206,7 +206,7 @@ def save_checkpoint(queue, args): ...@@ -206,7 +206,7 @@ def save_checkpoint(queue, args):
mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.set_tensor_model_parallel_rank(0) mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0) mpu.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs) # fused_kernels.load(margs)
# Embeddings # Embeddings
# ----------- # -----------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment