Unverified Commit 955b2b97 authored by kumapo's avatar kumapo Committed by GitHub
Browse files

Enable add_prefix_space if model_type is roberta or gpt2 (#12116)

parent 60b1d6b4
...@@ -304,13 +304,26 @@ def main(): ...@@ -304,13 +304,26 @@ def main():
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, tokenizer_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
use_fast=True, use_fast=True,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
) )
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForTokenClassification.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
......
...@@ -317,16 +317,18 @@ def main(): ...@@ -317,16 +317,18 @@ def main():
config = CONFIG_MAPPING[args.model_type]() config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if args.tokenizer_name: tokenizer_name_or_path = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) if not tokenizer_name_or_path:
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name." "You can do it from another script, save it, and load it from here, using --tokenizer_name."
) )
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True, add_prefix_space=True)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True)
if args.model_name_or_path: if args.model_name_or_path:
model = AutoModelForTokenClassification.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment