Unverified Commit b33bd91f authored by 1lint's avatar 1lint Committed by GitHub
Browse files

Add option to set dtype in pipeline.to() method (#2317)

add test_to_dtype to check pipe.to(fp16)
parent 1fcf279d
...@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin): ...@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): def to(
if torch_device is None: self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
silence_dtype_warnings: bool = False,
):
if torch_device is None and torch_dtype is None:
return self return self
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
...@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys(): for name in module_names.keys():
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype)
if ( if (
module.dtype == torch.float16 module.dtype == torch.float16
and str(torch_device) in ["cpu"] and str(torch_device) in ["cpu"]
...@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin): ...@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
" support for`float16` operations on this device in PyTorch. Please, remove the" " support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference." " `torch_dtype=torch.float16` argument, or use another device for inference."
) )
module.to(torch_device)
return self return self
@property @property
......
...@@ -344,11 +344,8 @@ class PipelineTesterMixin: ...@@ -344,11 +344,8 @@ class PipelineTesterMixin:
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.half()
pipe_fp16 = self.pipeline_class(**components) pipe_fp16 = self.pipeline_class(**components)
pipe_fp16.to(torch_device) pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None) pipe_fp16.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(torch_device))[0] output = pipe(**self.get_dummy_inputs(torch_device))[0]
...@@ -447,6 +444,18 @@ class PipelineTesterMixin: ...@@ -447,6 +444,18 @@ class PipelineTesterMixin:
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(output_cuda).sum() == 0) self.assertTrue(np.isnan(output_cuda).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass() self._test_attention_slicing_forward_pass()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment