Unverified Commit 0bab9d6b authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Single File] Allow loading T5 encoder in mixed precision (#8778)

* update

* update

* update

* update
parent 2e2684f0
...@@ -555,7 +555,4 @@ class FromSingleFileMixin: ...@@ -555,7 +555,4 @@ class FromSingleFileMixin:
pipe = pipeline_class(**init_kwargs) pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)
return pipe return pipe
...@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint( ...@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint(
else: else:
model.load_state_dict(diffusers_format_checkpoint) model.load_state_dict(diffusers_format_checkpoint)
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = model._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
# param = param.to(torch.float32) does not work here as only in the local scope.
param.data = param.data.to(torch.float32)
return model return model
...@@ -201,6 +201,20 @@ class SDSingleFileTesterMixin: ...@@ -201,6 +201,20 @@ class SDSingleFileTesterMixin:
self._compare_component_configs(pipe, single_file_pipe) self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_setting_pipeline_dtype_to_fp16(
self,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, torch_dtype=torch.float16
)
for component_name, component in single_file_pipe.components.items():
if not isinstance(component, torch.nn.Module):
continue
assert component.dtype == torch.float16
class SDXLSingleFileTesterMixin: class SDXLSingleFileTesterMixin:
def _compare_component_configs(self, pipe, single_file_pipe): def _compare_component_configs(self, pipe, single_file_pipe):
...@@ -378,3 +392,17 @@ class SDXLSingleFileTesterMixin: ...@@ -378,3 +392,17 @@ class SDXLSingleFileTesterMixin:
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff assert max_diff < expected_max_diff
def test_single_file_setting_pipeline_dtype_to_fp16(
self,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, torch_dtype=torch.float16
)
for component_name, component in single_file_pipe.components.items():
if not isinstance(component, torch.nn.Module):
continue
assert component.dtype == torch.float16
...@@ -180,3 +180,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ...@@ -180,3 +180,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
local_files_only=True, local_files_only=True,
) )
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
...@@ -181,3 +181,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC ...@@ -181,3 +181,12 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
local_files_only=True, local_files_only=True,
) )
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
...@@ -169,3 +169,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD ...@@ -169,3 +169,12 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
local_files_only=True, local_files_only=True,
) )
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
...@@ -200,3 +200,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX ...@@ -200,3 +200,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
local_files_only=True, local_files_only=True,
) )
self._compare_component_configs(pipe, pipe_single_file) self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, adapter=adapter, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
...@@ -195,3 +195,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, ...@@ -195,3 +195,12 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
local_files_only=True, local_files_only=True,
) )
super()._compare_component_configs(pipe, pipe_single_file) super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_setting_pipeline_dtype_to_fp16(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
single_file_pipe = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment