Commit 3f26348b authored by dongcl's avatar dongcl
Browse files

modify pretrain_gpt.py

parent 390eac88
......@@ -61,7 +61,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
use_te = args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine" or bool(os.getenv("USE_FLUX_OVERLAP", 0))
if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
......@@ -87,8 +87,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
else:
config = core_transformer_config_from_args(args)
print_rank_0(f"config: {config}")
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
......@@ -195,7 +193,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args = get_args()
losses = output_tensor.float()
if args.num_nextn_predict_layers > 0:
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
......@@ -287,7 +285,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
random_seed=args.seed,
sequence_length=args.seq_length + args.num_nextn_predict_layers,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0),
blend=blend,
blend_per_split=blend_per_split,
split=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment