"vscode:/vscode.git/clone" did not exist on "796c01534dcc8856aa81d69753a2df758274d625"
Unverified Commit f55f1f7e authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

Fixes EMAModel "from_pretrained" method (#9779)



* fix from_pretrained and added test

* make style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 9dcac830
...@@ -379,7 +379,7 @@ class EMAModel: ...@@ -379,7 +379,7 @@ class EMAModel:
@classmethod @classmethod
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path) model = model_cls.from_pretrained(path)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
......
...@@ -59,6 +59,25 @@ class EMAModelTests(unittest.TestCase): ...@@ -59,6 +59,25 @@ class EMAModelTests(unittest.TestCase):
unet.load_state_dict(updated_state_dict) unet.load_state_dict(updated_state_dict)
return unet return unet
def test_from_pretrained(self):
# Save the model parameters to a temporary directory
unet, ema_unet = self.get_models()
with tempfile.TemporaryDirectory() as tmpdir:
ema_unet.save_pretrained(tmpdir)
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
assert torch.allclose(original_param, loaded_param, atol=1e-4)
# Verify that the optimization step is also preserved
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
# Check the decay value
assert loaded_ema_unet.decay == ema_unet.decay
def test_optimization_steps_updated(self): def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models() unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step. # Take the first (hypothetical) EMA step.
...@@ -194,6 +213,25 @@ class EMAModelTestsForeach(unittest.TestCase): ...@@ -194,6 +213,25 @@ class EMAModelTestsForeach(unittest.TestCase):
unet.load_state_dict(updated_state_dict) unet.load_state_dict(updated_state_dict)
return unet return unet
def test_from_pretrained(self):
# Save the model parameters to a temporary directory
unet, ema_unet = self.get_models()
with tempfile.TemporaryDirectory() as tmpdir:
ema_unet.save_pretrained(tmpdir)
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
assert torch.allclose(original_param, loaded_param, atol=1e-4)
# Verify that the optimization step is also preserved
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
# Check the decay value
assert loaded_ema_unet.decay == ema_unet.decay
def test_optimization_steps_updated(self): def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models() unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step. # Take the first (hypothetical) EMA step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment