Commit c814bf1a authored by wxj's avatar wxj
Browse files

Update loader_llama_mistral.py 取消fused_kernels.load()方法的使用

parent 5a1b96fa
Pipeline #2196 passed with stage
......@@ -527,7 +527,7 @@ def _load_checkpoint(queue, args):
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size)
fused_kernels.load(margs)
# fused_kernels.load(margs)
# Short aliases.
tp_size = margs.tensor_model_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment