Unverified Commit f2b0b488 authored by Jinhua Zhu's avatar Jinhua Zhu Committed by GitHub
Browse files

filter params in ema (#19)

parent 4e4dc4b8
......@@ -15,6 +15,7 @@ class ExponentialMovingAverageModel:
dtype_grouped_names = dict()
ordered_dtype = []
for n, p in model.named_parameters():
if p.requires_grad:
if p.dtype not in dtype_grouped_names:
dtype_grouped_names[p.dtype] = []
ordered_dtype.append(p.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment