"INSTALL/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "0b7fa630a477204220c87eac8a4005ee9f205d64"
Commit e347725d authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

More fine-grained control over pipeline creation with config argument.

parent 55397dfb
......@@ -497,7 +497,7 @@ class QuestionAnsweringPipeline(Pipeline):
'score': score.item(),
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1])
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]:feature.token_to_orig_map[e] + 1])
}
for s, e, score in zip(starts, ends, scores)
]
......@@ -612,7 +612,8 @@ SUPPORTED_TASKS = {
}
def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
def pipeline(task: str, model, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, **kwargs) -> Pipeline:
"""
Utility factory method to build a pipeline.
Pipeline are made of:
......@@ -637,13 +638,21 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
# Special handling for model conversion
from_tf = model.endswith('.h5') and not is_tf_available()
from_pt = model.endswith('.bin') and not is_torch_available()
if isinstance(model, str):
from_tf = model.endswith('.h5') and not is_tf_available()
from_pt = model.endswith('.bin') and not is_torch_available()
if from_tf:
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. '
'Trying to load the model with PyTorch.')
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. '
'Trying to load the model with Tensorflow.')
else:
from_tf = from_pt = False
if from_tf:
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.')
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.')
if isinstance(config, str):
config = PretrainedConfig.from_pretrained(config)
if allocator.__name__.startswith('TF'):
model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment