Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9fa29959
Unverified
Commit
9fa29959
authored
Apr 13, 2021
by
Philipp Schmid
Committed by
GitHub
Apr 13, 2021
Browse files
added cache_dir=model_args.cache_dir to all example with cache_dir arg (#11220)
parent
3312e96b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
37 additions
and
27 deletions
+37
-27
examples/language-modeling/run_clm.py
examples/language-modeling/run_clm.py
+4
-2
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+4
-2
examples/language-modeling/run_mlm_flax.py
examples/language-modeling/run_mlm_flax.py
+4
-2
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+4
-2
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+2
-2
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+2
-2
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+2
-2
examples/seq2seq/run_summarization.py
examples/seq2seq/run_summarization.py
+2
-2
examples/seq2seq/run_translation.py
examples/seq2seq/run_translation.py
+2
-2
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+3
-3
examples/text-classification/run_xnli.py
examples/text-classification/run_xnli.py
+6
-4
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+2
-2
No files found.
examples/language-modeling/run_clm.py
View file @
9fa29959
...
@@ -230,17 +230,19 @@ def main():
...
@@ -230,17 +230,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -255,7 +257,7 @@ def main():
...
@@ -255,7 +257,7 @@ def main():
)
)
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm.py
View file @
9fa29959
...
@@ -239,17 +239,19 @@ def main():
...
@@ -239,17 +239,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -260,7 +262,7 @@ def main():
...
@@ -260,7 +262,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm_flax.py
View file @
9fa29959
...
@@ -475,17 +475,19 @@ if __name__ == "__main__":
...
@@ -475,17 +475,19 @@ if __name__ == "__main__":
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -496,7 +498,7 @@ if __name__ == "__main__":
...
@@ -496,7 +498,7 @@ if __name__ == "__main__":
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_plm.py
View file @
9fa29959
...
@@ -236,17 +236,19 @@ def main():
...
@@ -236,17 +236,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -257,7 +259,7 @@ def main():
...
@@ -257,7 +259,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/multiple-choice/run_swag.py
View file @
9fa29959
...
@@ -268,10 +268,10 @@ def main():
...
@@ -268,10 +268,10 @@ def main():
if
data_args
.
validation_file
is
not
None
:
if
data_args
.
validation_file
is
not
None
:
data_files
[
"validation"
]
=
data_args
.
validation_file
data_files
[
"validation"
]
=
data_args
.
validation_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Downloading and loading the swag dataset from the hub.
# Downloading and loading the swag dataset from the hub.
datasets
=
load_dataset
(
"swag"
,
"regular"
)
datasets
=
load_dataset
(
"swag"
,
"regular"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa.py
View file @
9fa29959
...
@@ -256,7 +256,7 @@ def main():
...
@@ -256,7 +256,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -269,7 +269,7 @@ def main():
...
@@ -269,7 +269,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa_beam_search.py
View file @
9fa29959
...
@@ -255,7 +255,7 @@ def main():
...
@@ -255,7 +255,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -267,7 +267,7 @@ def main():
...
@@ -267,7 +267,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_summarization.py
View file @
9fa29959
...
@@ -310,7 +310,7 @@ def main():
...
@@ -310,7 +310,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -322,7 +322,7 @@ def main():
...
@@ -322,7 +322,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_translation.py
View file @
9fa29959
...
@@ -294,7 +294,7 @@ def main():
...
@@ -294,7 +294,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -306,7 +306,7 @@ def main():
...
@@ -306,7 +306,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_glue.py
View file @
9fa29959
...
@@ -239,7 +239,7 @@ def main():
...
@@ -239,7 +239,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
task_name
is
not
None
:
if
data_args
.
task_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
)
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Loading a dataset from your local files.
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
# CSV/JSON training and evaluation files are needed.
...
@@ -263,10 +263,10 @@ def main():
...
@@ -263,10 +263,10 @@ def main():
if
data_args
.
train_file
.
endswith
(
".csv"
):
if
data_args
.
train_file
.
endswith
(
".csv"
):
# Loading a dataset from local csv files
# Loading a dataset from local csv files
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Loading a dataset from local json files
# Loading a dataset from local json files
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset at
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_xnli.py
View file @
9fa29959
...
@@ -209,17 +209,19 @@ def main():
...
@@ -209,17 +209,19 @@ def main():
# Downloading and loading xnli dataset from the hub.
# Downloading and loading xnli dataset from the hub.
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
model_args
.
train_language
is
None
:
if
model_args
.
train_language
is
None
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
train_dataset
.
features
[
"label"
].
names
label_list
=
train_dataset
.
features
[
"label"
].
names
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
)
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
eval_dataset
.
features
[
"label"
].
names
label_list
=
eval_dataset
.
features
[
"label"
].
names
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
)
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
test_dataset
.
features
[
"label"
].
names
label_list
=
test_dataset
.
features
[
"label"
].
names
# Labels
# Labels
...
...
examples/token-classification/run_ner.py
View file @
9fa29959
...
@@ -229,7 +229,7 @@ def main():
...
@@ -229,7 +229,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -239,7 +239,7 @@ def main():
...
@@ -239,7 +239,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment