Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9fa29959
You need to sign in or sign up before continuing.
Unverified
Commit
9fa29959
authored
Apr 13, 2021
by
Philipp Schmid
Committed by
GitHub
Apr 13, 2021
Browse files
added cache_dir=model_args.cache_dir to all example with cache_dir arg (#11220)
parent
3312e96b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
37 additions
and
27 deletions
+37
-27
examples/language-modeling/run_clm.py
examples/language-modeling/run_clm.py
+4
-2
examples/language-modeling/run_mlm.py
examples/language-modeling/run_mlm.py
+4
-2
examples/language-modeling/run_mlm_flax.py
examples/language-modeling/run_mlm_flax.py
+4
-2
examples/language-modeling/run_plm.py
examples/language-modeling/run_plm.py
+4
-2
examples/multiple-choice/run_swag.py
examples/multiple-choice/run_swag.py
+2
-2
examples/question-answering/run_qa.py
examples/question-answering/run_qa.py
+2
-2
examples/question-answering/run_qa_beam_search.py
examples/question-answering/run_qa_beam_search.py
+2
-2
examples/seq2seq/run_summarization.py
examples/seq2seq/run_summarization.py
+2
-2
examples/seq2seq/run_translation.py
examples/seq2seq/run_translation.py
+2
-2
examples/text-classification/run_glue.py
examples/text-classification/run_glue.py
+3
-3
examples/text-classification/run_xnli.py
examples/text-classification/run_xnli.py
+6
-4
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+2
-2
No files found.
examples/language-modeling/run_clm.py
View file @
9fa29959
...
@@ -230,17 +230,19 @@ def main():
...
@@ -230,17 +230,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -255,7 +257,7 @@ def main():
...
@@ -255,7 +257,7 @@ def main():
)
)
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm.py
View file @
9fa29959
...
@@ -239,17 +239,19 @@ def main():
...
@@ -239,17 +239,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -260,7 +262,7 @@ def main():
...
@@ -260,7 +262,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_mlm_flax.py
View file @
9fa29959
...
@@ -475,17 +475,19 @@ if __name__ == "__main__":
...
@@ -475,17 +475,19 @@ if __name__ == "__main__":
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -496,7 +498,7 @@ if __name__ == "__main__":
...
@@ -496,7 +498,7 @@ if __name__ == "__main__":
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/language-modeling/run_plm.py
View file @
9fa29959
...
@@ -236,17 +236,19 @@ def main():
...
@@ -236,17 +236,19 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
if
"validation"
not
in
datasets
.
keys
():
if
"validation"
not
in
datasets
.
keys
():
datasets
[
"validation"
]
=
load_dataset
(
datasets
[
"validation"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
split
=
f
"train[:
{
data_args
.
validation_split_percentage
}
%]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
datasets
[
"train"
]
=
load_dataset
(
datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
data_args
.
dataset_config_name
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
split
=
f
"train[
{
data_args
.
validation_split_percentage
}
%:]"
,
cache_dir
=
model_args
.
cache_dir
,
)
)
else
:
else
:
data_files
=
{}
data_files
=
{}
...
@@ -257,7 +259,7 @@ def main():
...
@@ -257,7 +259,7 @@ def main():
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
extension
==
"txt"
:
if
extension
==
"txt"
:
extension
=
"text"
extension
=
"text"
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/multiple-choice/run_swag.py
View file @
9fa29959
...
@@ -268,10 +268,10 @@ def main():
...
@@ -268,10 +268,10 @@ def main():
if
data_args
.
validation_file
is
not
None
:
if
data_args
.
validation_file
is
not
None
:
data_files
[
"validation"
]
=
data_args
.
validation_file
data_files
[
"validation"
]
=
data_args
.
validation_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Downloading and loading the swag dataset from the hub.
# Downloading and loading the swag dataset from the hub.
datasets
=
load_dataset
(
"swag"
,
"regular"
)
datasets
=
load_dataset
(
"swag"
,
"regular"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa.py
View file @
9fa29959
...
@@ -256,7 +256,7 @@ def main():
...
@@ -256,7 +256,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -269,7 +269,7 @@ def main():
...
@@ -269,7 +269,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/question-answering/run_qa_beam_search.py
View file @
9fa29959
...
@@ -255,7 +255,7 @@ def main():
...
@@ -255,7 +255,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -267,7 +267,7 @@ def main():
...
@@ -267,7 +267,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
field
=
"data"
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_summarization.py
View file @
9fa29959
...
@@ -310,7 +310,7 @@ def main():
...
@@ -310,7 +310,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -322,7 +322,7 @@ def main():
...
@@ -322,7 +322,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/seq2seq/run_translation.py
View file @
9fa29959
...
@@ -294,7 +294,7 @@ def main():
...
@@ -294,7 +294,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -306,7 +306,7 @@ def main():
...
@@ -306,7 +306,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_glue.py
View file @
9fa29959
...
@@ -239,7 +239,7 @@ def main():
...
@@ -239,7 +239,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
task_name
is
not
None
:
if
data_args
.
task_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
)
datasets
=
load_dataset
(
"glue"
,
data_args
.
task_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Loading a dataset from your local files.
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
# CSV/JSON training and evaluation files are needed.
...
@@ -263,10 +263,10 @@ def main():
...
@@ -263,10 +263,10 @@ def main():
if
data_args
.
train_file
.
endswith
(
".csv"
):
if
data_args
.
train_file
.
endswith
(
".csv"
):
# Loading a dataset from local csv files
# Loading a dataset from local csv files
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"csv"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
# Loading a dataset from local json files
# Loading a dataset from local json files
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
)
datasets
=
load_dataset
(
"json"
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset at
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
examples/text-classification/run_xnli.py
View file @
9fa29959
...
@@ -209,17 +209,19 @@ def main():
...
@@ -209,17 +209,19 @@ def main():
# Downloading and loading xnli dataset from the hub.
# Downloading and loading xnli dataset from the hub.
if
training_args
.
do_train
:
if
training_args
.
do_train
:
if
model_args
.
train_language
is
None
:
if
model_args
.
train_language
is
None
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
)
train_dataset
=
load_dataset
(
"xnli"
,
model_args
.
train_language
,
split
=
"train"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
train_dataset
.
features
[
"label"
].
names
label_list
=
train_dataset
.
features
[
"label"
].
names
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
)
eval_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"validation"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
eval_dataset
.
features
[
"label"
].
names
label_list
=
eval_dataset
.
features
[
"label"
].
names
if
training_args
.
do_predict
:
if
training_args
.
do_predict
:
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
)
test_dataset
=
load_dataset
(
"xnli"
,
model_args
.
language
,
split
=
"test"
,
cache_dir
=
model_args
.
cache_dir
)
label_list
=
test_dataset
.
features
[
"label"
].
names
label_list
=
test_dataset
.
features
[
"label"
].
names
# Labels
# Labels
...
...
examples/token-classification/run_ner.py
View file @
9fa29959
...
@@ -229,7 +229,7 @@ def main():
...
@@ -229,7 +229,7 @@ def main():
# download the dataset.
# download the dataset.
if
data_args
.
dataset_name
is
not
None
:
if
data_args
.
dataset_name
is
not
None
:
# Downloading and loading a dataset from the hub.
# Downloading and loading a dataset from the hub.
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
)
datasets
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config_name
,
cache_dir
=
model_args
.
cache_dir
)
else
:
else
:
data_files
=
{}
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
if
data_args
.
train_file
is
not
None
:
...
@@ -239,7 +239,7 @@ def main():
...
@@ -239,7 +239,7 @@ def main():
if
data_args
.
test_file
is
not
None
:
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
)
datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# https://huggingface.co/docs/datasets/loading_datasets.html.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment