Commit 62ac7e9a authored by VictorSanh's avatar VictorSanh
Browse files

Fix small bug in `run_squad_pytorch.py`

parent 98b9771d
...@@ -27,6 +27,7 @@ import tokenization ...@@ -27,6 +27,7 @@ import tokenization
import six import six
import argparse import argparse
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int, ...@@ -103,6 +104,10 @@ parser.add_argument("--max_answer_length", default=30, type=int,
parser.add_argument("--verbose_logging", default=False, type=bool, parser.add_argument("--verbose_logging", default=False, type=bool,
help="If true, all of the warnings related to data processing will be printed. " help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.") "A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
default = False,
action='store_true',
help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
...@@ -769,8 +774,7 @@ def main(): ...@@ -769,8 +774,7 @@ def main():
(args.max_seq_length, bert_config.max_position_embeddings)) (args.max_seq_length, bert_config.max_position_embeddings))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError(f"Output directory ({args.output_dir}) already exists and is " raise ValueError("Output directory () already exists and is not empty.")
f"not empty.")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
tokenizer = tokenization.FullTokenizer( tokenizer = tokenization.FullTokenizer(
...@@ -795,7 +799,8 @@ def main(): ...@@ -795,7 +799,8 @@ def main():
lr=args.learning_rate, schedule='warmup_linear', lr=args.learning_rate, schedule='warmup_linear',
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_steps)
global_step = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
examples=train_examples, examples=train_examples,
...@@ -823,7 +828,7 @@ def main(): ...@@ -823,7 +828,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for epoch in args.num_train_epochs: for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment