Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
64232bc0
Unverified
Commit
64232bc0
authored
May 11, 2021
by
Jonathan Chang
Committed by
GitHub
May 11, 2021
Browse files
Add --text_column to run_summarization_no_trainer (#11673)
parent
024cd19b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
examples/pytorch/summarization/run_summarization_no_trainer.py
...les/pytorch/summarization/run_summarization_no_trainer.py
+15
-4
No files found.
examples/pytorch/summarization/run_summarization_no_trainer.py
View file @
64232bc0
...
@@ -184,6 +184,12 @@ def parse_args():
...
@@ -184,6 +184,12 @@ def parse_args():
default
=
None
,
default
=
None
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
,
)
)
parser
.
add_argument
(
"--text_column"
,
type
=
str
,
default
=
None
,
help
=
"The name of the column in the datasets containing the full texts (for summarization)."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--summary_column"
,
"--summary_column"
,
type
=
str
,
type
=
str
,
...
@@ -371,9 +377,14 @@ def main():
...
@@ -371,9 +377,14 @@ def main():
# Get the column names for input/target.
# Get the column names for input/target.
dataset_columns
=
summarization_name_mapping
.
get
(
args
.
dataset_name
,
None
)
dataset_columns
=
summarization_name_mapping
.
get
(
args
.
dataset_name
,
None
)
text_column_name
=
dataset_columns
[
0
]
if
dataset_columns
is
not
None
else
column_names
[
0
]
if
args
.
text_column
is
None
:
text_column
=
dataset_columns
[
0
]
if
dataset_columns
is
not
None
else
column_names
[
0
]
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
else
:
text_column
=
args
.
text_column
if
text_column
not
in
column_names
:
raise
ValueError
(
f
"--text_column' value '
{
args
.
text_column
}
' needs to be one of:
{
', '
.
join
(
column_names
)
}
"
)
if
args
.
summary_column
is
None
:
if
args
.
summary_column
is
None
:
summary_column
=
dataset_columns
[
1
]
if
dataset_columns
is
not
None
else
column_names
[
1
]
summary_column
=
dataset_columns
[
1
]
if
dataset_columns
is
not
None
else
column_names
[
1
]
else
:
else
:
...
@@ -388,7 +399,7 @@ def main():
...
@@ -388,7 +399,7 @@ def main():
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
def
preprocess_function
(
examples
):
def
preprocess_function
(
examples
):
inputs
=
examples
[
text_column
_name
]
inputs
=
examples
[
text_column
]
targets
=
examples
[
summary_column
]
targets
=
examples
[
summary_column
]
inputs
=
[
prefix
+
inp
for
inp
in
inputs
]
inputs
=
[
prefix
+
inp
for
inp
in
inputs
]
model_inputs
=
tokenizer
(
inputs
,
max_length
=
args
.
max_source_length
,
padding
=
padding
,
truncation
=
True
)
model_inputs
=
tokenizer
(
inputs
,
max_length
=
args
.
max_source_length
,
padding
=
padding
,
truncation
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment