codegen.py 8.8 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
5
__all__ = ['model_to_pytorch_script']

6
import logging
7
import re
8
from typing import Dict, List, Tuple, Any, cast
9
10

from nni.common.device import Device, GPUDevice
11
12
13
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
from nni.nas.execution.common.graph_op import PyTorchOperation
from nni.nas.utils import STATE_DICT_PY_MAPPING
14

15
from .op_def import ToDevice
16

17
18
_logger = logging.getLogger(__name__)

19

20
def model_to_pytorch_script(model: Model, placement=None) -> str:
21
22
23
    graphs = []
    total_pkgs = set()
    for name, cell in model.graphs.items():
24
        import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
25
26
27
28
29
        graphs.append(graph_code)
        total_pkgs.update(import_pkgs)
    pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
    return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()

30

31
32
def _sorted_incoming_edges(node: Node) -> List[Edge]:
    edges = [edge for edge in node.graph.edges if edge.tail is node]
33
    _logger.debug('sorted_incoming_edges: %s', str(edges))
34
35
    if not edges:
        return []
36
    _logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
37
38
39
    if all(edge.tail_slot is None for edge in edges):
        return edges
    if all(isinstance(edge.tail_slot, int) for edge in edges):
40
        edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
41
42
43
44
        if [edge.tail_slot for edge in edges] == list(range(len(edges))):
            return edges
    raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))

45

46
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
47
    """
48
49
    Format the inputs of a given node.
    Inputs will be formatted with ``_format_variable_name``
50
51
52
53
54

    Parameters
    ----------
    node : Node
        a graph node, get and format its inputs
55
56
    graph_name : str
        subgraph name, to format variable names
57
58
59
60
61
62
63
64
65

    Returns
    -------
    list
        the list of input names
    list
        the list of input values, if an input is simple type, record its value,
        otherwise the value is None
    """
66
67
    edges = _sorted_incoming_edges(node)
    inputs = []
68
    inputs_value = []
69
70
71
    for edge in edges:
        if edge.head.name == '_inputs':
            assert isinstance(edge.head_slot, int)
72
            if edge.head.operation.io_names is not None:
73
                # when input has names, e.g., forward(self, tensor1, tensor2, another_one)
74
                inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
75
76
77
            else:
                # when input has no name, e.g., forward(*_inputs)
                inputs.append('_inputs[{}]'.format(edge.head_slot))
78
            inputs_value.append(None)
79
80
81
        else:
            if edge.head_slot is None:
                # when the input comes from a single-output operator
82
                inputs.append(_format_variable_name(edge.head.name, graph_name))
83
                if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
84
                        'value' in edge.head.operation.parameters:
85
86
87
                    inputs_value.append(edge.head.operation.parameters['value'])
                else:
                    inputs_value.append(None)
88
89
            else:
                # when the input comes from a multi-output operator: needs to know which one it comes from
90
                inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
91
92
                inputs_value.append(None)
    return inputs, inputs_value
93

94

95
def _format_variable_name(name: str, graph_name: str) -> str:
QuanluZhang's avatar
QuanluZhang committed
96
    """
97
98
    1. replace invalid characters in node name
    2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
QuanluZhang's avatar
QuanluZhang committed
99
    """
100
101
102
103
    name = name[len(graph_name):] if name.startswith(graph_name) else name
    name = name.replace('/', '__')

    # https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
104
    name = re.sub(r'\W|^(?=\d)','_', name)
105
106
107
108
109
110
111
112
113
114
115

    if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
        # name can't start with double underscore
        # it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
        # but it's actually very common in our generated code
        name = name[1:]
    elif name.startswith('_'):
        # to avoid conflicts between '_' and '__'
        name = 'i' + name

    return name
QuanluZhang's avatar
QuanluZhang committed
116

117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
    '''
    Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
    we need to remap the GPU ID when generating code to match them correctly.
    For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
    '''
    unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
    node_gpu_cnt = {}
    cuda_remapped_id = {}
    for d in unique_devices:
        if d.node_id not in node_gpu_cnt:
            node_gpu_cnt[d.node_id] = 0
        node_gpu_cnt[d.node_id] += 1
        cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1

    return cuda_remapped_id


136
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
137
    nodes = graph.topo_sort()
138
139
140

    # handle module node and function node differently
    # only need to generate code for module here
141
    import_pkgs = set()
142
    node_codes = []
143
    node_python_mappings = {}
144
145
146
    cuda_remapped_id = None
    if placement:
        cuda_remapped_id = generate_cuda_mapping(placement)
147
148
    for node in nodes:
        if node.operation:
149
            if placement and isinstance(node.operation, ToDevice):
150
                cuda_remapped_id = cast(dict, cuda_remapped_id)
151
152
                node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])

153
154
            if node.operation.type == 'shared':
                continue
155
            pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
156
157
            if pkg_name is not None:
                import_pkgs.add(pkg_name)
158
159
160

            py_variable_name = _format_variable_name(node.name, graph_name)
            node_code = node.operation.to_init_code(py_variable_name)
161
            if node_code is not None:
162
                if placement and node in placement and len(node_code) > 0:
163
                    if isinstance(placement[node], GPUDevice):
164
                        assert cuda_remapped_id is not None
165
166
167
168
                        device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
                    else:
                        device_repr = placement[node].device_repr()
                    node_codes.append(f"{node_code}.to('{device_repr}')")
169
170
                else:
                    node_codes.append(node_code)
171

172
173
174
175
176
                # Map to module hierarchies in original search space python code
                node_python_mappings[py_variable_name] = node.python_name

    node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')

177
    if graph.input_node.operation.io_names is None:
178
179
        input_code = '*_inputs'
    else:
QuanluZhang's avatar
QuanluZhang committed
180
181
        for name in graph.input_node.operation.io_names:
            assert not name.startswith(graph_name)
182
        input_code = ', '.join(graph.input_node.operation.io_names)
183
184

    edge_codes = []
185
186
    sorted_nodes = graph.topo_sort()
    for node in sorted_nodes:
187
        if node.operation:
188
189
            inputs, inputs_value = _format_inputs(node, graph_name)
            node_name = _format_variable_name(node.name, graph_name)
190
191
            submodule_name = node_name
            if node.operation.type == 'shared':
192
                submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
193
            edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
194

195
    output_names, _ = _format_inputs(graph.output_node, graph_name)
196
    if not output_names:
197
198
        raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
    output_code = ', '.join(output_names)
199
200

    linebreak = '\n        '
201
    return import_pkgs, _PyTorchModelTemplate.format(
202
        graph_name=('Graph' if graph_name == '_graph' else graph_name),
203
        inputs=input_code,
QuanluZhang's avatar
QuanluZhang committed
204
        outputs=output_code,
205
206
207
208
209
210
211
212
213
214
215
216
217
        nodes=linebreak.join(node_codes),
        edges=linebreak.join(edge_codes)
    )


# TODO: handle imports

_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

218
import nni.nas.nn.pytorch
219

220
221
{}

222
223
224
225
226
227
228
229
230
231
232
233
234
{}
'''

_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
    def __init__(self):
        super().__init__()
        {nodes}

    def forward(self, {inputs}):
        {edges}
        return {outputs}
'''