Commit ed8fad73 authored by Mathieu Prouveur's avatar Mathieu Prouveur
Browse files

Update example files so that tr_loss is not affected by args.gradient_accumulation_step

parent c36cca07
......@@ -845,7 +845,7 @@ def main():
else:
loss.backward()
tr_loss += loss.item()
tr_loss += loss.item() * args.gradient_accumulation_steps
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
......
......@@ -452,7 +452,7 @@ def main():
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
tr_loss += loss.item() * args.gradient_accumulation_steps
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment