Unverified Commit ec6878f6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Now supporting pathlike in pipelines too. (#20030)

parent aa39967b
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from numpy import isin from numpy import isin
...@@ -638,6 +639,8 @@ def pipeline( ...@@ -638,6 +639,8 @@ def pipeline(
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class"
" or a path/identifier to a pretrained model when providing feature_extractor." " or a path/identifier to a pretrained model when providing feature_extractor."
) )
if isinstance(model, Path):
model = str(model)
# Config is the primordial information item. # Config is the primordial information item.
# Instantiate config if needed # Instantiate config if needed
......
...@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase): ...@@ -356,6 +356,15 @@ class CommonPipelineTest(unittest.TestCase):
self.assertEqual(pipe._batch_size, 2) self.assertEqual(pipe._batch_size, 2)
self.assertEqual(pipe._num_workers, 1) self.assertEqual(pipe._num_workers, 1)
@require_torch
def test_pipeline_pathlike(self):
pipe = pipeline(model="hf-internal-testing/tiny-random-distilbert")
with tempfile.TemporaryDirectory() as d:
pipe.save_pretrained(d)
path = Path(d)
newpipe = pipeline(task="text-classification", model=path)
self.assertIsInstance(newpipe, TextClassificationPipeline)
@require_torch @require_torch
def test_pipeline_override(self): def test_pipeline_override(self):
class MyPipeline(TextClassificationPipeline): class MyPipeline(TextClassificationPipeline):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment