Commit e347725d authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

More fine-grained control over pipeline creation with config argument.

parent 55397dfb
...@@ -497,7 +497,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -497,7 +497,7 @@ class QuestionAnsweringPipeline(Pipeline):
'score': score.item(), 'score': score.item(),
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(), 'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(), 'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1]) 'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]:feature.token_to_orig_map[e] + 1])
} }
for s, e, score in zip(starts, ends, scores) for s, e, score in zip(starts, ends, scores)
] ]
...@@ -612,7 +612,8 @@ SUPPORTED_TASKS = { ...@@ -612,7 +612,8 @@ SUPPORTED_TASKS = {
} }
def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline: def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
""" """
Utility factory method to build a pipeline. Utility factory method to build a pipeline.
Pipeline are made of: Pipeline are made of:
...@@ -637,13 +638,21 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni ...@@ -637,13 +638,21 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt'] task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
# Special handling for model conversion # Special handling for model conversion
from_tf = model.endswith('.h5') and not is_tf_available() if isinstance(model, str):
from_pt = model.endswith('.bin') and not is_torch_available() from_tf = model.endswith('.h5') and not is_tf_available()
from_pt = model.endswith('.bin') and not is_torch_available()
if from_tf:
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.')
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.')
else:
from_tf = from_pt = False
if from_tf: if isinstance(config, str):
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.') config = PretrainedConfig.from_pretrained(config)
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.')
if allocator.__name__.startswith('TF'): if allocator.__name__.startswith('TF'):
model = allocator.from_pretrained(model, config=config, from_pt=from_pt) model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment