NativeOnnxValueHelper.shared.cs 4.82 KB
Newer Older
gaoqiong's avatar
gaoqiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.OnnxRuntime
{
    internal class PinnedGCHandle : IDisposable
    {
        private GCHandle _handle;

        public PinnedGCHandle(GCHandle handle)
        {
            _handle = handle;
        }

        public IntPtr Pointer
        {
            get
            {
                return _handle.AddrOfPinnedObject();
            }
        }

        #region Disposable Support
        protected virtual void Dispose(bool disposing)
        {
            if (disposing)
            {
                _handle.Free();
            }
        }
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }
        // No need for the finalizer
        // If this is not disposed timely GC can't help us
        #endregion
   }

    /// <summary>
    /// This helper class contains methods to create native OrtValue from a managed value object
    /// </summary>
    internal static class NativeOnnxValueHelper
    {
        /// <summary>
        /// Converts C# UTF-16 string to UTF-8 zero terminated
        /// byte[] instance
        /// </summary>
        /// <param name="s">string to be converted</param>
        /// <returns>UTF-8 encoded equivalent</returns>
        internal static byte[] StringToZeroTerminatedUtf8(string s)
        {
            byte[] utf8Bytes = UTF8Encoding.UTF8.GetBytes(s);
            Array.Resize(ref utf8Bytes, utf8Bytes.Length + 1);
            utf8Bytes[utf8Bytes.Length - 1] = 0;
            return utf8Bytes;
        }

        /// <summary>
        /// Reads UTF-8 encode string from a C zero terminated string
        /// and converts it into a C# UTF-16 encoded string
        /// </summary>
        /// <param name="nativeUtf8">pointer to native or pinned memory where Utf-8 resides</param>
        /// <returns></returns>
        internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
        {
            // .NET 5.0 has Marshal.PtrToStringUTF8 that does the below
            int len = 0;
            while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len;
            byte[] buffer = new byte[len];
            Marshal.Copy(nativeUtf8, buffer, 0, len);
            return Encoding.UTF8.GetString(buffer, 0, buffer.Length);
        }

        /// <summary>
        /// Run helper
        /// </summary>
        /// <param name="names">names to convert to zero terminated utf8 and pin</param>
        /// <param name="extractor">delegate for string extraction from inputs</param>
        /// <param name="cleanupList">list to add pinned memory to for later disposal</param>
        /// <returns></returns>
        internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, NameExtractor<T> extractor,
            DisposableList<IDisposable> cleanupList)
        {
            var result = new IntPtr[names.Count];
            for (int i = 0; i < names.Count; ++i)
            {
                var name = extractor(names.ElementAt(i));
                var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
                var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
                result[i] = pinnedHandle.Pointer;
                cleanupList.Add(pinnedHandle);
            }
            return result;
        }

        /// <summary>
        /// Converts C# UTF-16 string to UTF-8 zero terminated
        /// byte[] instance
        /// </summary>
        /// <param name="str">string to be converted</param>
        /// <returns>UTF-8 encoded equivalent</returns>
        internal static byte[] GetPlatformSerializedString(string str)
        {
            if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue);
            else
                return StringToZeroTerminatedUtf8(str);
        }

        // Delegate for string extraction from an arbitrary input/output object
        internal delegate string NameExtractor<in TInput>(TInput input);
    }

    internal static class TensorElementTypeConverter
    {
        public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
        {
            bool result = true;
            TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType);
            if(typeInfo != null)
            {
                type = typeInfo.TensorType;
                width = typeInfo.TypeSize;
            }
            else
            {
                type = null;
                width = 0;
                result = false;
            }
            return result;
        }
    }
}