"docs/source/en/model_doc/pix2struct.md" did not exist on "fd3eb3e3cd62f1a078aadba791d03d042678313e"
Unverified Commit 453a70d4 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow example to use a revision and work with private models (#9407)

* Allow example to use a revision and work with private models

* Copy to other examples and template

* Styling
parent 7988edc0
...@@ -83,6 +83,17 @@ class ModelArguments: ...@@ -83,6 +83,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -224,22 +235,29 @@ def main(): ...@@ -224,22 +235,29 @@ def main():
# The .from_pretrained methods guarantee that only one local process can concurrently # The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -252,6 +270,8 @@ def main(): ...@@ -252,6 +270,8 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
......
...@@ -81,6 +81,17 @@ class ModelArguments: ...@@ -81,6 +81,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -234,22 +245,29 @@ def main(): ...@@ -234,22 +245,29 @@ def main():
# Distributed training: # Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently # The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -262,6 +280,8 @@ def main(): ...@@ -262,6 +280,8 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
......
...@@ -83,6 +83,17 @@ class ModelArguments: ...@@ -83,6 +83,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -247,22 +258,29 @@ def main(): ...@@ -247,22 +258,29 @@ def main():
# Distributed training: # Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently # The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -275,6 +293,8 @@ def main(): ...@@ -275,6 +293,8 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
......
...@@ -71,6 +71,17 @@ class ModelArguments: ...@@ -71,6 +71,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -231,22 +242,29 @@ def main(): ...@@ -231,22 +242,29 @@ def main():
# Distributed training: # Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently # The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else: else:
config = XLNetConfig() config = XLNetConfig()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -259,6 +277,8 @@ def main(): ...@@ -259,6 +277,8 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
......
...@@ -68,6 +68,17 @@ class ModelArguments: ...@@ -68,6 +68,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -245,17 +256,23 @@ def main(): ...@@ -245,17 +256,23 @@ def main():
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path, model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = AutoModelForMultipleChoice.from_pretrained( model = AutoModelForMultipleChoice.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
# When using your own dataset or a different dataset from swag, you will probably need to change this. # When using your own dataset or a different dataset from swag, you will probably need to change this.
......
...@@ -65,6 +65,17 @@ class ModelArguments: ...@@ -65,6 +65,17 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"}, metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -220,17 +231,23 @@ def main(): ...@@ -220,17 +231,23 @@ def main():
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path, model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=True, use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = AutoModelForQuestionAnswering.from_pretrained( model = AutoModelForQuestionAnswering.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
# Tokenizer check: this script requires a fast tokenizer. # Tokenizer check: this script requires a fast tokenizer.
......
...@@ -64,9 +64,16 @@ class ModelArguments: ...@@ -64,9 +64,16 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
) )
use_fast_tokenizer: bool = field( model_revision: str = field(
default=True, default="main",
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
) )
...@@ -223,16 +230,22 @@ def main(): ...@@ -223,16 +230,22 @@ def main():
config = XLNetConfig.from_pretrained( config = XLNetConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path, model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = XLNetTokenizerFast.from_pretrained( tokenizer = XLNetTokenizerFast.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = XLNetForQuestionAnswering.from_pretrained( model = XLNetForQuestionAnswering.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
# Preprocessing the datasets. # Preprocessing the datasets.
......
...@@ -131,6 +131,17 @@ class ModelArguments: ...@@ -131,6 +131,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
def main(): def main():
...@@ -236,17 +247,23 @@ def main(): ...@@ -236,17 +247,23 @@ def main():
num_labels=num_labels, num_labels=num_labels,
finetuning_task=data_args.task_name, finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
# Preprocessing the datasets # Preprocessing the datasets
......
...@@ -65,6 +65,17 @@ class ModelArguments: ...@@ -65,6 +65,17 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
@dataclass @dataclass
...@@ -238,17 +249,23 @@ def main(): ...@@ -238,17 +249,23 @@ def main():
num_labels=num_labels, num_labels=num_labels,
finetuning_task=data_args.task_name, finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=True, use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = AutoModelForTokenClassification.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
# Tokenizer check: this script requires a fast tokenizer. # Tokenizer check: this script requires a fast tokenizer.
......
...@@ -104,6 +104,17 @@ class ModelArguments: ...@@ -104,6 +104,17 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
) )
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
},
)
{% endif %} {% endif %}
...@@ -219,22 +230,29 @@ def main(): ...@@ -219,22 +230,29 @@ def main():
# The .from_pretrained methods guarantee that only one local process can concurrently # The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
{%- if cookiecutter.can_train_from_scratch == "True" %} {%- if cookiecutter.can_train_from_scratch == "True" %}
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -247,6 +265,8 @@ def main(): ...@@ -247,6 +265,8 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
...@@ -259,17 +279,23 @@ def main(): ...@@ -259,17 +279,23 @@ def main():
num_labels=num_labels, num_labels=num_labels,
finetuning_task=data_args.task_name, finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
) )
{% endif %} {% endif %}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment