InferenceSampleApi.cs 4.22 KB
Newer Older
gaoqiong's avatar
gaoqiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.OnnxRuntime.Tensors;

namespace Microsoft.ML.OnnxRuntime.InferenceSample
{
    public class InferenceSampleApi : IDisposable
    {
        public InferenceSampleApi()
        {
            model = LoadModelFromEmbeddedResource("TestData.squeezenet.onnx");

            // this is the data for only one input tensor for this model
            var inputTensor = LoadTensorFromEmbeddedResource("TestData.bench.in");

            // create default session with default session options
            // Creating an InferenceSession and loading the model is an expensive operation, so generally you would
            // do this once. InferenceSession.Run can be called multiple times, and concurrently.
            CreateInferenceSession();

            // setup sample input data
            inputData = new List<NamedOnnxValue>();
            var inputMeta = inferenceSession.InputMetadata;
            foreach (var name in inputMeta.Keys)
            {
                // note: DenseTensor takes a copy of the provided data
                var tensor = new DenseTensor<float>(inputTensor, inputMeta[name].Dimensions);
                inputData.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
            }
        }

        public void CreateInferenceSession(SessionOptions options = null)
        {
            // Optional : Create session options and set any relevant values.
            // If an additional execution provider is needed it should be added to the SessionOptions prior to
            // creating the InferenceSession. The CPU Execution Provider is always added by default.
            if (options == null)
            {
                options = new SessionOptions { LogId = "Sample" };
            }

            inferenceSession = new InferenceSession(model, options);
        }

        public void Execute()
        {
            // Run the inference
            // 'results' is an IDisposableReadOnlyCollection<DisposableNamedOnnxValue> container
            using (var results = inferenceSession.Run(inputData))
            {
                // dump the results
                foreach (var r in results)
                {
                    Console.WriteLine("Output for {0}", r.Name);
                    Console.WriteLine(r.AsTensor<float>().GetArrayString());
                }
            }
        }

        protected virtual void Dispose(bool disposing)
        {
            if (disposing && inferenceSession != null)
            {
                inferenceSession.Dispose();
                inferenceSession = null;
            }
        }

        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        static float[] LoadTensorFromEmbeddedResource(string path)
        {
            var tensorData = new List<float>();
            var assembly = typeof(InferenceSampleApi).Assembly;

            using (var inputFile = 
                new StreamReader(assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}")))
            {
                inputFile.ReadLine(); // skip the input name
                string[] dataStr = inputFile.ReadLine().Split(new char[] { ',', '[', ']' }, 
                                                              StringSplitOptions.RemoveEmptyEntries);
                for (int i = 0; i < dataStr.Length; i++)
                {
                    tensorData.Add(Single.Parse(dataStr[i]));
                }
            }

            return tensorData.ToArray();
        }

        static byte[] LoadModelFromEmbeddedResource(string path)
        {
            var assembly = typeof(InferenceSampleApi).Assembly;
            byte[] model = null;

            using (Stream stream = assembly.GetManifestResourceStream($"{assembly.GetName().Name}.{path}"))
            {
                using (MemoryStream memoryStream = new MemoryStream())
                {
                    stream.CopyTo(memoryStream);
                    model = memoryStream.ToArray();
                }
            }

            return model;
        }

        private readonly byte[] model;
        private readonly List<NamedOnnxValue> inputData;
        private InferenceSession inferenceSession;
    }
}