Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ac17f711
Unverified
Commit
ac17f711
authored
Mar 09, 2021
by
Bhadresh Savani
Committed by
GitHub
Mar 09, 2021
Browse files
added max_sample args and metrics changes (#10602)
parent
c19c811a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
28 deletions
+60
-28
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+60
-28
No files found.
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
ac17f711
...
...
@@ -144,6 +144,20 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
max_train_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
def
__post_init__
(
self
):
if
self
.
dataset_name
is
None
and
self
.
train_file
is
None
and
self
.
validation_file
is
None
:
...
...
@@ -317,13 +331,37 @@ def main():
def
tokenize_function
(
examples
):
return
tokenizer
(
examples
[
text_column_name
],
padding
=
"max_length"
,
truncation
=
True
)
tokenized_datasets
=
datasets
.
map
(
tokenize_function
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
[
text_column_name
],
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
if
training_args
.
do_train
:
if
"train"
not
in
datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
# Select Sample from Dataset
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
# tokenize train dataset in batch
train_dataset
=
train_dataset
.
map
(
tokenize_function
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
[
text_column_name
],
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
if
training_args
.
do_eval
:
if
"validation"
not
in
datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
datasets
[
"validation"
]
# Selecting samples from dataset
if
data_args
.
max_val_samples
is
not
None
:
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_val_samples
))
# tokenize validation dataset
eval_dataset
=
eval_dataset
.
map
(
tokenize_function
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
[
text_column_name
],
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
)
# Data collator
data_collator
=
default_data_collator
if
not
training_args
.
fp16
else
DataCollatorWithPadding
(
tokenizer
,
pad_to_multiple_of
=
8
)
...
...
@@ -332,8 +370,8 @@ def main():
trainer
=
Trainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
t
okenized_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
tokenized_datasets
[
"validation"
]
if
training_args
.
do_eval
else
None
,
train_dataset
=
t
rain_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
)
...
...
@@ -358,33 +396,27 @@ def main():
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
trainer
.
save_model
()
# Saves the tokenizer too for easy upload
output_train_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"train_results.txt"
)
if
trainer
.
is_world_process_zero
():
with
open
(
output_train_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Train results *****"
)
for
key
,
value
in
sorted
(
train_result
.
metrics
.
items
()):
logger
.
info
(
f
"
{
key
}
=
{
value
}
"
)
writer
.
write
(
f
"
{
key
}
=
{
value
}
\n
"
)
metrics
=
train_result
.
metrics
max_train_samples
=
(
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
train_dataset
)
)
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
train_dataset
))
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer
.
state
.
save_to_json
(
os
.
path
.
join
(
training_args
.
output_dir
,
"trainer_state.json"
))
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
trainer
.
save_state
()
# Evaluation
results
=
{}
if
training_args
.
do_eval
:
logger
.
info
(
"*** Evaluate ***"
)
result
s
=
trainer
.
evaluate
()
metric
s
=
trainer
.
evaluate
()
output_eval_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"eval_results_{{cookiecutter.example_shortcut}}.txt"
)
if
trainer
.
is_world_process_zero
():
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
for
key
,
value
in
sorted
(
results
.
items
()):
logger
.
info
(
f
"
{
key
}
=
{
value
}
"
)
writer
.
write
(
f
"
{
key
}
=
{
value
}
\n
"
)
max_val_samples
=
data_args
.
max_val_samples
if
data_args
.
max_val_samples
is
not
None
else
len
(
eval_dataset
)
metrics
[
"eval_samples"
]
=
min
(
max_val_samples
,
len
(
eval_dataset
))
return
results
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
def
_mp_fn
(
index
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment