Commit 391a4ec2 authored by VictorSanh's avatar VictorSanh
Browse files

Small typo in `trange`

I seriously don't understand why they defined num_train_epochs as a float in the originial tf code.
I Will change it at the end to avoir merge conflicts for now.
parent 5676d6f7
...@@ -514,7 +514,7 @@ def main(): ...@@ -514,7 +514,7 @@ def main():
model.train() model.train()
nb_tr_examples = 0 nb_tr_examples = 0
for epoch in trange(args.num_train_epochs, desc="Epoch"): for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"): for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment