convert_mixin.py 582 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape


class ConvertMixin:
    @staticmethod
    def _convert_model(model, input):
        script_module = torch.jit.script(model)
        model_ir = convert_to_graph(script_module, model)
        return model_ir


class ConvertWithShapeMixin:
    @staticmethod
    def _convert_model(model, input):
        script_module = torch.jit.script(model)
18
        model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input)
19
        return model_ir