"ts/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "d047d6f4d869cf8a227280ae51577eb175f67d6a"
Unverified Commit 6d9c49da authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Set the strict to false in mask conflict utils. (#4078)

parent 4adea9ab
......@@ -39,9 +39,13 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
if traced is None:
assert model is not None and dummy_input is not None
training = model.training
model.eval()
# We need to trace the model in eval mode
traced = torch.jit.trace(model, dummy_input)
model.eval()
kw_args = {}
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
traced = torch.jit.trace(model, dummy_input, **kw_args)
model.train(training)
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment