Unverified Commit 9f8fa4e9 authored by Eliza Szczechla's avatar Eliza Szczechla Committed by GitHub
Browse files

Use DataCollatorForSeq2Seq in run_summarization in all cases (#10856)


Co-authored-by: default avatarEliza <eliza@habanero.tiger.com.pl>
parent a8d4d677
...@@ -38,7 +38,6 @@ from transformers import ( ...@@ -38,7 +38,6 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
default_data_collator,
set_seed, set_seed,
) )
from transformers.file_utils import is_offline_mode from transformers.file_utils import is_offline_mode
...@@ -466,15 +465,12 @@ def main(): ...@@ -466,15 +465,12 @@ def main():
# Data collator # Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length: data_collator = DataCollatorForSeq2Seq(
data_collator = default_data_collator tokenizer,
else: model=model,
data_collator = DataCollatorForSeq2Seq( label_pad_token_id=label_pad_token_id,
tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None,
model=model, )
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
# Metric # Metric
metric = load_metric("rouge") metric = load_metric("rouge")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment