Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e6c34b3
Unverified
Commit
8e6c34b3
authored
Mar 22, 2023
by
Connor Henderson
Committed by
GitHub
Mar 22, 2023
Browse files
fix: Allow only test_file in pytorch and flax summarization (#22293)
allow only test_file in pytorch and flax summarization
parent
4ccaf268
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
16 deletions
+32
-16
examples/flax/summarization/run_summarization_flax.py
examples/flax/summarization/run_summarization_flax.py
+16
-8
examples/pytorch/summarization/run_summarization.py
examples/pytorch/summarization/run_summarization.py
+16
-8
No files found.
examples/flax/summarization/run_summarization_flax.py
View file @
8e6c34b3
...
@@ -308,8 +308,13 @@ class DataTrainingArguments:
...
@@ -308,8 +308,13 @@ class DataTrainingArguments:
)
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
:
if
(
raise
ValueError
(
"Need either a dataset name or a training/validation file."
)
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
and
self
.
test_file
is
None
):
raise
ValueError
(
"Need either a dataset name or a training, validation, or test file."
)
else
:
else
:
if
self
.
train_file
is
not
None
:
if
self
.
train_file
is
not
None
:
extension
=
self
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
self
.
train_file
.
split
(
"."
)[
-
1
]
...
@@ -317,6 +322,9 @@ class DataTrainingArguments:
...
@@ -317,6 +322,9 @@ class DataTrainingArguments:
if
self
.
validation_file
is
not
None
:
if
self
.
validation_file
is
not
None
:
extension
=
self
.
validation_file
.
split
(
"."
)[
-
1
]
extension
=
self
.
validation_file
.
split
(
"."
)[
-
1
]
assert
extension
in
[
"csv"
,
"json"
],
"`validation_file` should be a csv or a json file."
assert
extension
in
[
"csv"
,
"json"
],
"`validation_file` should be a csv or a json file."
if
self
.
test_file
is
not
None
:
extension
=
self
.
test_file
.
split
(
"."
)[
-
1
]
assert
extension
in
[
"csv"
,
"json"
],
"`test_file` should be a csv or a json file."
if
self
.
val_max_target_length
is
None
:
if
self
.
val_max_target_length
is
None
:
self
.
val_max_target_length
=
self
.
max_target_length
self
.
val_max_target_length
=
self
.
max_target_length
...
@@ -553,10 +561,16 @@ def main():
...
@@ -553,10 +561,16 @@ def main():
# Preprocessing the datasets.
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
# We need to tokenize inputs and targets.
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
dataset
:
raise
ValueError
(
"--do_train requires a train dataset"
)
column_names
=
dataset
[
"train"
].
column_names
column_names
=
dataset
[
"train"
].
column_names
elif
training_args
.
do_eval
:
elif
training_args
.
do_eval
:
if
"validation"
not
in
dataset
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
column_names
=
dataset
[
"validation"
].
column_names
column_names
=
dataset
[
"validation"
].
column_names
elif
training_args
.
do_predict
:
elif
training_args
.
do_predict
:
if
"test"
not
in
dataset
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
column_names
=
dataset
[
"test"
].
column_names
column_names
=
dataset
[
"test"
].
column_names
else
:
else
:
logger
.
info
(
"There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
)
logger
.
info
(
"There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
)
...
@@ -620,8 +634,6 @@ def main():
...
@@ -620,8 +634,6 @@ def main():
return
model_inputs
return
model_inputs
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
dataset
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
dataset
[
"train"
]
train_dataset
=
dataset
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
if
data_args
.
max_train_samples
is
not
None
:
max_train_samples
=
min
(
len
(
train_dataset
),
data_args
.
max_train_samples
)
max_train_samples
=
min
(
len
(
train_dataset
),
data_args
.
max_train_samples
)
...
@@ -637,8 +649,6 @@ def main():
...
@@ -637,8 +649,6 @@ def main():
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
max_target_length
=
data_args
.
val_max_target_length
max_target_length
=
data_args
.
val_max_target_length
if
"validation"
not
in
dataset
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
dataset
[
"validation"
]
eval_dataset
=
dataset
[
"validation"
]
if
data_args
.
max_eval_samples
is
not
None
:
if
data_args
.
max_eval_samples
is
not
None
:
max_eval_samples
=
min
(
len
(
eval_dataset
),
data_args
.
max_eval_samples
)
max_eval_samples
=
min
(
len
(
eval_dataset
),
data_args
.
max_eval_samples
)
...
@@ -654,8 +664,6 @@ def main():
...
@@ -654,8 +664,6 @@ def main():
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
max_target_length
=
data_args
.
val_max_target_length
max_target_length
=
data_args
.
val_max_target_length
if
"test"
not
in
dataset
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
predict_dataset
=
dataset
[
"test"
]
predict_dataset
=
dataset
[
"test"
]
if
data_args
.
max_predict_samples
is
not
None
:
if
data_args
.
max_predict_samples
is
not
None
:
max_predict_samples
=
min
(
len
(
predict_dataset
),
data_args
.
max_predict_samples
)
max_predict_samples
=
min
(
len
(
predict_dataset
),
data_args
.
max_predict_samples
)
...
...
examples/pytorch/summarization/run_summarization.py
View file @
8e6c34b3
...
@@ -262,8 +262,13 @@ class DataTrainingArguments:
...
@@ -262,8 +262,13 @@ class DataTrainingArguments:
)
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
:
if
(
raise
ValueError
(
"Need either a dataset name or a training/validation file."
)
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
and
self
.
test_file
is
None
):
raise
ValueError
(
"Need either a dataset name or a training, validation, or test file."
)
else
:
else
:
if
self
.
train_file
is
not
None
:
if
self
.
train_file
is
not
None
:
extension
=
self
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
self
.
train_file
.
split
(
"."
)[
-
1
]
...
@@ -271,6 +276,9 @@ class DataTrainingArguments:
...
@@ -271,6 +276,9 @@ class DataTrainingArguments:
if
self
.
validation_file
is
not
None
:
if
self
.
validation_file
is
not
None
:
extension
=
self
.
validation_file
.
split
(
"."
)[
-
1
]
extension
=
self
.
validation_file
.
split
(
"."
)[
-
1
]
assert
extension
in
[
"csv"
,
"json"
],
"`validation_file` should be a csv or a json file."
assert
extension
in
[
"csv"
,
"json"
],
"`validation_file` should be a csv or a json file."
if
self
.
test_file
is
not
None
:
extension
=
self
.
test_file
.
split
(
"."
)[
-
1
]
assert
extension
in
[
"csv"
,
"json"
],
"`test_file` should be a csv or a json file."
if
self
.
val_max_target_length
is
None
:
if
self
.
val_max_target_length
is
None
:
self
.
val_max_target_length
=
self
.
max_target_length
self
.
val_max_target_length
=
self
.
max_target_length
...
@@ -467,10 +475,16 @@ def main():
...
@@ -467,10 +475,16 @@ def main():
# Preprocessing the datasets.
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
# We need to tokenize inputs and targets.
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
column_names
=
raw_datasets
[
"train"
].
column_names
column_names
=
raw_datasets
[
"train"
].
column_names
elif
training_args
.
do_eval
:
elif
training_args
.
do_eval
:
if
"validation"
not
in
raw_datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
column_names
=
raw_datasets
[
"validation"
].
column_names
column_names
=
raw_datasets
[
"validation"
].
column_names
elif
training_args
.
do_predict
:
elif
training_args
.
do_predict
:
if
"test"
not
in
raw_datasets
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
column_names
=
raw_datasets
[
"test"
].
column_names
column_names
=
raw_datasets
[
"test"
].
column_names
else
:
else
:
logger
.
info
(
"There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
)
logger
.
info
(
"There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
)
...
@@ -546,8 +560,6 @@ def main():
...
@@ -546,8 +560,6 @@ def main():
return
model_inputs
return
model_inputs
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
raw_datasets
[
"train"
]
train_dataset
=
raw_datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
if
data_args
.
max_train_samples
is
not
None
:
max_train_samples
=
min
(
len
(
train_dataset
),
data_args
.
max_train_samples
)
max_train_samples
=
min
(
len
(
train_dataset
),
data_args
.
max_train_samples
)
...
@@ -564,8 +576,6 @@ def main():
...
@@ -564,8 +576,6 @@ def main():
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
max_target_length
=
data_args
.
val_max_target_length
max_target_length
=
data_args
.
val_max_target_length
if
"validation"
not
in
raw_datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
raw_datasets
[
"validation"
]
eval_dataset
=
raw_datasets
[
"validation"
]
if
data_args
.
max_eval_samples
is
not
None
:
if
data_args
.
max_eval_samples
is
not
None
:
max_eval_samples
=
min
(
len
(
eval_dataset
),
data_args
.
max_eval_samples
)
max_eval_samples
=
min
(
len
(
eval_dataset
),
data_args
.
max_eval_samples
)
...
@@ -582,8 +592,6 @@ def main():
...
@@ -582,8 +592,6 @@ def main():
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
max_target_length
=
data_args
.
val_max_target_length
max_target_length
=
data_args
.
val_max_target_length
if
"test"
not
in
raw_datasets
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
predict_dataset
=
raw_datasets
[
"test"
]
predict_dataset
=
raw_datasets
[
"test"
]
if
data_args
.
max_predict_samples
is
not
None
:
if
data_args
.
max_predict_samples
is
not
None
:
max_predict_samples
=
min
(
len
(
predict_dataset
),
data_args
.
max_predict_samples
)
max_predict_samples
=
min
(
len
(
predict_dataset
),
data_args
.
max_predict_samples
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment