Commit b219ff00 authored by Jared Casper's avatar Jared Casper
Browse files

Update code used for finetuning to latest API.

parent b4b0d739
...@@ -52,8 +52,7 @@ class Classification(MegatronModule): ...@@ -52,8 +52,7 @@ class Classification(MegatronModule):
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(attention_mask)
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, _, pooled_output = self.language_model(input_ids,
......
...@@ -64,8 +64,7 @@ class MultipleChoice(MegatronModule): ...@@ -64,8 +64,7 @@ class MultipleChoice(MegatronModule):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1)) tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(attention_mask)
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, _, pooled_output = self.language_model(input_ids,
......
...@@ -161,7 +161,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -161,7 +161,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, _ = train_step(forward_step, batch, model, losses_dict, skipped_iter = train_step(forward_step, batch, model,
optimizer, lr_scheduler) optimizer, lr_scheduler)
iteration += 1 iteration += 1
...@@ -169,7 +169,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -169,7 +169,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
report_memory_flag = training_log(losses_dict, losses_dict_sum, report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, optimizer.loss_scale,
report_memory_flag) report_memory_flag, skipped_iter)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment