Unverified Commit fe65657d authored by bofeng huang's avatar bofeng huang Committed by GitHub
Browse files

Fix FP16 inference in TextGenerationPipeline (#20913)



* add torch_dtype attribute to Pipeline

* Use torch_dtype to cast input tensor type in AutomaticSpeechRecognitionPipeline

* Fix code quality

* Add TextGenerationPipeline fp16 test

* Fix code quality

* Remove useless require in tests
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 11c49ed2
......@@ -242,8 +242,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
if "torch_dtype" in kwargs:
preprocess_params["dtype"] = kwargs["torch_dtype"]
postprocess_params = {}
if "decoder_kwargs" in kwargs:
......@@ -253,7 +251,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
......@@ -336,14 +334,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("Chunk length must be superior to stride length")
# make sure that
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
for item in chunk_iter(
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = processed.to(dtype=dtype)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
if stride is not None:
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
raise ValueError("Stride is only usable with CTC models, try removing it")
......
......@@ -748,6 +748,7 @@ class Pipeline(_ScikitCompat):
task: str = "",
args_parser: ArgumentHandler = None,
device: Union[int, str, "torch.device"] = -1,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
binary_output: bool = False,
**kwargs,
):
......@@ -771,6 +772,7 @@ class Pipeline(_ScikitCompat):
self.device = torch.device(f"cuda:{device}")
else:
self.device = device
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Special handling
......
......@@ -300,3 +300,11 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
}
],
)
@require_torch
@require_torch_gpu
def test_small_model_fp16(self):
import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16)
pipe("This is a test")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment