Unverified Commit 7e7f62bf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix pipeline tests for Roberta-like tokenizers (#19365)

* Fix pipeline tests for Roberta-like tokenizers

* Fix fix
parent bad353ce
...@@ -37,8 +37,6 @@ from transformers import ( ...@@ -37,8 +37,6 @@ from transformers import (
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
IBertConfig,
RobertaConfig,
TextClassificationPipeline, TextClassificationPipeline,
TFAutoModelForSequenceClassification, TFAutoModelForSequenceClassification,
pipeline, pipeline,
...@@ -71,6 +69,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402 ...@@ -71,6 +69,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS = [
"CamembertConfig",
"IBertConfig",
"LongformerConfig",
"MarkupLMConfig",
"RobertaConfig",
"XLMRobertaConfig",
]
def get_checkpoint_from_architecture(architecture): def get_checkpoint_from_architecture(architecture):
try: try:
module = importlib.import_module(architecture.__module__) module = importlib.import_module(architecture.__module__)
...@@ -194,7 +202,7 @@ class PipelineTestCaseMeta(type): ...@@ -194,7 +202,7 @@ class PipelineTestCaseMeta(type):
try: try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1. # XLNet actually defines it as -1.
if isinstance(model.config, (RobertaConfig, IBertConfig)): if model.config.__class__.__name__ in ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS:
tokenizer.model_max_length = model.config.max_position_embeddings - 2 tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif ( elif (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment