Commit 79526f82 authored by Bilal Khan's avatar Bilal Khan Committed by Lysandre Debut
Browse files

Remove unnecessary epoch variable

parent 9626e045
...@@ -245,7 +245,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -245,7 +245,7 @@ def train(args, train_dataset, model, tokenizer):
model.zero_grad() model.zero_grad()
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3) set_seed(args) # Added here for reproducibility (even between python 2 and 3)
for epoch in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment