"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df1ddcedf28d6fa419719cc003a640181a080a88"
Commit 656e1386 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix #3305: run_ner only possible on ModelForTokenClassification models

parent 0c44b119
...@@ -31,7 +31,6 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -31,7 +31,6 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import ( from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
AutoConfig, AutoConfig,
...@@ -39,7 +38,7 @@ from transformers import ( ...@@ -39,7 +38,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
get_linear_schedule_with_warmup, get_linear_schedule_with_warmup,
) )
from transformers.modeling_auto import MODEL_MAPPING from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
...@@ -51,8 +50,9 @@ except ImportError: ...@@ -51,8 +50,9 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP) MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"] TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
...@@ -384,7 +384,7 @@ def main(): ...@@ -384,7 +384,7 @@ def main():
default=None, default=None,
type=str, type=str,
required=True, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES), help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
) )
parser.add_argument( parser.add_argument(
"--model_name_or_path", "--model_name_or_path",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment