"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "7dd79013891d32c46d88812607c50891c91afd54"
Unverified Commit f2b0b488 authored by Jinhua Zhu's avatar Jinhua Zhu Committed by GitHub
Browse files

filter params in ema (#19)

parent 4e4dc4b8
...@@ -15,10 +15,11 @@ class ExponentialMovingAverageModel: ...@@ -15,10 +15,11 @@ class ExponentialMovingAverageModel:
dtype_grouped_names = dict() dtype_grouped_names = dict()
ordered_dtype = [] ordered_dtype = []
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if p.dtype not in dtype_grouped_names: if p.requires_grad:
dtype_grouped_names[p.dtype] = [] if p.dtype not in dtype_grouped_names:
ordered_dtype.append(p.dtype) dtype_grouped_names[p.dtype] = []
dtype_grouped_names[p.dtype].append(n) ordered_dtype.append(p.dtype)
dtype_grouped_names[p.dtype].append(n)
ordered_names = list(chain(*(dtype_grouped_names[n] for n in ordered_dtype))) ordered_names = list(chain(*(dtype_grouped_names[n] for n in ordered_dtype)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment