Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a754594
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8def252de254ee5d85513c261d5a86433e88a191"
Unverified
Commit
9a754594
authored
Jun 26, 2021
by
Bhadresh Savani
Committed by
GitHub
Jun 25, 2021
Browse files
updated example template (#12365)
parent
539ee456
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
...directory_name}}/run_{{cookiecutter.example_shortcut}}.py
+20
-16
No files found.
templates/adding_a_new_example_script/{{cookiecutter.directory_name}}/run_{{cookiecutter.example_shortcut}}.py
View file @
9a754594
...
@@ -27,6 +27,7 @@ import sys
...
@@ -27,6 +27,7 @@ import sys
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
import
datasets
from
datasets
import
load_dataset
from
datasets
import
load_dataset
import
transformers
import
transformers
...
@@ -226,16 +227,19 @@ def main():
...
@@ -226,16 +227,19 @@ def main():
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
)
logger
.
setLevel
(
logging
.
INFO
if
training_args
.
should_log
else
logging
.
WARN
)
log_level
=
training_args
.
get_process_log_level
()
logger
.
setLevel
(
log_level
)
datasets
.
utils
.
logging
.
set_verbosity
(
log_level
)
transformers
.
utils
.
logging
.
set_verbosity
(
log_level
)
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
# Log on each process the small summary:
# Log on each process the small summary:
logger
.
warning
(
logger
.
warning
(
f
"Process rank:
{
training_args
.
local_rank
}
, device:
{
training_args
.
device
}
, n_gpu:
{
training_args
.
n_gpu
}
"
f
"Process rank:
{
training_args
.
local_rank
}
, device:
{
training_args
.
device
}
, n_gpu:
{
training_args
.
n_gpu
}
"
+
f
"distributed training:
{
bool
(
training_args
.
local_rank
!=
-
1
)
}
, 16-bits training:
{
training_args
.
fp16
}
"
+
f
"distributed training:
{
bool
(
training_args
.
local_rank
!=
-
1
)
}
, 16-bits training:
{
training_args
.
fp16
}
"
)
)
# Set the verbosity to info of the Transformers logger (on main process only):
if
training_args
.
should_log
:
transformers
.
utils
.
logging
.
set_verbosity_info
()
logger
.
info
(
f
"Training/evaluation parameters
{
training_args
}
"
)
logger
.
info
(
f
"Training/evaluation parameters
{
training_args
}
"
)
# Set seed before initializing model.
# Set seed before initializing model.
...
@@ -252,7 +256,7 @@ def main():
...
@@ -252,7 +256,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
raw_
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -266,7 +270,7 @@ def main():
...
@@ -266,7 +270,7 @@ def main():
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
raw_
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
@@ -348,20 +352,20 @@ def main():
...
@@ -348,20 +352,20 @@ def main():
# Preprocessing the datasets.
# Preprocessing the datasets.
# First we tokenize all the texts.
# First we tokenize all the texts.
if
training_args
.
do_train
:
if
training_args
.
do_train
:
column_names
=
datasets
[
"train"
].
column_names
column_names
=
raw_
datasets
[
"train"
].
column_names
elif
training_args
.
do_eval
:
elif
training_args
.
do_eval
:
column_names
=
datasets
[
"validation"
].
column_names
column_names
=
raw_
datasets
[
"validation"
].
column_names
elif
training_args
.
do_predict
:
elif
training_args
.
do_predict
:
column_names
=
datasets
[
"test"
].
column_names
column_names
=
raw_
datasets
[
"test"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
def
tokenize_function
(
examples
):
def
tokenize_function
(
examples
):
return
tokenizer
(
examples
[
text_column_name
],
padding
=
"max_length"
,
truncation
=
True
)
return
tokenizer
(
examples
[
text_column_name
],
padding
=
"max_length"
,
truncation
=
True
)
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
"train"
not
in
datasets
:
if
"train"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
datasets
[
"train"
]
train_dataset
=
raw_
datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
if
data_args
.
max_train_samples
is
not
None
:
# Select Sample from Dataset
# Select Sample from Dataset
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
train_dataset
=
train_dataset
.
select
(
range
(
data_args
.
max_train_samples
))
...
@@ -375,9 +379,9 @@ def main():
...
@@ -375,9 +379,9 @@ def main():
)
)
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
if
"validation"
not
in
datasets
:
if
"validation"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
datasets
[
"validation"
]
eval_dataset
=
raw_
datasets
[
"validation"
]
# Selecting samples from dataset
# Selecting samples from dataset
if
data_args
.
max_eval_samples
is
not
None
:
if
data_args
.
max_eval_samples
is
not
None
:
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_eval_samples
))
eval_dataset
=
eval_dataset
.
select
(
range
(
data_args
.
max_eval_samples
))
...
@@ -391,9 +395,9 @@ def main():
...
@@ -391,9 +395,9 @@ def main():
)
)
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
if
"test"
not
in
datasets
:
if
"test"
not
in
raw_
datasets
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
raise
ValueError
(
"--do_predict requires a test dataset"
)
predict_dataset
=
datasets
[
"test"
]
predict_dataset
=
raw_
datasets
[
"test"
]
# Selecting samples from dataset
# Selecting samples from dataset
if
data_args
.
max_predict_samples
is
not
None
:
if
data_args
.
max_predict_samples
is
not
None
:
predict_dataset
=
predict_dataset
.
select
(
range
(
data_args
.
max_predict_samples
))
predict_dataset
=
predict_dataset
.
select
(
range
(
data_args
.
max_predict_samples
))
...
@@ -754,7 +758,7 @@ def main():
...
@@ -754,7 +758,7 @@ def main():
# Preprocessing the datasets.
# Preprocessing the datasets.
# First we tokenize all the texts.
# First we tokenize all the texts.
column_names
=
datasets
[
"train"
].
column_names
column_names
=
raw_
datasets
[
"train"
].
column_names
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
text_column_name
=
"text"
if
"text"
in
column_names
else
column_names
[
0
]
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
args
.
pad_to_max_length
else
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment