Unverified Commit 9444e275 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Support nested ModuleList and fix an issue in list append (#3652)

parent ac14b9e4
...@@ -20,17 +20,17 @@ class GraphConverter: ...@@ -20,17 +20,17 @@ class GraphConverter:
self.global_graph_id = 0 self.global_graph_id = 0
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index): def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in graph_inputs: if _input in output_remap:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append' assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input] predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node) assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None src_node_idx = None
src_node = node_index[predecessor_node] src_node = node_index[predecessor_node]
assert isinstance(src_node, Node) assert isinstance(src_node, Node)
elif _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else: else:
predecessor_node = _input.node() predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node) assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
...@@ -315,16 +315,31 @@ class GraphConverter: ...@@ -315,16 +315,31 @@ class GraphConverter:
if submodule.inputsAt(0).type().name() == 'ModuleList': if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList # handle ModuleList
predecessor = submodule.inputsAt(0).node() predecessor = submodule.inputsAt(0).node()
module_name_space = [submodule_name]
while predecessor.inputsAt(0).debugName() != 'self':
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert predecessor.kind() == 'prim::GetAttr'
module_name_space.append(predecessor.s('name'))
predecessor = predecessor.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr' assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name') assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self' module_name_space.append(predecessor.s('name'))
predecessor_name = predecessor.s('name') submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
# TODO: exchange submodule_name and predecessor_name submodule_obj = module
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name]) script_submodule = script_module
predecessor_obj = getattr(module, predecessor_name) for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(predecessor_obj, submodule_name) submodule_obj = getattr(submodule_obj, each_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[predecessor_name]._modules[submodule_name], script_submodule = script_submodule._modules[each_name]
submodule_obj, submodule_full_name, ir_model) subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
......
import os
import sys
import unittest
from typing import (Dict)
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
class TestModels(unittest.TestCase):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
if check_value:
try:
self.assertEqual(len(converted_output), len(expected_output))
for a, b in zip(converted_output, expected_output):
torch.eq(a, b)
except:
self.assertEqual(converted_output, expected_output)
return converted_model
def test_nested_modulelist(self):
class Net(nn.Module):
def __init__(self, num_nodes, num_ops_per_node):
super().__init__()
self.ops = nn.ModuleList()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
for _ in range(num_nodes):
self.ops.append(nn.ModuleList([nn.Linear(16, 16) for __ in range(num_ops_per_node)]))
def forward(self, x):
state = x
for ops in self.ops:
for op in ops:
state = op(state)
return state
model = Net(4, 2)
x = torch.rand((16, 16), dtype=torch.float)
self.run_test(model, (x, ))
def test_append_input_tensor(self):
from typing import List
class Net(nn.Module):
def __init__(self, num_nodes):
super().__init__()
self.ops = nn.ModuleList()
self.num_nodes = num_nodes
for _ in range(num_nodes):
self.ops.append(nn.Linear(16, 16))
def forward(self, x: List[torch.Tensor]):
state = x
for ops in self.ops:
state.append(ops(state[-1]))
return state[-1]
model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment