"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "af2b78601b9bbc75835e0d45657b2253155e4abb"
Unverified Commit 2b60a26b authored by J.P Lee's avatar J.P Lee Committed by GitHub
Browse files

Update examples/ner/run_ner.py to use AutoModel (#3305)

* Update examples/ner/run_ner.py to use AutoModel

* Fix missing code and apply `make style` command
parent e41212c7
...@@ -31,28 +31,15 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -31,28 +31,15 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import ( from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
AlbertConfig, AutoConfig,
AlbertForTokenClassification, AutoModelForTokenClassification,
AlbertTokenizer, AutoTokenizer,
BertConfig,
BertForTokenClassification,
BertTokenizer,
CamembertConfig,
CamembertForTokenClassification,
CamembertTokenizer,
DistilBertConfig,
DistilBertForTokenClassification,
DistilBertTokenizer,
RobertaConfig,
RobertaForTokenClassification,
RobertaTokenizer,
XLMRobertaConfig,
XLMRobertaForTokenClassification,
XLMRobertaTokenizer,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from transformers.modeling_auto import MODEL_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
...@@ -64,22 +51,8 @@ except ImportError: ...@@ -64,22 +51,8 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum( ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
( MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
),
(),
)
MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
}
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"] TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
...@@ -411,7 +384,7 @@ def main(): ...@@ -411,7 +384,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
) )
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
...@@ -594,8 +567,7 @@ def main(): ...@@ -594,8 +567,7 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = AutoConfig.from_pretrained(
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels, num_labels=num_labels,
id2label={str(i): label for i, label in enumerate(labels)}, id2label={str(i): label for i, label in enumerate(labels)},
...@@ -604,12 +576,12 @@ def main(): ...@@ -604,12 +576,12 @@ def main():
) )
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS} tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
logger.info("Tokenizer arguments: %s", tokenizer_args) logger.info("Tokenizer arguments: %s", tokenizer_args)
tokenizer = tokenizer_class.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None, cache_dir=args.cache_dir if args.cache_dir else None,
**tokenizer_args, **tokenizer_args,
) )
model = model_class.from_pretrained( model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
...@@ -650,7 +622,7 @@ def main(): ...@@ -650,7 +622,7 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args) tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list( checkpoints = list(
...@@ -660,7 +632,7 @@ def main(): ...@@ -660,7 +632,7 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = AutoModelForTokenClassification.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step) result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
if global_step: if global_step:
...@@ -672,8 +644,8 @@ def main(): ...@@ -672,8 +644,8 @@ def main():
writer.write("{} = {}\n".format(key, str(results[key]))) writer.write("{} = {}\n".format(key, str(results[key])))
if args.do_predict and args.local_rank in [-1, 0]: if args.do_predict and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args) tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
model = model_class.from_pretrained(args.output_dir) model = AutoModelForTokenClassification.from_pretrained(args.output_dir)
model.to(args.device) model.to(args.device)
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test") result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
# Save results # Save results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment