Unverified Commit 7a001c3e authored by kaixuanliu's avatar kaixuanliu Committed by GitHub
Browse files

adjust unit tests for `test_save_load_float16` (#12500)



* adjust unit tests for wan pipeline
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* update code
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* avoid adjusting common `get_dummy_components` API
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* use `form_pretrained` to `transformer` and `transformer_2`
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* update code
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

* update
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>

---------
Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent d8e48058
...@@ -1422,7 +1422,18 @@ class PipelineTesterMixin: ...@@ -1422,7 +1422,18 @@ class PipelineTesterMixin:
def test_save_load_float16(self, expected_max_diff=1e-2): def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components() components = self.get_dummy_components()
for name, module in components.items(): for name, module in components.items():
if hasattr(module, "half"): # Account for components with _keep_in_fp32_modules
if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None:
for name, param in module.named_parameters():
if any(
module_to_keep_in_fp32 in name.split(".")
for module_to_keep_in_fp32 in module._keep_in_fp32_modules
):
param.data = param.data.to(torch_device).to(torch.float32)
else:
param.data = param.data.to(torch_device).to(torch.float16)
elif hasattr(module, "half"):
components[name] = module.to(torch_device).half() components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment