Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ac17f711
Unverified
Commit
ac17f711
authored
Mar 09, 2021
by
Bhadresh Savani
Committed by
GitHub
Mar 09, 2021
Browse files
added max_sample args and metrics changes (#10602)
parent
c19c811a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
28 deletions
+60
-28
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+60
-28
No files found.
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
ac17f711
...
...
@@ -144,6 +144,20 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
max_train_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
def
__post_init__
(
self
):
if
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
:
...
...
@@ -317,7 +331,31 @@ def main():
def
tokenize_function
(
examples
):
return
tokenizer
(
examples
[
text_column_name
],
padding
=
"max_length"
,
truncation
=
True
)
tokenized_datasets
=
datasets
.
map
(
if
training_args
.
do_train
:
if
"train"
not
in
datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
# Select Sample from Dataset
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
# tokenize train dataset in batch
train_dataset
=
train_dataset
.
map
(
tokenize_function
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
[
text_column_name
],
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
if
training_args
.
do_eval
:
if
"validation"
not
in
datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
datasets
[
"validation"
]
# Selecting samples from dataset
if
data_args
.
max_val_samples
is
not
None
:
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_val_samples
))
# tokenize validation dataset
eval_dataset
=
eval_dataset
.
map
(
tokenize_function
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
...
...
@@ -332,8 +370,8 @@ def main():
trainer
=
Trainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
t
okenized_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
tokenized_datasets
[
"validation"
]
if
training_args
.
do_eval
else
None
,
train_dataset
=
t
rain_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
)
...
...
@@ -358,33 +396,27 @@ def main():
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
if
trainer
.
is_world_process_zero
():
with
open
(
output_train_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Train results *****"
)
for
key
,
value
in
sorted
(
train_result
.
metrics
.
items
()):
logger
.
info
(
f
"
{
key
}
=
{
value
}
"
)
writer
.
write
(
f
"
{
key
}
=
{
value
}
\n
"
)
metrics
=
train_result
.
metrics
max_train_samples
=
(
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
train_dataset
)
)
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
train_dataset
))
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer
.
state
.
save_to_json
(
os
.
path
.
join
(
training_args
.
output_dir
,
"trainer_state.json"
))
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
trainer
.
save_state
()
# Evaluation
results
=
{}
if
training_args
.
do_eval
:
logger
.
info
(
"*** Evaluate ***"
)
result
s
=
trainer
.
evaluate
()
metric
s
=
trainer
.
evaluate
()
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results_{{cookiecutter.example_shortcut}}.txt"
)
if
trainer
.
is_world_process_zero
():
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
for
key
,
value
in
sorted
(
results
.
items
()):
logger
.
info
(
f
"
{
key
}
=
{
value
}
"
)
writer
.
write
(
f
"
{
key
}
=
{
value
}
\n
"
)
max_val_samples
=
data_args
.
max_val_samples
if
data_args
.
max_val_samples
is
not
None
else
len
(
eval_dataset
)
metrics
[
"eval_samples"
]
=
min
(
max_val_samples
,
len
(
eval_dataset
))
return
results
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
def
_mp_fn
(
index
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment