Unverified Commit 05d9baea authored by asfiyab-nvidia's avatar asfiyab-nvidia Committed by GitHub
Browse files

Fix TensorRT community pipeline device set function (#3157)



pass silence_dtype_warnings as kwarg
Signed-off-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e573ae06
...@@ -703,7 +703,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -703,7 +703,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
) )
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
super().to(torch_device, silence_dtype_warnings) super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment