Unverified Commit 64232bc0 authored by Jonathan Chang's avatar Jonathan Chang Committed by GitHub
Browse files

Add --text_column to run_summarization_no_trainer (#11673)

parent 024cd19b
...@@ -184,6 +184,12 @@ def parse_args(): ...@@ -184,6 +184,12 @@ def parse_args():
default=None, default=None,
help="Pretrained tokenizer name or path if not the same as model_name", help="Pretrained tokenizer name or path if not the same as model_name",
) )
parser.add_argument(
"--text_column",
type=str,
default=None,
help="The name of the column in the datasets containing the full texts (for summarization).",
)
parser.add_argument( parser.add_argument(
"--summary_column", "--summary_column",
type=str, type=str,
...@@ -371,9 +377,14 @@ def main(): ...@@ -371,9 +377,14 @@ def main():
# Get the column names for input/target. # Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(args.dataset_name, None) dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0] if args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
padding = "max_length" if args.pad_to_max_length else False else:
text_column = args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
)
if args.summary_column is None: if args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else: else:
...@@ -388,7 +399,7 @@ def main(): ...@@ -388,7 +399,7 @@ def main():
padding = "max_length" if args.pad_to_max_length else False padding = "max_length" if args.pad_to_max_length else False
def preprocess_function(examples): def preprocess_function(examples):
inputs = examples[text_column_name] inputs = examples[text_column]
targets = examples[summary_column] targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs] inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment