Unverified Commit 585b5c53 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix apply_fixed_architecture device error and ENAS micro mask device error (#2088)

parent 89de4061
......@@ -120,7 +120,7 @@ if __name__ == "__main__":
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint, device=device)
apply_fixed_architecture(model, args.arc_checkpoint)
criterion = nn.CrossEntropyLoss()
model.to(device)
......
......@@ -115,7 +115,7 @@ class ENASLayer(nn.Module):
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask
nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device)
prev_nodes_out.append(node_out)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
......
......@@ -101,6 +101,6 @@ if __name__ == "__main__":
from nni.nas.pytorch.fixed import apply_fixed_architecture
assert os.path.isfile(args.exported_arch_path), \
"exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device)
apply_fixed_architecture(model, args.exported_arch_path)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment