Unverified Commit 64232bc0 authored by Jonathan Chang's avatar Jonathan Chang Committed by GitHub
Browse files

Add --text_column to run_summarization_no_trainer (#11673)

parent 024cd19b
......@@ -184,6 +184,12 @@ def parse_args():
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--text_column",
type=str,
default=None,
help="The name of the column in the datasets containing the full texts (for summarization).",
)
parser.add_argument(
"--summary_column",
type=str,
......@@ -371,9 +377,14 @@ def main():
# Get the column names for input/target.
dataset_columns = summarization_name_mapping.get(args.dataset_name, None)
text_column_name = dataset_columns[0] if dataset_columns is not None else column_names[0]
padding = "max_length" if args.pad_to_max_length else False
if args.text_column is None:
text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
text_column = args.text_column
if text_column not in column_names:
raise ValueError(
f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}"
)
if args.summary_column is None:
summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
......@@ -388,7 +399,7 @@ def main():
padding = "max_length" if args.pad_to_max_length else False
def preprocess_function(examples):
inputs = examples[text_column_name]
inputs = examples[text_column]
targets = examples[summary_column]
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment