Unverified Commit e9d46fa0 authored by RUCer's avatar RUCer Committed by GitHub
Browse files

fix loading model logs (#24)

* fix log

* fix log
parent 64d5fa6e
......@@ -320,6 +320,7 @@ class Trainer(object):
)
had_loaded_model = False
ema_loaded = False
if bexists:
state = None
if is_master:
......@@ -344,6 +345,7 @@ class Trainer(object):
errors = self.model.load_state_dict(
ema_state["params"], strict=False, model_args=self.args
)
ema_loaded = True
else:
errors = self.model.load_state_dict(
state["model"], strict=False, model_args=self.args
......@@ -383,7 +385,7 @@ class Trainer(object):
):
logger.info(f"Loading EMA state...")
self.ema.load_state_dict(ema_state)
elif self.ema is not None:
elif self.ema is not None and not ema_loaded:
logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
......@@ -444,6 +446,8 @@ class Trainer(object):
elif had_loaded_model:
logger.info("Loaded checkpoint {}".format(filename))
elif ema_loaded:
logger.info("Loaded ema state from checkpoint {}".format(filename))
else:
logger.info("No existing checkpoint found {}".format(filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment