backend.py 4.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -------------------------------------------------------------------------
# Copyright (c) Advanced Micro Devices. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Implements ONNX's backend API.
"""
import sys
if sys.version_info < (3, 0):
    sys.exit()

from onnx import ModelProto
from onnx.checker import check_model
from onnx.backend.base import Backend
import migraphx
from onnx_migraphx.backend_rep import MIGraphXBackendRep


def get_device():
    return ("CPU", "GPU")


class MIGraphXBackend(Backend):
    _device = "GPU"
    _input_names = []

    @classmethod
    def set_device(cls, device):
        cls._device = device
    """
    Implements
    `ONNX's backend API <https://github.com/onnx/onnx/blob/master/docs/ImplementingAnOnnxBackend.md>`_
    with *ONNX Runtime*.
    The backend is mostly used when you need to switch between
    multiple runtimes with the same API.
    `Importing models from ONNX to Caffe2 <https://github.com/onnx/tutorials/blob/master/tutorials/OnnxCaffe2Import.ipynb>`_
    shows how to use *caffe2* as a backend for a converted model.
    Note: This is not the official Python API.
    """  # noqa: E501

    @classmethod
    def is_compatible(cls, model, device=None, **kwargs):
        """
        Return whether the model is compatible with the backend.

        :param model: unused
        :param device: None to use the default device or a string (ex: `'CPU'`)
        :return: boolean
        """
        device = cls._device
        return cls.supports_device(device)

    @classmethod
    def supports_device(cls, device):
        """
        Check whether the backend is compiled with particular device support.
        In particular it's used in the testing suite.
        """
        return device in get_device()

    @classmethod
    def prepare(cls, model, device=None, **kwargs):
        """
        Load the model and creates a :class:`migraphx.program`
        ready to be used as a backend.

        :param model: ModelProto (returned by `onnx.load`),
            string for a filename or bytes for a serialized model
        :param device: requested device for the computation,
            None means the default one which depends on
            the compilation settings
        :param kwargs: see :class:`onnxruntime.SessionOptions`
        :return: :class:`migraphx.program`
        """
        if isinstance(model, MIGraphXBackendRep):
            return model
        elif isinstance(model, migraphx.program):
            return MIGraphXBackendRep(model, cls._input_names)
        elif isinstance(model, (str, bytes)):
            for k, v in kwargs.items():
                if hasattr(options, k):
                    setattr(options, k, v)
            if device is not None and not cls.supports_device(device):
                raise RuntimeError(
                    "Incompatible device expected '{0}', got '{1}'".format(
                        device, get_device()))
            inf = migraphx.parse_onnx_buffer(model)
            device = cls._device
            cls._input_names = inf.get_parameter_names()
            inf.compile(migraphx.get_target(device.lower()))
            return cls.prepare(inf, device, **kwargs)
        else:
            # type: ModelProto
            check_model(model)
            bin = model.SerializeToString()
            return cls.prepare(bin, device, **kwargs)

    @classmethod
    def run_model(cls, model, inputs, device=None, **kwargs):
        """
        Compute the prediction.

        :param model: :class:`migraphx.program` returned
            by function *prepare*
        :param inputs: inputs
        :param device: requested device for the computation,
            None means the default one which depends on
            the compilation settings
        :param kwargs: see :class:`migraphx.program`
        :return: predictions
        """
        rep = cls.prepare(model, device, **kwargs)
        return rep.run(inputs, **kwargs)

    @classmethod
    def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
        '''
        This method is not implemented as it is much more efficient
        to run a whole model than every node independently.
        '''
        raise NotImplementedError(
            "It is much more efficient to run a whole model than every node independently."
        )


is_compatible = MIGraphXBackend.is_compatible
prepare = MIGraphXBackend.prepare
run = MIGraphXBackend.run_model
supports_device = MIGraphXBackend.supports_device