Commit 63c300ba authored by wxj's avatar wxj
Browse files

Update loader_llama_mistral.py

parent be4dda7b
Pipeline #2651 passed with stage
...@@ -514,7 +514,7 @@ def _load_checkpoint(queue, args): ...@@ -514,7 +514,7 @@ def _load_checkpoint(queue, args):
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs) # fused_kernels.load(margs)
# Short aliases. # Short aliases.
tp_size = margs.tensor_model_parallel_size tp_size = margs.tensor_model_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment