"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9660ba1cbdec0e419937af06bd99f06fb5ebbf91"
Commit 1ca52567 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Allow model conversion in the pipeline allocator.

parent 28e64ad5
...@@ -385,5 +385,17 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni ...@@ -385,5 +385,17 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
targeted_task = SUPPORTED_TASKS[task] targeted_task = SUPPORTED_TASKS[task]
task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt'] task, allocator = targeted_task['impl'], targeted_task['tf'] if is_tf_available() else targeted_task['pt']
model = allocator.from_pretrained(model) # Special handling for model conversion
from_tf = model.endswith('.h5') and not is_tf_available()
from_pt = model.endswith('.bin') and not is_torch_available()
if from_tf:
logger.warning('Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. Trying to load the model with PyTorch.')
elif from_pt:
logger.warning('Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. Trying to load the model with Tensorflow.')
if allocator.__name__.startswith('TF'):
model = allocator.from_pretrained(model, config=config, from_pt=from_pt)
else:
model = allocator.from_pretrained(model, config=config, from_tf=from_tf)
return task(model, tokenizer, **kwargs) return task(model, tokenizer, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment