Unverified Commit ec396a79 authored by robotcator's avatar robotcator Committed by GitHub
Browse files

fix ema (#33)



* fix ema

* add assert

---------
Co-authored-by: default avatarjixh <jixh@dp.tech>
parent 44f6386f
Pipeline #1068 failed with stages
in 0 seconds
......@@ -113,13 +113,11 @@ class Trainer(object):
if args.ema_decay > 0 and (
self.data_parallel_rank == 0 or args.validate_with_ema
):
assert isinstance(
self.optimizer, optim.FP16Optimizer
), "ema must with fp16 optimizer"
assert (self.args.fp16 or self.args.bf16), "ema must with fp16 or bf16"
self.ema = ExponentialMovingAverageModel(
model,
args.ema_decay,
self._optimizer.fp32_params,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment