// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.ML.OnnxRuntime
{
internal class PinnedGCHandle : IDisposable
{
private GCHandle _handle;
public PinnedGCHandle(GCHandle handle)
{
_handle = handle;
}
public IntPtr Pointer
{
get
{
return _handle.AddrOfPinnedObject();
}
}
#region Disposable Support
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
_handle.Free();
}
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
// No need for the finalizer
// If this is not disposed timely GC can't help us
#endregion
}
///
/// This helper class contains methods to create native OrtValue from a managed value object
///
internal static class NativeOnnxValueHelper
{
///
/// Converts C# UTF-16 string to UTF-8 zero terminated
/// byte[] instance
///
/// string to be converted
/// UTF-8 encoded equivalent
internal static byte[] StringToZeroTerminatedUtf8(string s)
{
byte[] utf8Bytes = UTF8Encoding.UTF8.GetBytes(s);
Array.Resize(ref utf8Bytes, utf8Bytes.Length + 1);
utf8Bytes[utf8Bytes.Length - 1] = 0;
return utf8Bytes;
}
///
/// Reads UTF-8 encode string from a C zero terminated string
/// and converts it into a C# UTF-16 encoded string
///
/// pointer to native or pinned memory where Utf-8 resides
///
internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
{
// .NET 5.0 has Marshal.PtrToStringUTF8 that does the below
int len = 0;
while (Marshal.ReadByte(nativeUtf8, len) != 0) ++len;
byte[] buffer = new byte[len];
Marshal.Copy(nativeUtf8, buffer, 0, len);
return Encoding.UTF8.GetString(buffer, 0, buffer.Length);
}
///
/// Run helper
///
/// names to convert to zero terminated utf8 and pin
/// delegate for string extraction from inputs
/// list to add pinned memory to for later disposal
///
internal static IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, NameExtractor extractor,
DisposableList cleanupList)
{
var result = new IntPtr[names.Count];
for (int i = 0; i < names.Count; ++i)
{
var name = extractor(names.ElementAt(i));
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
result[i] = pinnedHandle.Pointer;
cleanupList.Add(pinnedHandle);
}
return result;
}
///
/// Converts C# UTF-16 string to UTF-8 zero terminated
/// byte[] instance
///
/// string to be converted
/// UTF-8 encoded equivalent
internal static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue);
else
return StringToZeroTerminatedUtf8(str);
}
// Delegate for string extraction from an arbitrary input/output object
internal delegate string NameExtractor(TInput input);
}
internal static class TensorElementTypeConverter
{
public static bool GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
bool result = true;
TensorElementTypeInfo typeInfo = TensorBase.GetElementTypeInfo(elemType);
if(typeInfo != null)
{
type = typeInfo.TensorType;
width = typeInfo.TypeSize;
}
else
{
type = null;
width = 0;
result = false;
}
return result;
}
}
}