Unverified Commit 8e1b7655 authored by Himashi Amanda Peiris's avatar Himashi Amanda Peiris Committed by GitHub
Browse files

updated unit tests of network morphism tuner (#2427)

parent f42e53d2
...@@ -102,7 +102,7 @@ class NetworkMorphismTestCase(TestCase): ...@@ -102,7 +102,7 @@ class NetworkMorphismTestCase(TestCase):
graph_init = CnnGenerator(10, (32, 32, 3)).generate() graph_init = CnnGenerator(10, (32, 32, 3)).generate()
json_out = graph_to_json(graph_init, "temp.json") json_out = graph_to_json(graph_init, "temp.json")
graph_recover = json_to_graph(json_out) graph_recover = json_to_graph(json_out)
deeper_graph = to_wider_graph(deepcopy(graph_recover)) deeper_graph = to_deeper_graph(deepcopy(graph_recover))
model = deeper_graph.produce_torch_model() model = deeper_graph.produce_torch_model()
out = model(torch.ones(1, 3, 32, 32)) out = model(torch.ones(1, 3, 32, 32))
self.assertEqual(out.shape, torch.Size([1, 10])) self.assertEqual(out.shape, torch.Size([1, 10]))
...@@ -114,7 +114,7 @@ class NetworkMorphismTestCase(TestCase): ...@@ -114,7 +114,7 @@ class NetworkMorphismTestCase(TestCase):
graph_init = CnnGenerator(10, (32, 32, 3)).generate() graph_init = CnnGenerator(10, (32, 32, 3)).generate()
json_out = graph_to_json(graph_init, "temp.json") json_out = graph_to_json(graph_init, "temp.json")
graph_recover = json_to_graph(json_out) graph_recover = json_to_graph(json_out)
skip_connection_graph = to_wider_graph(deepcopy(graph_recover)) skip_connection_graph = to_skip_connection_graph(deepcopy(graph_recover))
model = skip_connection_graph.produce_torch_model() model = skip_connection_graph.produce_torch_model()
out = model(torch.ones(1, 3, 32, 32)) out = model(torch.ones(1, 3, 32, 32))
self.assertEqual(out.shape, torch.Size([1, 10])) self.assertEqual(out.shape, torch.Size([1, 10]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment