Unverified Commit c225e872 authored by Tom Grek's avatar Tom Grek Committed by GitHub
Browse files

Fix it to work with BART (#6756)

parent 0d2c111a
......@@ -187,7 +187,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions": batch[4],
}
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
del inputs["token_type_ids"]
if args.model_type in ["xlnet", "xlm"]:
......@@ -300,7 +300,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids": batch[2],
}
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
del inputs["token_type_ids"]
feature_indices = batch[3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment