"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "ad26f298e27c2ac3e03d78a2f15ab8cbb353b73d"
Unverified Commit 3a2c4e6f authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1548 from huggingface/cli

[2.2] - Command-line interface - Pipeline class
parents 4e3f745b db0795b5
...@@ -44,32 +44,6 @@ jobs: ...@@ -44,32 +44,6 @@ jobs:
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov - run: python -m pytest -sv ./transformers/tests/ --cov
- run: codecov - run: codecov
build_py2_torch:
working_directory: ~/transformers
resource_class: large
parallelism: 1
docker:
- image: circleci/python:2.7
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: codecov
build_py2_tf:
working_directory: ~/transformers
resource_class: large
parallelism: 1
docker:
- image: circleci/python:2.7
steps:
- checkout
- run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: codecov
build_py3_custom_tokenizers: build_py3_custom_tokenizers:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -80,17 +54,6 @@ jobs: ...@@ -80,17 +54,6 @@ jobs:
- run: sudo pip install pytest - run: sudo pip install pytest
- run: sudo pip install mecab-python3 - run: sudo pip install mecab-python3
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py - run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
build_py2_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:2.7
steps:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest
- run: sudo apt-get -y install libmecab-dev mecab mecab-ipadic-utf8 swig
- run: sudo pip install mecab-python
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
deploy_doc: deploy_doc:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -124,10 +87,7 @@ workflows: ...@@ -124,10 +87,7 @@ workflows:
jobs: jobs:
- repository_consistency - repository_consistency
- build_py3_custom_tokenizers - build_py3_custom_tokenizers
- build_py2_custom_tokenizers
- build_py3_torch_and_tf - build_py3_torch_and_tf
- build_py3_torch - build_py3_torch
- build_py3_tf - build_py3_tf
- build_py2_torch
- build_py2_tf
- deploy_doc: *workflow_filters - deploy_doc: *workflow_filters
...@@ -38,7 +38,9 @@ from setuptools import find_packages, setup ...@@ -38,7 +38,9 @@ from setuptools import find_packages, setup
extras = { extras = {
'serving': ['uvicorn', 'fastapi'] 'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
} }
extras['all'] = [package for package in extras.values()] extras['all'] = [package for package in extras.values()]
...@@ -62,11 +64,6 @@ setup( ...@@ -62,11 +64,6 @@ setup(
'regex != 2019.12.17', 'regex != 2019.12.17',
'sentencepiece', 'sentencepiece',
'sacremoses'], 'sacremoses'],
entry_points={
'console_scripts': [
"transformers=transformers.__main__:main",
]
},
extras_require=extras, extras_require=extras,
scripts=[ scripts=[
'transformers-cli' 'transformers-cli'
......
#!/usr/bin/env python #!/usr/bin/env python
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli <command> [<args>]') parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
commands_parser = parser.add_subparsers(help='transformers-cli command helpers') commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands # Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
RunCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser)
# Let's go # Let's go
......
...@@ -24,6 +24,7 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH ...@@ -24,6 +24,7 @@ from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH
from .data import (is_sklearn_available, from .data import (is_sklearn_available,
InputExample, InputFeatures, DataProcessor, InputExample, InputFeatures, DataProcessor,
SingleSentenceClassificationProcessor,
glue_output_modes, glue_convert_examples_to_features, glue_output_modes, glue_convert_examples_to_features,
glue_processors, glue_tasks_num_labels, glue_processors, glue_tasks_num_labels,
xnli_output_modes, xnli_processors, xnli_tasks_num_labels, xnli_output_modes, xnli_processors, xnli_tasks_num_labels,
...@@ -34,7 +35,7 @@ if is_sklearn_available(): ...@@ -34,7 +35,7 @@ if is_sklearn_available():
from .data import glue_compute_metrics, xnli_compute_metrics from .data import glue_compute_metrics, xnli_compute_metrics
# Model Cards # Model Cards
from .model_card import ModelCard from .modelcard import ModelCard
# Tokenizers # Tokenizers
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import (PreTrainedTokenizer)
...@@ -75,7 +76,7 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_ ...@@ -75,7 +76,7 @@ from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_
if is_torch_available(): if is_torch_available():
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead, ALL_PRETRAINED_MODEL_ARCHIVE_MAP) AutoModelWithLMHead, AutoModelForTokenClassification, ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining, from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
...@@ -136,7 +137,7 @@ if is_torch_available(): ...@@ -136,7 +137,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP) TFAutoModelWithLMHead, TFAutoModelForTokenClassification, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings, from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
TFBertModel, TFBertForPreTraining, TFBertModel, TFBertForPreTraining,
...@@ -206,6 +207,10 @@ from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name ...@@ -206,6 +207,10 @@ from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name
load_tf2_weights_in_pytorch_model, load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model) load_tf2_model_in_pytorch_model)
# Pipelines
from .pipelines import pipeline, PipelineDataFormat, CsvPipelineDataFormat, JsonPipelineDataFormat, PipedPipelineDataFormat, \
Pipeline, FeatureExtractionPipeline, QuestionAnsweringPipeline, NerPipeline, TextClassificationPipeline
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found." logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found."
"Models won't be available and only tokenizers, configuration" "Models won't be available and only tokenizers, configuration"
......
# coding: utf8 # coding: utf8
def main(): def main():
import sys import sys
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
print( print(
"This command line utility let you convert original (author released) model checkpoint to pytorch.\n" "First argument to `transformers` command line interface should be one of: \n"
"It should be used as one of: \n" ">> convert serve train predict")
">> transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" if sys.argv[1] == "convert":
">> transformers t5 TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" from transformers.commands import convert
">> transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" convert(sys.argv)
">> transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" elif sys.argv[1] == "train":
">> transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" from transformers.commands import train
">> transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" train(sys.argv)
">> transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") elif sys.argv[1] == "serve":
else: pass
if sys.argv[1] == "bert": # from argparse import ArgumentParser
try: # from transformers.commands.serving import ServeCommand
from .convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch # parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
except ImportError: # commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") # # Register commands
raise # ServeCommand.register_subcommand(commands_parser)
if len(sys.argv) != 5: # # Let's go
# pylint: disable=line-too-long # args = parser.parse_args()
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else: # if not hasattr(args, 'func'):
PYTORCH_DUMP_OUTPUT = sys.argv.pop() # parser.print_help()
TF_CONFIG = sys.argv.pop() # exit(1)
TF_CHECKPOINT = sys.argv.pop() # # Run
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) # service = args.func(args)
elif sys.argv[1] == "t5": # service.run()
try:
from .convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError:
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "gpt":
from .convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long
print("Should be used as `transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
else:
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
OPENAI_GPT_CONFIG = sys.argv[4]
else:
OPENAI_GPT_CONFIG = ""
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "transfo_xl":
try:
from .convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError:
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long
print("Should be used as `transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else:
if 'ckpt' in sys.argv[2].lower():
TF_CHECKPOINT = sys.argv[2]
TF_DATASET_FILE = ""
else:
TF_DATASET_FILE = sys.argv[2]
TF_CHECKPOINT = ""
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
elif sys.argv[1] == "gpt2":
try:
from .convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError:
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) < 4 or len(sys.argv) > 5:
# pylint: disable=line-too-long
print("Should be used as `transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else:
TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "xlnet":
try:
from .convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError:
print("transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) < 5 or len(sys.argv) > 6:
# pylint: disable=line-too-long
print("Should be used as `transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
else:
TF_CHECKPOINT = sys.argv[2]
TF_CONFIG = sys.argv[3]
PYTORCH_DUMP_OUTPUT = sys.argv[4]
if len(sys.argv) == 6:
FINETUNING_TASK = sys.argv[5]
else:
FINETUNING_TASK = None
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
TF_CONFIG,
PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK)
elif sys.argv[1] == "xlm":
from .convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
if len(sys.argv) != 4:
# pylint: disable=line-too-long
print("Should be used as `transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
else:
XLM_CHECKPOINT_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers import AutoModel, AutoTokenizer
from transformers.commands import BaseTransformersCLICommand
def convert_command_factory(args: Namespace):
"""
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
:return: ServeCommand
"""
return ConvertCommand(args.model_type, args.tf_checkpoint, args.pytorch_dump_output,
args.config, args.finetuning_task_name)
class ConvertCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
"""
train_parser = parser.add_parser('convert', help="CLI tool to run convert model from original "
"author checkpoints to Transformesr PyTorch checkpoints.")
train_parser.add_argument('--model_type', type=str, required=True,
help='Model\'s type.')
train_parser.add_argument('--tf_checkpoint', type=str, required=True,
help='TensorFlow checkpoint path or folder.')
train_parser.add_argument('--pytorch_dump_output', type=str, required=True,
help='Path to the PyTorch savd model output.')
train_parser.add_argument('--config', type=str, default="",
help='Configuration file path or folder.')
train_parser.add_argument('--finetuning_task_name', type=str, default=None,
help='Optional fine-tuning task name if the TF model was a finetuned model.')
train_parser.set_defaults(func=convert_command_factory)
def __init__(self, model_type: str, tf_checkpoint: str, pytorch_dump_output: str,
config: str, finetuning_task_name: str, *args):
self._logger = getLogger('transformers-cli/converting')
self._logger.info('Loading model {}'.format(model_type))
self._model_type = model_type
self._tf_checkpoint = tf_checkpoint
self._pytorch_dump_output = pytorch_dump_output
self._config = config
self._finetuning_task_name = finetuning_task_name
def run(self):
if self._model_type == "bert":
try:
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
"In that case, it requires TensorFlow to be installed. Please see " \
"https://www.tensorflow.org/install/ for installation instructions."
raise ImportError(msg)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt":
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint,
self._config,
self._pytorch_dump_output)
elif self._model_type == "transfo_xl":
try:
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
"In that case, it requires TensorFlow to be installed. Please see " \
"https://www.tensorflow.org/install/ for installation instructions."
raise ImportError(msg)
if 'ckpt' in self._tf_checkpoint.lower():
TF_CHECKPOINT = self._tf_checkpoint
TF_DATASET_FILE = ""
else:
TF_DATASET_FILE = self._tf_checkpoint
TF_CHECKPOINT = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT,
self._config,
self._pytorch_dump_output,
TF_DATASET_FILE)
elif self._model_type == "gpt2":
try:
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
"In that case, it requires TensorFlow to be installed. Please see " \
"https://www.tensorflow.org/install/ for installation instructions."
raise ImportError(msg)
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet":
try:
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \
"In that case, it requires TensorFlow to be installed. Please see " \
"https://www.tensorflow.org/install/ for installation instructions."
raise ImportError(msg)
convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint,
self._config,
self._pytorch_dump_output,
self._finetuning_task_name)
elif self._model_type == "xlm":
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
else:
raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download')
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
download_parser.add_argument('model', type=str, help='Name of the model to download')
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool):
self._model = model
self._cache = cache
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
\ No newline at end of file
import logging
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str):
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext):
return ext
raise Exception(
'Unable to determine file format from file extension {}. '
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
)
def run_command_factory(args):
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device)
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column)
return RunCommand(nlp, reader)
class RunCommand(BaseTransformersCLICommand):
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
self._nlp = nlp
self._reader = reader
@staticmethod
def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.')
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.')
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference')
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
run_parser.set_defaults(func=run_command_factory)
def run(self):
nlp, output = self._nlp, []
for entry in self._reader:
if self._reader.is_multi_columns:
output += nlp(**entry)
else:
output += nlp(entry)
# Saving data
if self._nlp.binary_output:
binary_path = self._reader.save_binary(output)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else:
self._reader.save(output)
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
import logging
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace):
"""
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device)
return ServeCommand(nlp, args.host, args.port)
class ServeModelInfoResult(BaseModel):
"""
Expose model information
"""
infos: dict
class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
tokens: List[str]
tokens_ids: Optional[List[int]]
class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text: str
class ServeForwardResult(BaseModel):
"""
Forward result model
"""
output: Any
class ServeCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
"""
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model.')
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
self._pipeline = pipeline
self._host = host
self._port = port
if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
else:
logger.info('Serving model over {}:{}'.format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
def run(self):
run(self._app, host=self._host, port=self._port)
def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
"""
Tokenize the provided input and eventually returns corresponding tokens id:
- **text_input**: String to tokenize
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
"""
try:
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
if return_ids:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
else:
return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True)):
"""
Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids
- **skip_special_tokens**: Flag indicating to not try to decode special tokens
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model='', text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
"""
**inputs**:
**attention_mask**:
**tokens_type_ids**:
"""
# Check we don't have empty string
if len(inputs) == 0:
return ServeForwardResult(output=[], attention=[])
try:
# Forward through the model
output = self._pipeline(inputs)
return ServeForwardResult(output=output)
except Exception as e:
raise HTTPException(500, {"error": str(e)})
import os
from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers.commands import BaseTransformersCLICommand
from transformers import (is_tf_available, is_torch_available,
TextClassificationPipeline,
SingleSentenceClassificationProcessor as Processor)
if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
# TF training parameters
USE_XLA = False
USE_AMP = False
def train_command_factory(args: Namespace):
"""
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
return TrainCommand(args)
class TrainCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
"""
Register this command to argparse so it's available for the transformer-cli
:param parser: Root parser to register command-specific arguments
:return:
"""
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.')
train_parser.add_argument('--train_data', type=str, required=True,
help="path to train (and optionally evaluation) dataset as a csv with "
"tab separated labels and sentences.")
train_parser.add_argument('--column_label', type=int, default=0,
help='Column of the dataset csv file with example labels.')
train_parser.add_argument('--column_text', type=int, default=1,
help='Column of the dataset csv file with example texts.')
train_parser.add_argument('--column_id', type=int, default=2,
help='Column of the dataset csv file with example ids.')
train_parser.add_argument('--skip_first_row', action='store_true',
help='Skip the first row of the csv file (headers).')
train_parser.add_argument('--validation_data', type=str, default='',
help='path to validation dataset.')
train_parser.add_argument('--validation_split', type=float, default=0.1,
help="if validation dataset is not provided, fraction of train dataset "
"to use as validation dataset.")
train_parser.add_argument('--output', type=str, default='./',
help='path to saved the trained model.')
train_parser.add_argument('--task', type=str, default='text_classification',
help='Task to train the model on.')
train_parser.add_argument('--model', type=str, default='bert-base-uncased',
help='Model\'s name or path to stored model.')
train_parser.add_argument('--train_batch_size', type=int, default=32,
help='Batch size for training.')
train_parser.add_argument('--valid_batch_size', type=int, default=64,
help='Batch size for validation.')
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
help="Learning rate.")
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08,
help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace):
self.logger = getLogger('transformers-cli/training')
self.framework = 'tf' if is_tf_available() else 'torch'
os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output)
self.output = args.output
self.column_label = args.column_label
self.column_text = args.column_text
self.column_id = args.column_id
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model))
if args.task == 'text_classification':
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == 'token_classification':
raise NotImplementedError
elif args.task == 'question_answering':
raise NotImplementedError
self.logger.info('Loading dataset from {}'.format(args.train_data))
self.train_dataset = Processor.create_from_csv(args.train_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row)
self.valid_dataset = None
if args.validation_data:
self.logger.info('Loading validation dataset from {}'.format(args.validation_data))
self.valid_dataset = Processor.create_from_csv(args.validation_data,
column_label=args.column_label,
column_text=args.column_text,
column_id=args.column_id,
skip_first_row=args.skip_first_row)
self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size
self.valid_batch_size = args.valid_batch_size
self.learning_rate = args.learning_rate
self.adam_epsilon = args.adam_epsilon
def run(self):
if self.framework == 'tf':
return self.run_tf()
return self.run_torch()
def run_torch(self):
raise NotImplementedError
def run_tf(self):
self.pipeline.fit(self.train_dataset,
validation_data=self.valid_dataset,
validation_split=self.validation_split,
learning_rate=self.learning_rate,
adam_epsilon=self.adam_epsilon,
train_batch_size=self.train_batch_size,
valid_batch_size=self.valid_batch_size)
# Save trained pipeline
self.pipeline.save_pretrained(self.output)
...@@ -83,6 +83,34 @@ class AutoConfig(object): ...@@ -83,6 +83,34 @@ class AutoConfig(object):
raise EnvironmentError("AutoConfig is designed to be instantiated " raise EnvironmentError("AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.")
@classmethod
def for_model(cls, model_type, *args, **kwargs):
if 'distilbert' in model_type:
return DistilBertConfig(*args, **kwargs)
elif 'roberta' in model_type:
return RobertaConfig(*args, **kwargs)
elif 'bert' in model_type:
return BertConfig(*args, **kwargs)
elif 'openai-gpt' in model_type:
return OpenAIGPTConfig(*args, **kwargs)
elif 'gpt2' in model_type:
return GPT2Config(*args, **kwargs)
elif 'transfo-xl' in model_type:
return TransfoXLConfig(*args, **kwargs)
elif 'xlnet' in model_type:
return XLNetConfig(*args, **kwargs)
elif 'xlm' in model_type:
return XLMConfig(*args, **kwargs)
elif 'ctrl' in model_type:
return CTRLConfig(*args, **kwargs)
elif 'albert' in model_type:
return AlbertConfig(*args, **kwargs)
elif 'camembert' in model_type:
return CamembertConfig(*args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', 'ctrl', 'camembert', 'albert'".format(model_type))
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a one of the configuration classes of the library r""" Instantiate a one of the configuration classes of the library
......
...@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model, ...@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP) T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
...@@ -47,7 +47,7 @@ if is_torch_available(): ...@@ -47,7 +47,7 @@ if is_torch_available():
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
...@@ -59,7 +59,7 @@ else: ...@@ -59,7 +59,7 @@ else:
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = ( T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
...@@ -70,7 +70,7 @@ else: ...@@ -70,7 +70,7 @@ else:
None, None, None, None,
None, None, None, None,
None, None, None, None, None, None,
None, None, None, None, None, None, None,
None, None, None, None,
None, None, None, None,
None, None) None, None)
...@@ -93,6 +93,7 @@ MODEL_CLASSES = { ...@@ -93,6 +93,7 @@ MODEL_CLASSES = {
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP), 't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
......
from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures from .processors import InputExample, InputFeatures, DataProcessor, SquadFeatures, SingleSentenceClassificationProcessor
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
......
from .utils import InputExample, InputFeatures, DataProcessor from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
\ No newline at end of file
...@@ -18,6 +18,11 @@ import csv ...@@ -18,6 +18,11 @@ import csv
import sys import sys
import copy import copy
import json import json
import logging
from ...file_utils import is_tf_available, is_torch_available
logger = logging.getLogger(__name__)
class InputExample(object): class InputExample(object):
""" """
...@@ -64,7 +69,7 @@ class InputFeatures(object): ...@@ -64,7 +69,7 @@ class InputFeatures(object):
label: Label corresponding to the input label: Label corresponding to the input
""" """
def __init__(self, input_ids, attention_mask, token_type_ids, label): def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
self.token_type_ids = token_type_ids self.token_type_ids = token_type_ids
...@@ -86,34 +91,6 @@ class InputFeatures(object): ...@@ -86,34 +91,6 @@ class InputFeatures(object):
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def tfds_map(self, example):
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
This method converts examples to the correct format."""
if len(self.get_labels()) > 1:
example.label = self.get_labels()[int(example.label)]
return example
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
...@@ -125,3 +102,215 @@ class DataProcessor(object): ...@@ -125,3 +102,215 @@ class DataProcessor(object):
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line) lines.append(line)
return lines return lines
class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode='classification', verbose=False):
self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples
self.mode = mode
self.verbose = verbose
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels,
examples=self.examples[idx])
return self.examples[idx]
@classmethod
def create_from_csv(cls, file_name, split_name='', column_label=0, column_text=1,
column_id=None, skip_first_row=False, **kwargs):
processor = cls(**kwargs)
processor.add_examples_from_csv(file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True)
return processor
@classmethod
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor
def add_examples_from_csv(self, file_name, split_name='', column_label=0, column_text=1, column_id=None,
skip_first_row=False, overwrite_labels=False, overwrite_examples=False):
lines = self._read_tsv(file_name)
if skip_first_row:
lines = lines[1:]
texts = []
labels = []
ids = []
for (i, line) in enumerate(lines):
texts.append(line[column_text])
labels.append(line[column_label])
if column_id is not None:
ids.append(line[column_id])
else:
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid)
return self.add_examples(texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples)
def add_examples(self, texts_or_text_and_labels, labels=None, ids=None,
overwrite_labels=False, overwrite_examples=False):
assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None:
ids = [None] * len(texts_or_text_and_labels)
if labels is None:
labels = [None] * len(texts_or_text_and_labels)
examples = []
added_labels = set()
for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
text, label = text_or_text_and_label
else:
text = text_or_text_and_label
added_labels.add(label)
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
# Update examples
if overwrite_examples:
self.examples = examples
else:
self.examples.extend(examples)
# Update labels
if overwrite_labels:
self.labels = list(added_labels)
else:
self.labels = list(set(self.labels).union(added_labels))
return self.examples
def get_features(self,
tokenizer,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
return_tensors=None):
"""
Convert examples in a list of ``InputFeatures``
Args:
tokenizer: Instance of a tokenizer that will tokenize the examples
max_length: Maximum example length
task: GLUE task
label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
output_mode: String indicating the output mode. Either ``regression`` or ``classification``
pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
pad_token: Padding token
mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
actual values)
Returns:
If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
containing the task-specific features. If the input is a list of ``InputExamples``, will return
a list of task-specific ``InputFeatures`` which can be fed to the model.
"""
if max_length is None:
max_length = tokenizer.max_len
label_map = {label: i for i, label in enumerate(self.labels)}
all_input_ids = []
for (ex_index, example) in enumerate(self.examples):
if ex_index % 10000 == 0:
logger.info("Tokenizing example %d", ex_index)
input_ids = tokenizer.encode(
example.text_a,
add_special_tokens=True,
max_length=min(max_length, tokenizer.max_len),
)
all_input_ids.append(input_ids)
batch_length = max(len(input_ids) for input_ids in all_input_ids)
features = []
for (ex_index, (input_ids, example)) in enumerate(zip(all_input_ids, self.examples)):
if ex_index % 10000 == 0:
logger.info("Writing example %d", ex_index)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = batch_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
else:
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
assert len(input_ids) == batch_length, "Error with input length {} vs {}".format(len(input_ids), batch_length)
assert len(attention_mask) == batch_length, "Error with input length {} vs {}".format(len(attention_mask), batch_length)
if self.mode == "classification":
label = label_map[example.label]
elif self.mode == "regression":
label = float(example.label)
else:
raise ValueError(self.mode)
if ex_index < 5 and self.verbose:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("label: %s (id = %d)" % (example.label, label))
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
label=label))
if return_tensors is None:
return features
elif return_tensors == 'tf':
if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf
def gen():
for ex in features:
yield ({'input_ids': ex.input_ids,
'attention_mask': ex.attention_mask},
ex.label)
dataset = tf.data.Dataset.from_generator(gen,
({'input_ids': tf.int32,
'attention_mask': tf.int32},
tf.int64),
({'input_ids': tf.TensorShape([None]),
'attention_mask': tf.TensorShape([None])},
tf.TensorShape([])))
return dataset
elif return_tensors == 'pt':
if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch
from torch.utils.data import TensorDataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
if self.mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif self.mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
return dataset
else:
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
...@@ -28,17 +28,27 @@ from . import __version__ ...@@ -28,17 +28,27 @@ from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TORCH', 'YES')
if os.environ['USE_TORCH'].upper() in ('1', 'ON', 'YES'):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__)) logger.info("PyTorch version {} available.".format(torch.__version__))
else:
logger.info("USE_TORCH override through env variable, disabling PyTorch")
_torch_available = False
except ImportError: except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
os.environ.setdefault('USE_TF', 'YES')
if os.environ['USE_TF'].upper() in ('1', 'ON', 'YES'):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2 assert hasattr(tf, '__version__') and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name _tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__)) logger.info("TensorFlow version {} available.".format(tf.__version__))
else:
logger.info("USE_TF override through env variable, disabling Tensorflow")
_tf_available = False
except (ImportError, AssertionError): except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name _tf_available = False # pylint: disable=invalid-name
...@@ -72,7 +82,7 @@ WEIGHTS_NAME = "pytorch_model.bin" ...@@ -72,7 +82,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5' TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "model_card.json" MODEL_CARD_NAME = "modelcard.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
...@@ -85,6 +95,7 @@ def is_torch_available(): ...@@ -85,6 +95,7 @@ def is_torch_available():
return _torch_available return _torch_available
def is_tf_available(): def is_tf_available():
return _tf_available return _tf_available
if not six.PY2: if not six.PY2:
...@@ -343,7 +354,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -343,7 +354,7 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
temp_file_manager = tempfile.NamedTemporaryFile temp_file_manager = tempfile.NamedTemporaryFile
resume_size = 0 resume_size = 0
if not os.path.exists(cache_path) or force_download: if etag is not None and (not os.path.exists(cache_path) or force_download):
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
......
...@@ -25,7 +25,8 @@ from io import open ...@@ -25,7 +25,8 @@ from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, cached_path, is_remote_url, hf_bucket_url from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, WEIGHTS_NAME, TF2_WEIGHTS_NAME, \
cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -89,7 +90,7 @@ class ModelCard(object): ...@@ -89,7 +90,7 @@ class ModelCard(object):
- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``. - a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/model_card.json``. - a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
...@@ -100,16 +101,14 @@ class ModelCard(object): ...@@ -100,16 +101,14 @@ class ModelCard(object):
- The values in kwargs of any keys which are model card attributes will be used to override the loaded values. - The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter. - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model card file and override the cached version if it exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None: proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
find_from_standard_name: (`optional`) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
Can be used to directly feed a model/config url and access the colocated modelcard.
return_unused_kwargs: (`optional`) bool: return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final model card object. - If False, then this function returns just the final model card object.
...@@ -117,22 +116,21 @@ class ModelCard(object): ...@@ -117,22 +116,21 @@ class ModelCard(object):
Examples:: Examples::
model_card = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache. modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
model_card = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')` modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
model_card = ModelCard.from_pretrained('./test/saved_model/model_card.json') modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
model_card = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False) modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
find_from_standard_name = kwargs.pop('find_from_standard_name', True)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files but with a different suffix (model_card.json) # For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
elif os.path.isdir(pretrained_model_name_or_path): elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME) model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
...@@ -140,17 +138,22 @@ class ModelCard(object): ...@@ -140,17 +138,22 @@ class ModelCard(object):
else: else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME) model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=force_download, resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=True,
proxies=proxies, resume_download=resume_download) proxies=proxies, resume_download=False)
if resolved_model_card_file == model_card_file: if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file)) logger.info("loading model card file {}".format(model_card_file))
else: else:
logger.info("loading model card file {} from cache at {}".format( logger.info("loading model card file {} from cache at {}".format(
model_card_file, resolved_model_card_file)) model_card_file, resolved_model_card_file))
# Load model card # Load model card
model_card = cls.from_json_file(resolved_model_card_file) modelcard = cls.from_json_file(resolved_model_card_file)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
...@@ -166,7 +169,7 @@ class ModelCard(object): ...@@ -166,7 +169,7 @@ class ModelCard(object):
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
model_card = cls() modelcard = cls()
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or " logger.warning("Couldn't reach server at '{}' to download model card file or "
...@@ -175,22 +178,22 @@ class ModelCard(object): ...@@ -175,22 +178,22 @@ class ModelCard(object):
logger.warning("Creating an empty model card.") logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card # We fall back on creating an empty model card
model_card = cls() modelcard = cls()
# Update model card with kwargs if needed # Update model card with kwargs if needed
to_remove = [] to_remove = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(model_card, key): if hasattr(modelcard, key):
setattr(model_card, key, value) setattr(modelcard, key, value)
to_remove.append(key) to_remove.append(key)
for key in to_remove: for key in to_remove:
kwargs.pop(key, None) kwargs.pop(key, None)
logger.info("Model card: %s", str(model_card)) logger.info("Model card: %s", str(modelcard))
if return_unused_kwargs: if return_unused_kwargs:
return model_card, kwargs return modelcard, kwargs
else: else:
return model_card return modelcard
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
......
This diff is collapsed.
This diff is collapsed.
...@@ -22,6 +22,8 @@ import logging ...@@ -22,6 +22,8 @@ import logging
import os import os
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.saving import hdf5_format
import h5py
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS, from .file_utils import (TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, DUMMY_INPUTS,
...@@ -182,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -182,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model):
model_args: (`optional`) Sequence of positional arguments: model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: config: (`optional`) one of:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
...@@ -206,6 +210,9 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -206,6 +210,9 @@ class TFPreTrainedModel(tf.keras.Model):
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
...@@ -229,11 +236,13 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -229,11 +236,13 @@ class TFPreTrainedModel(tf.keras.Model):
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False) resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
# Load config # Load config if we don't provide a configuration
if config is None: if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained( config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -304,10 +313,39 @@ class TFPreTrainedModel(tf.keras.Model): ...@@ -304,10 +313,39 @@ class TFPreTrainedModel(tf.keras.Model):
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
model.load_weights(resolved_archive_file, by_name=True) model.load_weights(resolved_archive_file, by_name=True)
except OSError:
raise OSError("Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. ")
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
# Check if the models are the same to output loading informations
with h5py.File(resolved_archive_file, 'r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, 'layer_names'))
model_layer_names = set(layer.name for layer in model.layers)
missing_keys = list(model_layer_names - hdf5_layer_names)
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
if len(missing_keys) > 0:
logger.info("Layers of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Layers from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading weights for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
if output_loading_info:
loading_info = {"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs}
return model, loading_info
return model return model
class TFConv1D(tf.keras.layers.Layer): class TFConv1D(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment