Commit c4ea7127 authored by wxj's avatar wxj
Browse files

Update pretrain_gpt.py 添加全量编译

parent c1c977c9
Pipeline #2288 passed with stage
...@@ -128,7 +128,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -128,7 +128,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling rope_scaling=args.use_rope_scaling
) )
#model = torch.compile(model,mode='max-autotune-no-cudagraphs') model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) print_rank_0(model)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment