"vscode:/vscode.git/clone" did not exist on "3bc65505fc0801e3d9ff741ec725fb0cb4d863d6"
Unverified Commit c38f4e1f authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Running a pipeline of `float16`. (#17637)

When we're preparing the tensors for CPU for postprocessing, we need
to upgrade the `float16` to `float32` since CPUs don't have instructions
for `[b]float16`.
parent 90ed9ae2
......@@ -869,6 +869,8 @@ class Pipeline(_ScikitCompat):
elif isinstance(inputs, tuple):
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
elif isinstance(inputs, torch.Tensor):
if device == torch.device("cpu") and inputs.dtype in {torch.float16, torch.bfloat16}:
inputs = inputs.float()
return inputs.to(device)
else:
return inputs
......
......@@ -16,7 +16,14 @@ import unittest
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
from transformers.pipelines import PipelineException
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
......@@ -130,6 +137,19 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
],
)
@require_torch_gpu
def test_fp16_casting(self):
pipe = pipeline("fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=0, framework="pt")
# convert model to fp16
pipe.model.half()
response = pipe("Paris is the [MASK] of France.")
# We actually don't care about the result, we just want to make sure
# it works, meaning the float16 tensor got casted back to float32
# for postprocessing.
self.assertIsInstance(response, list)
@slow
@require_torch
def test_large_model_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment