Unverified Commit 149cb0cc authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add `token` arugment in example scripts (#25172)



* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c6a8768d
...@@ -20,6 +20,7 @@ import json ...@@ -20,6 +20,7 @@ import json
import logging import logging
import os import os
import sys import sys
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
...@@ -170,15 +171,21 @@ class ModelArguments: ...@@ -170,15 +171,21 @@ class ModelArguments:
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
use_auth_token: bool = field( token: str = field(
default=False, default=None,
metadata={ metadata={
"help": ( "help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"with private models)." "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
) )
}, },
) )
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
# endregion # endregion
...@@ -198,6 +205,12 @@ def main(): ...@@ -198,6 +205,12 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_text_classification", model_args, data_args, framework="tensorflow") send_example_telemetry("run_text_classification", model_args, data_args, framework="tensorflow")
...@@ -258,7 +271,7 @@ def main(): ...@@ -258,7 +271,7 @@ def main():
"csv", "csv",
data_files=data_files, data_files=data_files,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, token=model_args.token,
) )
else: else:
# Loading a dataset from local json files # Loading a dataset from local json files
...@@ -301,20 +314,20 @@ def main(): ...@@ -301,20 +314,20 @@ def main():
num_labels=num_labels, num_labels=num_labels,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
config_path, config_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
# endregion # endregion
...@@ -402,7 +415,7 @@ def main(): ...@@ -402,7 +415,7 @@ def main():
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
# endregion # endregion
......
...@@ -21,6 +21,7 @@ import json ...@@ -21,6 +21,7 @@ import json
import logging import logging
import os import os
import random import random
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
...@@ -75,15 +76,21 @@ class ModelArguments: ...@@ -75,15 +76,21 @@ class ModelArguments:
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
use_auth_token: bool = field( token: str = field(
default=False, default=None,
metadata={ metadata={
"help": ( "help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"with private models)." "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
) )
}, },
) )
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
@dataclass @dataclass
...@@ -196,6 +203,12 @@ def main(): ...@@ -196,6 +203,12 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_ner", model_args, data_args, framework="tensorflow") send_example_telemetry("run_ner", model_args, data_args, framework="tensorflow")
...@@ -228,7 +241,7 @@ def main(): ...@@ -228,7 +241,7 @@ def main():
raw_datasets = load_dataset( raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_name,
data_args.dataset_config_name, data_args.dataset_config_name,
use_auth_token=True if model_args.use_auth_token else None, token=model_args.token,
) )
else: else:
data_files = {} data_files = {}
...@@ -240,7 +253,7 @@ def main(): ...@@ -240,7 +253,7 @@ def main():
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension, extension,
data_files=data_files, data_files=data_files,
use_auth_token=True if model_args.use_auth_token else None, token=model_args.token,
) )
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html. # https://huggingface.co/docs/datasets/loading_datasets.html.
......
...@@ -22,6 +22,7 @@ import json ...@@ -22,6 +22,7 @@ import json
import logging import logging
import os import os
import sys import sys
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
...@@ -93,15 +94,21 @@ class ModelArguments: ...@@ -93,15 +94,21 @@ class ModelArguments:
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
use_auth_token: bool = field( token: str = field(
default=False, default=None,
metadata={ metadata={
"help": ( "help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"with private models)." "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
) )
}, },
) )
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
@dataclass @dataclass
...@@ -268,6 +275,12 @@ def main(): ...@@ -268,6 +275,12 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_translation", model_args, data_args, framework="tensorflow") send_example_telemetry("run_translation", model_args, data_args, framework="tensorflow")
...@@ -322,7 +335,7 @@ def main(): ...@@ -322,7 +335,7 @@ def main():
data_args.dataset_name, data_args.dataset_name,
data_args.dataset_config_name, data_args.dataset_config_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, token=model_args.token,
) )
else: else:
data_files = {} data_files = {}
...@@ -336,7 +349,7 @@ def main(): ...@@ -336,7 +349,7 @@ def main():
extension, extension,
data_files=data_files, data_files=data_files,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None, token=model_args.token,
) )
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading # https://huggingface.co/docs/datasets/loading
...@@ -352,14 +365,14 @@ def main(): ...@@ -352,14 +365,14 @@ def main():
model_args.config_name if model_args.config_name else model_args.model_name_or_path, model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
...@@ -466,7 +479,7 @@ def main(): ...@@ -466,7 +479,7 @@ def main():
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.use_auth_token else None, token=model_args.token,
) )
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment