Unverified Commit e9d46fa0 authored by RUCer's avatar RUCer Committed by GitHub
Browse files

fix loading model logs (#24)

* fix log

* fix log
parent 64d5fa6e
...@@ -320,6 +320,7 @@ class Trainer(object): ...@@ -320,6 +320,7 @@ class Trainer(object):
) )
had_loaded_model = False had_loaded_model = False
ema_loaded = False
if bexists: if bexists:
state = None state = None
if is_master: if is_master:
...@@ -344,6 +345,7 @@ class Trainer(object): ...@@ -344,6 +345,7 @@ class Trainer(object):
errors = self.model.load_state_dict( errors = self.model.load_state_dict(
ema_state["params"], strict=False, model_args=self.args ema_state["params"], strict=False, model_args=self.args
) )
ema_loaded = True
else: else:
errors = self.model.load_state_dict( errors = self.model.load_state_dict(
state["model"], strict=False, model_args=self.args state["model"], strict=False, model_args=self.args
...@@ -383,7 +385,7 @@ class Trainer(object): ...@@ -383,7 +385,7 @@ class Trainer(object):
): ):
logger.info(f"Loading EMA state...") logger.info(f"Loading EMA state...")
self.ema.load_state_dict(ema_state) self.ema.load_state_dict(ema_state)
elif self.ema is not None: elif self.ema is not None and not ema_loaded:
logger.info( logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly" f"Cannot find EMA state in checkpoint, load model weight to ema directly"
) )
...@@ -444,6 +446,8 @@ class Trainer(object): ...@@ -444,6 +446,8 @@ class Trainer(object):
elif had_loaded_model: elif had_loaded_model:
logger.info("Loaded checkpoint {}".format(filename)) logger.info("Loaded checkpoint {}".format(filename))
elif ema_loaded:
logger.info("Loaded ema state from checkpoint {}".format(filename))
else: else:
logger.info("No existing checkpoint found {}".format(filename)) logger.info("No existing checkpoint found {}".format(filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment