Unverified Commit 0b0a5984 authored by francescorubbo's avatar francescorubbo Committed by GitHub
Browse files

Ensure input tensor are on device. (#11874)

The feature extractor does not create tensors on the appropriate device,
so we call `ensure_tensor_on_device` before feeding the processed inputs
to the model.
parent a9c797f9
...@@ -136,6 +136,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline): ...@@ -136,6 +136,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
processed = self.feature_extractor( processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
) )
processed = self.ensure_tensor_on_device(**processed)
name = self.model.__class__.__name__ name = self.model.__class__.__name__
if name.endswith("ForConditionalGeneration"): if name.endswith("ForConditionalGeneration"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment