Commit fd338abd authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Style

parent aef4cf8c
...@@ -76,9 +76,7 @@ def parse_args(): ...@@ -76,9 +76,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
) )
parser.add_argument( parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
"--do_predict", action="store_true", help="Eval the question answering model"
)
parser.add_argument( parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
) )
...@@ -396,7 +394,6 @@ def main(): ...@@ -396,7 +394,6 @@ def main():
return tokenized_examples return tokenized_examples
if "train" not in raw_datasets: if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"] train_dataset = raw_datasets["train"]
...@@ -481,7 +478,6 @@ def main(): ...@@ -481,7 +478,6 @@ def main():
return tokenized_examples return tokenized_examples
if "validation" not in raw_datasets: if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_examples = raw_datasets["validation"] eval_examples = raw_datasets["validation"]
...@@ -539,11 +535,8 @@ def main(): ...@@ -539,11 +535,8 @@ def main():
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
) )
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader( eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
if args.do_predict: if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
...@@ -605,7 +598,7 @@ def main(): ...@@ -605,7 +598,7 @@ def main():
if step + batch_size < len(dataset): if step + batch_size < len(dataset):
logits_concat[step : step + batch_size, :cols] = output_logit logits_concat[step : step + batch_size, :cols] = output_logit
else: else:
logits_concat[step:, :cols] = output_logit[:len(dataset) - step] logits_concat[step:, :cols] = output_logit[: len(dataset) - step]
step += batch_size step += batch_size
......
...@@ -81,9 +81,7 @@ def parse_args(): ...@@ -81,9 +81,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
) )
parser.add_argument( parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
"--do_predict", action="store_true", help="Eval the question answering model"
)
parser.add_argument( parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
) )
...@@ -543,9 +541,7 @@ def main(): ...@@ -543,9 +541,7 @@ def main():
) )
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
eval_dataloader = DataLoader( eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
if args.do_predict: if args.do_predict:
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
...@@ -607,7 +603,7 @@ def main(): ...@@ -607,7 +603,7 @@ def main():
if step + batch_size < len(dataset): if step + batch_size < len(dataset):
logits_concat[step : step + batch_size, :cols] = output_logit logits_concat[step : step + batch_size, :cols] = output_logit
else: else:
logits_concat[step:, :cols] = output_logit[:len(dataset) - step] logits_concat[step:, :cols] = output_logit[: len(dataset) - step]
step += batch_size step += batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment