Unverified Commit d6b5c6ca authored by Yiwu Yao's avatar Yiwu Yao Committed by GitHub
Browse files

Fix prioritiezd board update and rand_parameters for cream nas (#3498)

parent aea98dd6
...@@ -131,12 +131,12 @@ class SuperNet(nn.Module): ...@@ -131,12 +131,12 @@ class SuperNet(nn.Module):
yield param yield param
if not meta: if not meta:
for layer, layer_arch in zip(self.blocks, architecture): for choice_blocks, choice_name in zip(self.blocks, architecture):
for blocks, arch in zip(layer, layer_arch): choice_sample = architecture[choice_name]
if arch == -1: for block, arch in zip(choice_blocks, choice_sample):
if not arch:
continue continue
for name, param in blocks[arch].named_parameters( for name, param in block.named_parameters(recurse=True):
recurse=True):
yield param yield param
......
...@@ -165,7 +165,7 @@ class CreamSupernetTrainer(Trainer): ...@@ -165,7 +165,7 @@ class CreamSupernetTrainer(Trainer):
(val_prec1, (val_prec1,
prec1, prec1,
flops, flops,
self.current_teacher_arch, self.current_student_arch,
training_data, training_data,
torch.nn.functional.softmax( torch.nn.functional.softmax(
features, features,
...@@ -174,8 +174,6 @@ class CreamSupernetTrainer(Trainer): ...@@ -174,8 +174,6 @@ class CreamSupernetTrainer(Trainer):
self.prioritized_board, reverse=True) self.prioritized_board, reverse=True)
if len(self.prioritized_board) > self.pool_size: if len(self.prioritized_board) > self.pool_size:
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
del self.prioritized_board[-1] del self.prioritized_board[-1]
# only update student network weights # only update student network weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment