Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9fa29959
Unverified
Commit
9fa29959
authored
Apr 13, 2021
by
Philipp Schmid
Committed by
GitHub
Apr 13, 2021
Browse files
added cache_dir=model_args.cache_dir to all example with cache_dir arg (#11220)
parent
3312e96b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
37 additions
and
27 deletions
+37
-27
examples/language-modeling/run_clm.py
examples/language-modeling/run_clm.py
+4
-2
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+4
-2
examples/language-modeling/run_mlm_flax.py
examples/language-modeling/run_mlm_flax.py
+4
-2
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+4
-2
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+2
-2
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+2
-2
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+2
-2
examples/seq2seq/run_summarization.py
examples/seq2seq/run_summarization.py
+2
-2
examples/seq2seq/run_translation.py
examples/seq2seq/run_translation.py
+2
-2
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+3
-3
examples/text-classification/run_xnli.py
examples/text-classification/run_xnli.py
+6
-4
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+2
-2
No files found.
examples/language-modeling/run_clm.py
View file @
9fa29959
...
...
@@ -230,17 +230,19 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
else
:
data_files
=
{}
...
...
@@ -255,7 +257,7 @@ def main():
)
if
extension
==
"txt"
:
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm.py
View file @
9fa29959
...
...
@@ -239,17 +239,19 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
else
:
data_files
=
{}
...
...
@@ -260,7 +262,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm_flax.py
View file @
9fa29959
...
...
@@ -475,17 +475,19 @@ if __name__ == "__main__":
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
else
:
data_files
=
{}
...
...
@@ -496,7 +498,7 @@ if __name__ == "__main__":
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_plm.py
View file @
9fa29959
...
...
@@ -236,17 +236,19 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
else
:
data_files
=
{}
...
...
@@ -257,7 +259,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/multiple-choice/run_swag.py
View file @
9fa29959
...
...
@@ -268,10 +268,10 @@ def main():
if
data_args
.
validation_file
is
not
None
:
data_files
[
"validation"
]
=
data_args
.
validation_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
# Downloading and loading the swag dataset from the hub.
datasets
=
load_dataset
(
"swag"
,
"regular"
)
datasets
=
load_dataset
(
"swag"
,
"regular"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa.py
View file @
9fa29959
...
...
@@ -256,7 +256,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -269,7 +269,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa_beam_search.py
View file @
9fa29959
...
...
@@ -255,7 +255,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -267,7 +267,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_summarization.py
View file @
9fa29959
...
...
@@ -310,7 +310,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -322,7 +322,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_translation.py
View file @
9fa29959
...
...
@@ -294,7 +294,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -306,7 +306,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_glue.py
View file @
9fa29959
...
...
@@ -239,7 +239,7 @@ def main():
# download the dataset.
if
data_args
.
task_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
)
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
...
...
@@ -263,10 +263,10 @@ def main():
if
data_args
.
train_file
.
endswith
(
".csv"
):
# Loading a dataset from local csv files
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
# Loading a dataset from local json files
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_xnli.py
View file @
9fa29959
...
...
@@ -209,17 +209,19 @@ def main():
# Downloading and loading xnli dataset from the hub.
if
training_args
.
do_train
:
if
model_args
.
train_language
is
None
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
else
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
train_dataset
.
features
[
"label"
].
names
if
training_args
.
do_eval
:
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
)
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
eval_dataset
.
features
[
"label"
].
names
if
training_args
.
do_predict
:
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
)
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
test_dataset
.
features
[
"label"
].
names
# Labels
...
...
examples/token-classification/run_ner.py
View file @
9fa29959
...
...
@@ -229,7 +229,7 @@ def main():
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
...
...
@@ -239,7 +239,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment