Unverified Commit 149cb0cc authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add `token` arugment in example scripts (#25172)



* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent c6a8768d
......@@ -20,6 +20,7 @@ import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
......@@ -170,15 +171,21 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
token: str = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
# endregion
......@@ -198,6 +205,12 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_text_classification", model_args, data_args, framework="tensorflow")
......@@ -258,7 +271,7 @@ def main():
"csv",
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
# Loading a dataset from local json files
......@@ -301,20 +314,20 @@ def main():
num_labels=num_labels,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
config = AutoConfig.from_pretrained(
config_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# endregion
......@@ -402,7 +415,7 @@ def main():
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# endregion
......
......@@ -21,6 +21,7 @@ import json
import logging
import os
import random
import warnings
from dataclasses import dataclass, field
from typing import Optional
......@@ -75,15 +76,21 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
token: str = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
@dataclass
......@@ -196,6 +203,12 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_ner", model_args, data_args, framework="tensorflow")
......@@ -228,7 +241,7 @@ def main():
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
......@@ -240,7 +253,7 @@ def main():
raw_datasets = load_dataset(
extension,
data_files=data_files,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
......
......@@ -22,6 +22,7 @@ import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional
......@@ -93,15 +94,21 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
token: str = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
@dataclass
......@@ -268,6 +275,12 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_translation", model_args, data_args, framework="tensorflow")
......@@ -322,7 +335,7 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
......@@ -336,7 +349,7 @@ def main():
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading
......@@ -352,14 +365,14 @@ def main():
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
......@@ -466,7 +479,7 @@ def main():
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment