"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aa629e7a7c87e99490b673009ce09a5e9fbdd8c6"
Commit d6c929e2 authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Merge remote-tracking branch 'origin/master'

parents a8694b88 955b2b97
...@@ -304,13 +304,26 @@ def main(): ...@@ -304,13 +304,26 @@ def main():
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
cache_dir=model_args.cache_dir, if config.model_type in {"gpt2", "roberta"}:
use_fast=True, tokenizer = AutoTokenizer.from_pretrained(
revision=model_args.model_revision, tokenizer_name_or_path,
use_auth_token=True if model_args.use_auth_token else None, cache_dir=model_args.cache_dir,
) use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForTokenClassification.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
......
...@@ -317,16 +317,18 @@ def main(): ...@@ -317,16 +317,18 @@ def main():
config = CONFIG_MAPPING[args.model_type]() config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if args.tokenizer_name: tokenizer_name_or_path = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) if not tokenizer_name_or_path:
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name." "You can do it from another script, save it, and load it from here, using --tokenizer_name."
) )
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True, add_prefix_space=True)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True)
if args.model_name_or_path: if args.model_name_or_path:
model = AutoModelForTokenClassification.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment