Commit a20445d3 authored by Jared Casper's avatar Jared Casper
Browse files

Fix finetuning tasks after T5 pipeline merge.

parent 5478d67e
......@@ -25,6 +25,7 @@ from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
......@@ -248,6 +249,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider,
model_type=ModelType.encoder_or_decoder,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
task_collate_fn=None):
......@@ -277,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment