"docs/source/vscode:/vscode.git/clone" did not exist on "7ffe25f2b935dcaf65079b04c5f91c8a42a99e28"
Unverified Commit 8e9a2207 authored by Yuki Watanabe's avatar Yuki Watanabe Committed by GitHub
Browse files

Populate torch_dtype from model to pipeline (#28940)



* Populate torch_dtype from model to pipeline
Signed-off-by: default avatarB-Step62 <yuki.watanabe@databricks.com>

* use property
Signed-off-by: default avatarB-Step62 <yuki.watanabe@databricks.com>

* lint
Signed-off-by: default avatarB-Step62 <yuki.watanabe@databricks.com>

* Remove default handling
Signed-off-by: default avatarB-Step62 <yuki.watanabe@databricks.com>

---------
Signed-off-by: default avatarB-Step62 <yuki.watanabe@databricks.com>
parent afe73aed
...@@ -861,7 +861,7 @@ class Pipeline(_ScikitCompat): ...@@ -861,7 +861,7 @@ class Pipeline(_ScikitCompat):
raise ValueError(f"{device} unrecognized or not available.") raise ValueError(f"{device} unrecognized or not available.")
else: else:
self.device = device if device is not None else -1 self.device = device if device is not None else -1
self.torch_dtype = torch_dtype
self.binary_output = binary_output self.binary_output = binary_output
# We shouldn't call `model.to()` for models loaded with accelerate # We shouldn't call `model.to()` for models loaded with accelerate
...@@ -964,6 +964,13 @@ class Pipeline(_ScikitCompat): ...@@ -964,6 +964,13 @@ class Pipeline(_ScikitCompat):
""" """
return self(X) return self(X)
@property
def torch_dtype(self) -> Optional["torch.dtype"]:
"""
Torch dtype of the model (if it's Pytorch model), `None` otherwise.
"""
return getattr(self.model, "dtype", None)
@contextmanager @contextmanager
def device_placement(self): def device_placement(self):
""" """
......
...@@ -199,6 +199,29 @@ class CommonPipelineTest(unittest.TestCase): ...@@ -199,6 +199,29 @@ class CommonPipelineTest(unittest.TestCase):
outputs = text_classifier(["This is great !"] * 20, batch_size=32) outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20) self.assertEqual(len(outputs), 20)
@require_torch
def test_torch_dtype_property(self):
import torch
model_id = "hf-internal-testing/tiny-random-distilbert"
# If dtype is specified in the pipeline constructor, the property should return that type
pipe = pipeline(model=model_id, torch_dtype=torch.float16)
self.assertEqual(pipe.torch_dtype, torch.float16)
# If the underlying model changes dtype, the property should return the new type
pipe.model.to(torch.bfloat16)
self.assertEqual(pipe.torch_dtype, torch.bfloat16)
# If dtype is NOT specified in the pipeline constructor, the property should just return
# the dtype of the underlying model (default)
pipe = pipeline(model=model_id)
self.assertEqual(pipe.torch_dtype, torch.float32)
# If underlying model doesn't have dtype property, simply return None
pipe.model = None
self.assertIsNone(pipe.torch_dtype)
@is_pipeline_test @is_pipeline_test
class PipelineScikitCompatTest(unittest.TestCase): class PipelineScikitCompatTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment