Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a754594
Unverified
Commit
9a754594
authored
Jun 26, 2021
by
Bhadresh Savani
Committed by
GitHub
Jun 25, 2021
Browse files
updated example template (#12365)
parent
539ee456
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+20
-16
No files found.
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
9a754594
...
...
@@ -27,6 +27,7 @@ import sys
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
datasets
from
datasets
import
load_dataset
import
transformers
...
...
@@ -226,16 +227,19 @@ def main():
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
logger
.
setLevel
(
logging
.
INFO
if
training_args
.
should_log
else
logging
.
WARN
)
log_level
=
training_args
.
get_process_log_level
()
logger
.
setLevel
(
log_level
)
datasets
.
utils
.
logging
.
set_verbosity
(
log_level
)
transformers
.
utils
.
logging
.
set_verbosity
(
log_level
)
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
# Log on each process the small summary:
logger
.
warning
(
f
"Process rank:
{
training_args
.
local_rank
}
, device:
{
training_args
.
device
}
, n_gpu:
{
training_args
.
n_gpu
}
"
+
f
"distributed training:
{
bool
(
training_args
.
local_rank
!=
-
1
)
}
, 16-bits training:
{
training_args
.
fp16
}
"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if
training_args
.
should_log
:
transformers
.
utils
.
logging
.
set_verbosity_info
()
logger
.
info
(
f
"Training/evaluation parameters
{
training_args
}
"
)
# Set seed before initializing model.
...
...
@@ -252,7 +256,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
raw_
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -266,7 +270,7 @@ def main():
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
raw_
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
@@ -348,20 +352,20 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
if
training_args
.
do_train
:
column_names
=
datasets
[
"train"
].
column_names
column_names
=
raw_
datasets
[
"train"
].
column_names
elif
training_args
.
do_eval
:
column_names
=
datasets
[
"validation"
].
column_names
column_names
=
raw_
datasets
[
"validation"
].
column_names
elif
training_args
.
do_predict
:
column_names
=
datasets
[
"test"
].
column_names
column_names
=
raw_
datasets
[
"test"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
def
tokenize_function
(
examples
):
return
tokenizer
(
examples
[
text_column_name
],
padding
=
"max_length"
,
truncation
=
True
)
if
training_args
.
do_train
:
if
"train"
not
in
datasets
:
if
"train"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
datasets
[
"train"
]
train_dataset
=
raw_
datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
# Select Sample from Dataset
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
...
...
@@ -375,9 +379,9 @@ def main():
)
if
training_args
.
do_eval
:
if
"validation"
not
in
datasets
:
if
"validation"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
datasets
[
"validation"
]
eval_dataset
=
raw_
datasets
[
"validation"
]
# Selecting samples from dataset
if
data_args
.
max_eval_samples
is
not
None
:
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_eval_samples
))
...
...
@@ -391,9 +395,9 @@ def main():
)
if
training_args
.
do_predict
:
if
"test"
not
in
datasets
:
if
"test"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
predict_dataset
=
datasets
[
"test"
]
predict_dataset
=
raw_
datasets
[
"test"
]
# Selecting samples from dataset
if
data_args
.
max_predict_samples
is
not
None
:
predict_dataset
=
predict_dataset
.
select
(
range
(
data_args
.
max_predict_samples
))
...
...
@@ -754,7 +758,7 @@ def main():
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names
=
datasets
[
"train"
].
column_names
column_names
=
raw_
datasets
[
"train"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment