backend_rep.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -------------------------------------------------------------------------
# Copyright (c) Advanced Micro Device Inc. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Implements ONNX's backend API.
"""
import sys
if sys.version_info < (3, 0):
    sys.exit()

import migraphx
from onnx.backend.base import BackendRep
import numpy as np
from typing import Any, Tuple


class MIGraphXBackendRep(BackendRep):
    """
    Computes the prediction for a pipeline converted into
    an :class:`onnxruntime.InferenceSession` node.
    """
    def __init__(self, prog, input_names):
        """
        :param session: :class:`migraphx.program`
        """
        self._program = prog
        self._input_names = input_names

    def run(self, inputs, **kwargs):  # type: (Any, **Any) -> Tuple[Any, ...]
        """
        Computes the prediction.
        See :meth:`migraphx.program.run`.
        """

        if isinstance(inputs, list):
            inps = {}
            for i, name in enumerate(self._input_names):
                inps[name] = migraphx.argument(inputs[i])
            mgx_outputs = self._program.run(inps)
            outs = []
            for out in mgx_outputs:
                outs.append(np.array(out))
            return outs
        else:
            inp = self._program.get_parameter_shapes().keys()
            if len(inp) != 1:
                raise RuntimeError("Model expect {0} inputs".format(len(inp)))
            inps = {inp[0]: migraphx.argument(inputs)}
            mgx_outputs = self._program.run(inps)
            outs = []
            for out in mgx_outputs:
                outs.append(np.array(out))
            return self._program.run(inps)