"tests/vscode:/vscode.git/clone" did not exist on "f56174ac5b6f76d14babc135a709f4e9d53cc144"
Unverified Commit 955b2b97 authored by kumapo's avatar kumapo Committed by GitHub
Browse files

Enable add_prefix_space if model_type is roberta or gpt2 (#12116)

parent 60b1d6b4
......@@ -304,13 +304,26 @@ def main():
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
add_prefix_space=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
......
......@@ -317,16 +317,18 @@ def main():
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
else:
tokenizer_name_or_path = args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
if not tokenizer_name_or_path:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True, add_prefix_space=True)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True)
if args.model_name_or_path:
model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment