Unverified Commit bf9a641f authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix EMAModel test_from_pretrained (#10325)

parent a756694b
...@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase): ...@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
# Load the EMA model from the saved directory # Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model # Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
...@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase): ...@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
# Load the EMA model from the saved directory # Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model # Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment