Commit 62ac7e9a authored by VictorSanh's avatar VictorSanh
Browse files

Fix small bug in `run_squad_pytorch.py`

parent 98b9771d
......@@ -27,6 +27,7 @@ import tokenization
import six
import argparse
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
......@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int,
parser.add_argument("--verbose_logging", default=False, type=bool,
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
default = False,
action='store_true',
help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
default=-1,
......@@ -769,8 +774,7 @@ def main():
(args.max_seq_length, bert_config.max_position_embeddings))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError(f"Output directory ({args.output_dir}) already exists and is "
f"not empty.")
raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = tokenization.FullTokenizer(
......@@ -795,7 +799,8 @@ def main():
lr=args.learning_rate, schedule='warmup_linear',
warmup=args.warmup_proportion,
t_total=num_train_steps)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
examples=train_examples,
......@@ -823,7 +828,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for epoch in args.num_train_epochs:
for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment