Commit ca6bdb28 authored by thomwolf's avatar thomwolf
Browse files

fix pipelines and rename model_card => modelcard

parent 61d9ee45
...@@ -35,7 +35,7 @@ if is_sklearn_available(): ...@@ -35,7 +35,7 @@ if is_sklearn_available():
from .data import glue_compute_metrics, xnli_compute_metrics from .data import glue_compute_metrics, xnli_compute_metrics
# Model Cards # Model Cards
from .model_card import ModelCard from .modelcard import ModelCard
# Tokenizers # Tokenizers
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
......
...@@ -81,7 +81,7 @@ WEIGHTS_NAME = "pytorch_model.bin" ...@@ -81,7 +81,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5' TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "model_card.json" MODEL_CARD_NAME = "modelcard.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
...@@ -339,7 +339,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -339,7 +339,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
temp_file_manager = tempfile.NamedTemporaryFile temp_file_manager = tempfile.NamedTemporaryFile
resume_size = 0 resume_size = 0
if not os.path.exists(cache_path) or force_download: if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
......
...@@ -25,7 +25,8 @@ from io import open ...@@ -25,7 +25,8 @@ from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, cached_path, is_remote_url, hf_bucket_url from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -89,7 +90,7 @@ class ModelCard(object): ...@@ -89,7 +90,7 @@ class ModelCard(object):
- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/model_card.json``. - a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
...@@ -100,16 +101,14 @@ class ModelCard(object): ...@@ -100,16 +101,14 @@ class ModelCard(object):
- The values in kwargs of any keys which are model card attributes will be used to override the loaded values. - The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter. - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model card file and override the cached version if it exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
find_from_standard_name: (`optional`) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
Can be used to directly feed a model/config url and access the colocated modelcard.
return_unused_kwargs: (`optional`) bool: return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final model card object. - If False, then this function returns just the final model card object.
...@@ -117,22 +116,21 @@ class ModelCard(object): ...@@ -117,22 +116,21 @@ class ModelCard(object):
Examples:: Examples::
model_card = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache. modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
model_card = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')` modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
model_card = ModelCard.from_pretrained('./test/saved_model/model_card.json') modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
model_card = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False) modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files but with a different suffix (model_card.json) # For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME) model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
...@@ -140,17 +138,22 @@ class ModelCard(object): ...@@ -140,17 +138,22 @@ class ModelCard(object):
else: else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME) model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=force_download, resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
proxies=proxies, resume_download=resume_download) proxies=proxies, resume_download=False)
if resolved_model_card_file == model_card_file: if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file)) logger.info("loading model card file {}".format(model_card_file))
else: else:
logger.info("loading model card file {} from cache at {}".format( logger.info("loading model card file {} from cache at {}".format(
model_card_file, resolved_model_card_file)) model_card_file, resolved_model_card_file))
# Load model card # Load model card
model_card = cls.from_json_file(resolved_model_card_file) modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
...@@ -166,7 +169,7 @@ class ModelCard(object): ...@@ -166,7 +169,7 @@ class ModelCard(object):
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
model_card = cls() modelcard = cls()
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or " logger.warning("Couldn't reach server at '{}' to download model card file or "
...@@ -175,22 +178,22 @@ class ModelCard(object): ...@@ -175,22 +178,22 @@ class ModelCard(object):
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
model_card = cls() modelcard = cls()
# Update model card with kwargs if needed # Update model card with kwargs if needed
to_remove = [] to_remove = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(model_card, key): if hasattr(modelcard, key):
setattr(model_card, key, value) setattr(modelcard, key, value)
to_remove.append(key) to_remove.append(key)
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model card: %s", str(model_card)) logger.info("Model card: %s", str(modelcard))
if return_unused_kwargs: if return_unused_kwargs:
return model_card, kwargs return modelcard, kwargs
else: else:
return model_card return modelcard
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
......
...@@ -18,6 +18,8 @@ import csv ...@@ -18,6 +18,8 @@ import csv
import json import json
import os import os
import pickle import pickle
import logging
import six
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import groupby from itertools import groupby
...@@ -26,8 +28,12 @@ from typing import Union, Optional, Tuple, List, Dict ...@@ -26,8 +28,12 @@ from typing import Union, Optional, Tuple, List, Dict
import numpy as np import numpy as np
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \ from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger, BasicTokenizer PretrainedConfig, ModelCard, SquadExample,
squad_convert_examples_to_features, is_tf_available,
is_torch_available, BasicTokenizer,
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP)
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -40,6 +46,8 @@ if is_torch_available(): ...@@ -40,6 +46,8 @@ if is_torch_available():
AutoModelForQuestionAnswering, AutoModelForTokenClassification AutoModelForQuestionAnswering, AutoModelForTokenClassification
logger = logging.getLogger(__name__)
class ArgumentHandler(ABC): class ArgumentHandler(ABC):
""" """
Base interface for handling varargs for each Pipeline Base interface for handling varargs for each Pipeline
...@@ -271,11 +279,13 @@ class Pipeline(_ScikitCompat): ...@@ -271,11 +279,13 @@ class Pipeline(_ScikitCompat):
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...') nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
""" """
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False): binary_output: bool = False):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.modelcard = modelcard
self.device = device self.device = device
self.binary_output = binary_output self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler() self._args_parser = args_parser or DefaultArgumentHandler()
...@@ -294,6 +304,7 @@ class Pipeline(_ScikitCompat): ...@@ -294,6 +304,7 @@ class Pipeline(_ScikitCompat):
self.model.save_pretrained(save_directory) self.model.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory)
self.modelcard.save_pretrained(save_directory)
def transform(self, X): def transform(self, X):
""" """
...@@ -393,9 +404,10 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -393,9 +404,10 @@ class FeatureExtractionPipeline(Pipeline):
def __init__(self, model, def __init__(self, model,
tokenizer: PreTrainedTokenizer = None, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1): device: int = -1):
super().__init__(model, tokenizer, args_parser, device, binary_output=True) super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output=True)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist() return super().__call__(*args, **kwargs).tolist()
...@@ -418,9 +430,10 @@ class NerPipeline(Pipeline): ...@@ -418,9 +430,10 @@ class NerPipeline(Pipeline):
""" """
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
binary_output: bool = False): binary_output: bool = False):
super().__init__(model, tokenizer, args_parser, device, binary_output) super().__init__(model, tokenizer, modelcard, args_parser, device, binary_output)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
...@@ -554,8 +567,10 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -554,8 +567,10 @@ class QuestionAnsweringPipeline(Pipeline):
else: else:
return SquadExample(None, question, context, None, None, None) return SquadExample(None, question, context, None, None, None)
def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer], device: int = -1, **kwargs): def __init__(self, model, tokenizer: Optional[PreTrainedTokenizer],
super().__init__(model, tokenizer, args_parser=QuestionAnsweringArgumentHandler(), modelcard: Optional[ModelCard],
device: int = -1, **kwargs):
super().__init__(model, tokenizer, modelcard, args_parser=QuestionAnsweringArgumentHandler(),
device=device, **kwargs) device=device, **kwargs)
def __call__(self, *texts, **kwargs): def __call__(self, *texts, **kwargs):
...@@ -725,7 +740,7 @@ SUPPORTED_TASKS = { ...@@ -725,7 +740,7 @@ SUPPORTED_TASKS = {
'default': { 'default': {
'model': 'distilbert-base-uncased', 'model': 'distilbert-base-uncased',
'config': None, 'config': None,
'tokenizer': 'bert-base-uncased' 'tokenizer': 'distilbert-base-uncased'
} }
}, },
'sentiment-analysis': { 'sentiment-analysis': {
...@@ -735,7 +750,7 @@ SUPPORTED_TASKS = { ...@@ -735,7 +750,7 @@ SUPPORTED_TASKS = {
'default': { 'default': {
'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin', 'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json', 'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json',
'tokenizer': 'bert-base-uncased' 'tokenizer': 'distilbert-base-uncased'
} }
}, },
'ner': { 'ner': {
...@@ -745,7 +760,7 @@ SUPPORTED_TASKS = { ...@@ -745,7 +760,7 @@ SUPPORTED_TASKS = {
'default': { 'default': {
'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin', 'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json', 'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json',
'tokenizer': 'bert-base-cased' 'tokenizer': 'bert-large-cased'
} }
}, },
'question-answering': { 'question-answering': {
...@@ -755,7 +770,7 @@ SUPPORTED_TASKS = { ...@@ -755,7 +770,7 @@ SUPPORTED_TASKS = {
'default': { 'default': {
'model': 'distilbert-base-uncased-distilled-squad', 'model': 'distilbert-base-uncased-distilled-squad',
'config': None, 'config': None,
'tokenizer': 'bert-base-uncased' 'tokenizer': 'distilbert-base-uncased'
} }
} }
} }
...@@ -763,7 +778,9 @@ SUPPORTED_TASKS = { ...@@ -763,7 +778,9 @@ SUPPORTED_TASKS = {
def pipeline(task: str, model: Optional = None, def pipeline(task: str, model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline: tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
**kwargs) -> Pipeline:
""" """
Utility factory method to build a pipeline. Utility factory method to build a pipeline.
Pipeline are made of: Pipeline are made of:
...@@ -777,48 +794,63 @@ def pipeline(task: str, model: Optional = None, ...@@ -777,48 +794,63 @@ def pipeline(task: str, model: Optional = None,
pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...) pipeline('ner', model=AutoModel.from_pretrained(...), tokenizer=AutoTokenizer.from_pretrained(...)
pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased') pipeline('ner', model='https://...pytorch-model.bin', config='https://...config.json', tokenizer='bert-base-cased')
""" """
# Try to infer tokenizer from model name (if provided as str)
if tokenizer is None:
if model is not None and not isinstance(model, str):
# Impossible to guest what is the right tokenizer here
raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')
else:
tokenizer = model
# Retrieve the task # Retrieve the task
if task not in SUPPORTED_TASKS: if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()))) raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
pipeline_framework = 'tf' if is_tf_available() else ('pt' if is_torch_available() else None)
if pipeline_framework is None:
raise ImportError("At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.")
targeted_task = SUPPORTED_TASKS[task] targeted_task = SUPPORTED_TASKS[task]
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt'] task, model_class = targeted_task['impl'], targeted_task[pipeline_framework]
# Handling for default model for the task # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
model, config, tokenizer = tuple(targeted_task['default'].values()) model, config, tokenizer = tuple(targeted_task['default'].values())
# Allocate tokenizer # Try to infer tokenizer from model or config name (if provided as str)
tokenizer = tokenizer if isinstance(tokenizer, PreTrainedTokenizer) else AutoTokenizer.from_pretrained(tokenizer) if tokenizer is None:
if isinstance(model, str) and model in ALL_PRETRAINED_MODEL_ARCHIVE_MAP:
tokenizer = model
elif isinstance(config, str) and model in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
tokenizer = config
else:
# Impossible to guest what is the right tokenizer here
raise Exception("Impossible to guess which tokenizer to use. "
"Please provided a PretrainedTokenizer class or a path/url/shortcut name to a pretrained tokenizer.")
# Try to infer modelcard from model or config name (if provided as str)
if modelcard is None:
# Try to fallback on one of the provided string for model or config (will replace the suffix)
if isinstance(model, str):
modelcard = model
elif isinstance(config, str):
modelcard = config
# Instantiate tokenizer if needed
if isinstance(tokenizer, six.string_types):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config)
# Special handling for model conversion # Instantiate model if needed
if isinstance(model, str): if isinstance(model, str):
from_tf = model.endswith('.h5') and not is_tf_available() # Handle transparent TF/PT model conversion
from_pt = model.endswith('.bin') and not is_torch_available() model_kwargs = {}
if pipeline_framework == 'pt' and model.endswith('.h5'):
if from_tf: model_kwargs['from_tf'] = True
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. ' logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.') 'Trying to load the model with PyTorch.')
elif from_pt: elif pipeline_framework == 'tf' and model.endswith('.bin'):
model_kwargs['from_pt'] = True
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. ' logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.') 'Trying to load the model with Tensorflow.')
else: model = model_class.from_pretrained(model, config=config, **model_kwargs)
from_tf = from_pt = False
if isinstance(config, str):
config = AutoConfig.from_pretrained(config)
if isinstance(model, str):
if allocator.__name__.startswith('TF'):
model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
else:
model = allocator.from_pretrained(model, config=config, from_tf=from_tf)
return task(model, tokenizer, **kwargs) return task(model, tokenizer, **kwargs)
...@@ -18,7 +18,7 @@ import os ...@@ -18,7 +18,7 @@ import os
import json import json
import unittest import unittest
from transformers.model_card import ModelCard from transformers.modelcard import ModelCard
from .tokenization_tests_commons import TemporaryDirectory from .tokenization_tests_commons import TemporaryDirectory
class ModelCardTester(unittest.TestCase): class ModelCardTester(unittest.TestCase):
...@@ -49,20 +49,20 @@ class ModelCardTester(unittest.TestCase): ...@@ -49,20 +49,20 @@ class ModelCardTester(unittest.TestCase):
} }
def test_model_card_common_properties(self): def test_model_card_common_properties(self):
model_card = ModelCard.from_dict(self.inputs_dict) modelcard = ModelCard.from_dict(self.inputs_dict)
self.assertTrue(hasattr(model_card, 'model_details')) self.assertTrue(hasattr(modelcard, 'model_details'))
self.assertTrue(hasattr(model_card, 'intended_use')) self.assertTrue(hasattr(modelcard, 'intended_use'))
self.assertTrue(hasattr(model_card, 'factors')) self.assertTrue(hasattr(modelcard, 'factors'))
self.assertTrue(hasattr(model_card, 'metrics')) self.assertTrue(hasattr(modelcard, 'metrics'))
self.assertTrue(hasattr(model_card, 'evaluation_data')) self.assertTrue(hasattr(modelcard, 'evaluation_data'))
self.assertTrue(hasattr(model_card, 'training_data')) self.assertTrue(hasattr(modelcard, 'training_data'))
self.assertTrue(hasattr(model_card, 'quantitative_analyses')) self.assertTrue(hasattr(modelcard, 'quantitative_analyses'))
self.assertTrue(hasattr(model_card, 'ethical_considerations')) self.assertTrue(hasattr(modelcard, 'ethical_considerations'))
self.assertTrue(hasattr(model_card, 'caveats_and_recommendations')) self.assertTrue(hasattr(modelcard, 'caveats_and_recommendations'))
def test_model_card_to_json_string(self): def test_model_card_to_json_string(self):
model_card = ModelCard.from_dict(self.inputs_dict) modelcard = ModelCard.from_dict(self.inputs_dict)
obj = json.loads(model_card.to_json_string()) obj = json.loads(modelcard.to_json_string())
for key, value in self.inputs_dict.items(): for key, value in self.inputs_dict.items():
self.assertEqual(obj[key], value) self.assertEqual(obj[key], value)
...@@ -70,7 +70,7 @@ class ModelCardTester(unittest.TestCase): ...@@ -70,7 +70,7 @@ class ModelCardTester(unittest.TestCase):
model_card_first = ModelCard.from_dict(self.inputs_dict) model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"model_card.json") filename = os.path.join(tmpdirname, u"modelcard.json")
model_card_first.to_json_file(filename) model_card_first.to_json_file(filename)
model_card_second = ModelCard.from_json_file(filename) model_card_second = ModelCard.from_json_file(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment