Unverified Commit 8d5a47c7 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2243 from huggingface/fix-xlm-roberta

fixing xlm-roberta tokenizer max_length and automodels
parents 65c75fc5 79e4a6a2
...@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name ...@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):
if not path:
return 'pipe'
for ext in PipelineDataFormat.SUPPORTED_FORMATS: for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext): if path.endswith(ext):
return ext return ext
...@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str): ...@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str):
def run_command_factory(args): def run_command_factory(args):
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device) nlp = pipeline(task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device)
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column) reader = PipelineDataFormat.from_str(format=format,
output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names)
return RunCommand(nlp, reader) return RunCommand(nlp, reader)
...@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI") run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run') run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.') run_parser.add_argument('--input', type=str, help='Path to the file to use for inference')
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
run_parser.add_argument('--model', type=str, help='Name or path to the model to instantiate.')
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.') run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.')
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)') run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)') run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from') run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference') run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.')
run_parser.set_defaults(func=run_command_factory) run_parser.set_defaults(func=run_command_factory)
def run(self): def run(self):
nlp, output = self._nlp, [] nlp, outputs = self._nlp, []
for entry in self._reader: for entry in self._reader:
if self._reader.is_multi_columns: output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
output += nlp(**entry) if isinstance(output, dict):
outputs.append(output)
else: else:
output += nlp(entry) outputs += output
# Saving data # Saving data
if self._nlp.binary_output: if self._nlp.binary_output:
......
...@@ -24,7 +24,11 @@ def serve_command_factory(args: Namespace): ...@@ -24,7 +24,11 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
nlp = pipeline(task=args.task, model=args.model, config=args.config, tokenizer=args.tokenizer, device=args.device) nlp = pipeline(task=args.task,
model=args.model if args.model else None,
config=args.config,
tokenizer=args.tokenizer,
device=args.device)
return ServeCommand(nlp, args.host, args.port) return ServeCommand(nlp, args.host, args.port)
...@@ -68,12 +72,12 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -68,12 +72,12 @@ class ServeCommand(BaseTransformersCLICommand):
""" """
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.') serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on') serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.') serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.') serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model.') serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.')
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.') serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.')
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.') serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int):
......
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig, from .configuration_auto import (AlbertConfig, BertConfig, CamembertConfig, CTRLConfig,
DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig, DistilBertConfig, GPT2Config, OpenAIGPTConfig, RobertaConfig,
TransfoXLConfig, XLMConfig, XLNetConfig) TransfoXLConfig, XLMConfig, XLNetConfig, XLMRobertaConfig)
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \ from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, \
BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP BertForTokenClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -41,7 +41,8 @@ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertF ...@@ -41,7 +41,8 @@ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertF
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \ from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, \
AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, XLMRobertaForMultipleChoice, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, \
XLMRobertaForMultipleChoice, XLMRobertaForTokenClassification, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_utils import PreTrainedModel, SequenceSummary from .modeling_utils import PreTrainedModel, SequenceSummary
...@@ -146,6 +147,8 @@ class AutoModel(object): ...@@ -146,6 +147,8 @@ class AutoModel(object):
return AlbertModel(config) return AlbertModel(config)
elif isinstance(config, CamembertConfig): elif isinstance(config, CamembertConfig):
return CamembertModel(config) return CamembertModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaModel(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -333,6 +336,8 @@ class AutoModelWithLMHead(object): ...@@ -333,6 +336,8 @@ class AutoModelWithLMHead(object):
return XLMWithLMHeadModel(config) return XLMWithLMHeadModel(config)
elif isinstance(config, CTRLConfig): elif isinstance(config, CTRLConfig):
return CTRLLMHeadModel(config) return CTRLLMHeadModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForMaskedLM(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -509,6 +514,8 @@ class AutoModelForSequenceClassification(object): ...@@ -509,6 +514,8 @@ class AutoModelForSequenceClassification(object):
return XLNetForSequenceClassification(config) return XLNetForSequenceClassification(config)
elif isinstance(config, XLMConfig): elif isinstance(config, XLMConfig):
return XLMForSequenceClassification(config) return XLMForSequenceClassification(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForSequenceClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -787,6 +794,8 @@ class AutoModelForTokenClassification: ...@@ -787,6 +794,8 @@ class AutoModelForTokenClassification:
return XLNetForTokenClassification(config) return XLNetForTokenClassification(config)
elif isinstance(config, RobertaConfig): elif isinstance(config, RobertaConfig):
return RobertaForTokenClassification(config) return RobertaForTokenClassification(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaForTokenClassification(config)
raise ValueError("Unrecognized configuration class {}".format(config)) raise ValueError("Unrecognized configuration class {}".format(config))
@classmethod @classmethod
...@@ -865,6 +874,8 @@ class AutoModelForTokenClassification: ...@@ -865,6 +874,8 @@ class AutoModelForTokenClassification:
return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return CamembertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path: elif 'distilbert' in pretrained_model_name_or_path:
return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return DistilBertForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'xlm-roberta' in pretrained_model_name_or_path:
return XLMRobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path: elif 'roberta' in pretrained_model_name_or_path:
return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return RobertaForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: elif 'bert' in pretrained_model_name_or_path:
...@@ -873,4 +884,4 @@ class AutoModelForTokenClassification: ...@@ -873,4 +884,4 @@ class AutoModelForTokenClassification:
return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return XLNetForTokenClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of " raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'roberta'".format(pretrained_model_name_or_path)) "'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(pretrained_model_name_or_path))
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import csv import csv
import json import json
import os import os
import pickle import pickle
import logging import logging
import six import six
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import groupby from itertools import groupby
...@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__) ...@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
def get_framework(model=None): def get_framework(model=None):
""" Select framework (TensorFlow/PyTorch) to use. """ Select framework (TensorFlow/PyTorch) to use.
If both frameworks are installed and no specific model is provided, defaults to using TensorFlow. If both frameworks are installed and no specific model is provided, defaults to using PyTorch.
""" """
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the use supplied a model class instance. # Both framework are available but the use supplied a model class instance.
...@@ -60,7 +62,8 @@ def get_framework(model=None): ...@@ -60,7 +62,8 @@ def get_framework(model=None):
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/.") "To install PyTorch, read the instructions at https://pytorch.org/.")
else: else:
framework = 'tf' if is_tf_available() else 'pt' # framework = 'tf' if is_tf_available() else 'pt'
framework = 'pt' if is_torch_available() else 'tf'
return framework return framework
class ArgumentHandler(ABC): class ArgumentHandler(ABC):
...@@ -97,28 +100,29 @@ class PipelineDataFormat: ...@@ -97,28 +100,29 @@ class PipelineDataFormat:
Supported data formats currently includes: Supported data formats currently includes:
- JSON - JSON
- CSV - CSV
- stdin/stdout (pipe)
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format. to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
""" """
SUPPORTED_FORMATS = ['json', 'csv', 'pipe'] SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
self.output = output self.output_path = output_path
self.path = input self.input_path = input_path
self.column = column.split(',') if column else [''] self.column = column.split(',') if column is not None else ['']
self.is_multi_columns = len(self.column) > 1 self.is_multi_columns = len(self.column) > 1
if self.is_multi_columns: if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column] self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
if output is not None: if output_path is not None:
if exists(abspath(self.output)): if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output)) raise OSError('{} already exists on disk'.format(self.output_path))
if input is not None: if input_path is not None:
if not exists(abspath(self.path)): if not exists(abspath(self.input_path)):
raise OSError('{} doesnt exist on disk'.format(self.path)) raise OSError('{} doesnt exist on disk'.format(self.input_path))
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
...@@ -139,7 +143,7 @@ class PipelineDataFormat: ...@@ -139,7 +143,7 @@ class PipelineDataFormat:
:param data: data to store :param data: data to store
:return: (str) Path where the data has been saved :return: (str) Path where the data has been saved
""" """
path, _ = os.path.splitext(self.output) path, _ = os.path.splitext(self.output_path)
binary_path = os.path.extsep.join((path, 'pickle')) binary_path = os.path.extsep.join((path, 'pickle'))
with open(binary_path, 'wb+') as f_output: with open(binary_path, 'wb+') as f_output:
...@@ -148,23 +152,23 @@ class PipelineDataFormat: ...@@ -148,23 +152,23 @@ class PipelineDataFormat:
return binary_path return binary_path
@staticmethod @staticmethod
def from_str(name: str, output: Optional[str], path: Optional[str], column: Optional[str]): def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
if name == 'json': if format == 'json':
return JsonPipelineDataFormat(output, path, column) return JsonPipelineDataFormat(output_path, input_path, column)
elif name == 'csv': elif format == 'csv':
return CsvPipelineDataFormat(output, path, column) return CsvPipelineDataFormat(output_path, input_path, column)
elif name == 'pipe': elif format == 'pipe':
return PipedPipelineDataFormat(output, path, column) return PipedPipelineDataFormat(output_path, input_path, column)
else: else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(name)) raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
class CsvPipelineDataFormat(PipelineDataFormat): class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output, input, column) super().__init__(output_path, input_path, column)
def __iter__(self): def __iter__(self):
with open(self.path, 'r') as f: with open(self.input_path, 'r') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
if self.is_multi_columns: if self.is_multi_columns:
...@@ -173,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat): ...@@ -173,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
yield row[self.column[0]] yield row[self.column[0]]
def save(self, data: List[dict]): def save(self, data: List[dict]):
with open(self.output, 'w') as f: with open(self.output_path, 'w') as f:
if len(data) > 0: if len(data) > 0:
writer = csv.DictWriter(f, list(data[0].keys())) writer = csv.DictWriter(f, list(data[0].keys()))
writer.writeheader() writer.writeheader()
...@@ -181,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat): ...@@ -181,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat): class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output: Optional[str], input: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]):
super().__init__(output, input, column) super().__init__(output_path, input_path, column)
with open(input, 'r') as f: with open(input_path, 'r') as f:
self._entries = json.load(f) self._entries = json.load(f)
def __iter__(self): def __iter__(self):
...@@ -195,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat): ...@@ -195,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
yield entry[self.column[0]] yield entry[self.column[0]]
def save(self, data: dict): def save(self, data: dict):
with open(self.output, 'w') as f: with open(self.output_path, 'w') as f:
json.dump(data, f) json.dump(data, f)
...@@ -207,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -207,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
If columns are provided, then the output will be a dictionary with {column_x: value_x} If columns are provided, then the output will be a dictionary with {column_x: value_x}
""" """
def __iter__(self): def __iter__(self):
import sys
for line in sys.stdin: for line in sys.stdin:
# Split for multi-columns # Split for multi-columns
if '\t' in line: if '\t' in line:
...@@ -228,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -228,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
print(data) print(data)
def save_binary(self, data: Union[dict, List[dict]]) -> str: def save_binary(self, data: Union[dict, List[dict]]) -> str:
if self.output is None: if self.output_path is None:
raise KeyError( raise KeyError(
'When using piped input on pipeline outputting large object requires an output file path. ' 'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.' 'Please provide such output path through --output argument.'
...@@ -293,6 +295,9 @@ class Pipeline(_ScikitCompat): ...@@ -293,6 +295,9 @@ class Pipeline(_ScikitCompat):
nlp = NerPipeline(model='...', config='...', tokenizer='...') nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...') nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
""" """
default_input_names = None
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None, modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
...@@ -581,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -581,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head. Question Answering pipeline using ModelForQuestionAnswering head.
""" """
default_input_names = 'question,context'
def __init__(self, model, def __init__(self, model,
tokenizer: Optional[PreTrainedTokenizer], tokenizer: Optional[PreTrainedTokenizer],
modelcard: Optional[ModelCard], modelcard: Optional[ModelCard],
...@@ -683,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -683,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline):
} }
for s, e, score in zip(starts, ends, scores) for s, e, score in zip(starts, ends, scores)
] ]
if len(answers) == 1: if len(answers) == 1:
return answers[0] return answers[0]
return answers return answers
......
...@@ -434,7 +434,11 @@ class PreTrainedTokenizer(object): ...@@ -434,7 +434,11 @@ class PreTrainedTokenizer(object):
init_kwargs[key] = value init_kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
try:
tokenizer = cls(*init_inputs, **init_kwargs) tokenizer = cls(*init_inputs, **init_kwargs)
except OSError:
OSError("Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted.")
# Save inputs and kwargs for saving and re-loading with ``save_pretrained`` # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs tokenizer.init_inputs = init_inputs
......
...@@ -40,8 +40,12 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -40,8 +40,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-roberta-base': None, 'xlm-roberta-base': 512,
'xlm-roberta-large': None, 'xlm-roberta-large': 512,
'xlm-roberta-large-finetuned-conll02-dutch': 512,
'xlm-roberta-large-finetuned-conll02-spanish': 512,
'xlm-roberta-large-finetuned-conll03-english': 512,
'xlm-roberta-large-finetuned-conll03-german': 512,
} }
class XLMRobertaTokenizer(PreTrainedTokenizer): class XLMRobertaTokenizer(PreTrainedTokenizer):
...@@ -58,7 +62,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -58,7 +62,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
**kwargs): **kwargs):
super(XLMRobertaTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, super(XLMRobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, mask_token=mask_token,
**kwargs) **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment