// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Buffers; using System.Diagnostics; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { /// /// A type of data that OrtValue encapsulates. /// public enum OnnxValueType { ONNX_TYPE_UNKNOWN = 0, // Not set ONNX_TYPE_TENSOR = 1, // It's a Tensor ONNX_TYPE_SEQUENCE = 2, // It's an Onnx sequence which may be a sequence of Tensors/Maps/Sequences ONNX_TYPE_MAP = 3, // It's a map ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor } /// /// Represents a disposable OrtValue. /// This class exposes a native instance of OrtValue. /// The class implements IDisposable via SafeHandle and must /// be disposed. /// public class OrtValue : SafeHandle { /// /// Use factory methods to instantiate this class /// /// Pointer to a native instance of OrtValue /// Default true, own the raw handle. Otherwise, the handle is owned by another instance /// However, we use this class to expose OrtValue that is owned by DisposableNamedOnnxValue /// internal OrtValue(IntPtr handle, bool owned = true) : base(handle, true) { IsOwned = owned; } internal IntPtr Handle { get { return handle; } } /// /// Overrides SafeHandle.IsInvalid /// /// returns true if handle is equal to Zero public override bool IsInvalid { get { return handle == IntPtr.Zero; } } #region NamedOnnxValue/DisposableOnnxValue accommodations /// /// This internal interface is used to transfer ownership elsewhere. /// This instance must still be disposed in case there are other native /// objects still owned. This is a convenience method to ensure that an underlying /// OrtValue is disposed exactly once when exception is thrown. /// /// internal IntPtr Disown() { var ret = Handle; handle = IntPtr.Zero; IsOwned = false; return ret; } internal bool IsOwned { get; private set; } #endregion /// /// Factory method to construct an OrtValue of Tensor type on top of pre-allocated memory. /// This can be a piece of native memory allocated by OrtAllocator (possibly on a device) /// or a piece of pinned managed memory. /// /// The resulting OrtValue does not own the underlying memory buffer and will not attempt to /// deallocate it. /// /// Memory Info. For managed memory it is a default cpu. /// For Native memory must be obtained from the allocator or OrtMemoryAllocation instance /// DataType for the Tensor /// Tensor shape /// Pointer to a raw memory buffer /// Buffer length in bytes /// A disposable instance of OrtValue public static OrtValue CreateTensorValueWithData(OrtMemoryInfo memInfo, TensorElementType elementType, long[] shape, IntPtr dataBuffer, long bufferLength) { Type type; int width; if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width)) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unable to query type information for data type: " + elementType.ToString()); } if (elementType == TensorElementType.String) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Cannot map managed strings buffer to native OrtValue"); } var shapeSize = ArrayUtilities.GetSizeForShape(shape); var requiredBufferSize = shapeSize * width; if (requiredBufferSize > bufferLength) { var message = String.Format("Shape of: {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes", shapeSize, requiredBufferSize, bufferLength); throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message); } IntPtr ortValueHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue( memInfo.Pointer, dataBuffer, (UIntPtr)bufferLength, shape, (UIntPtr)shape.Length, elementType, out ortValueHandle )); return new OrtValue(ortValueHandle); } /// /// This is a factory method creates a native Onnxruntime OrtValue containing a tensor. /// The method will attempt to pin managed memory so no copying occurs when data is passed down /// to native code. /// /// Tensor object /// For all tensor types but string tensors we endeavor to use managed memory /// to avoid additional allocation and copy. This out parameter represents a chunk of pinned memory which will need /// to be disposed when no longer needed. The lifespan of memoryHandle should eclipse the lifespan of the corresponding /// OrtValue. /// /// discovered tensor element type /// And instance of OrtValue constructed on top of the object public static OrtValue CreateFromTensorObject(Object value, out MemoryHandle? memoryHandle, out TensorElementType elementType) { // Check if this is a Tensor if (!(value is TensorBase)) { throw new NotSupportedException("The inference value " + nameof(value) + " is not of a supported type"); } var tensorBase = value as TensorBase; var typeInfo = tensorBase.GetTypeInfo(); if (typeInfo == null) { throw new OnnxRuntimeException(ErrorCode.RequirementNotRegistered, "BUG Check"); } MemoryHandle? memHandle; OrtValue ortValue = null; int dataBufferLength = 0; long[] shape = null; int rank = 0; TensorElementType elType = typeInfo.ElementType; var typeSize = typeInfo.TypeSize; if (typeInfo.IsString) { ortValue = CreateStringTensor(value as Tensor); memHandle = null; } else { switch (elType) { case TensorElementType.Float: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Double: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Int32: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.UInt32: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Int64: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.UInt64: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Int16: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.UInt16: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.UInt8: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Int8: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Bool: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.Float16: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; case TensorElementType.BFloat16: PinAsTensor(value as Tensor, typeSize, out memHandle, out dataBufferLength, out shape, out rank); break; default: throw new NotSupportedException("Element type: " + elType + " is not of a supported type"); } try { Debug.Assert(memHandle.HasValue); IntPtr dataBufferPointer = IntPtr.Zero; unsafe { dataBufferPointer = (IntPtr)((MemoryHandle)memHandle).Pointer; } IntPtr nativeValue; NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue( OrtMemoryInfo.DefaultInstance.Pointer, dataBufferPointer, (UIntPtr)(dataBufferLength), shape, (UIntPtr)rank, elType, out nativeValue)); ortValue = new OrtValue(nativeValue); } catch (Exception) { memHandle?.Dispose(); throw; } } memoryHandle = memHandle; elementType = elType; return ortValue; } private static void PinAsTensor( Tensor tensor, int elementSize, out MemoryHandle? pinnedHandle, out int dataBufferLength, out long[] shape, out int rank) { if (tensor == null) { throw new OnnxRuntimeException(ErrorCode.Fail, "Cast to Tensor failed. BUG check!"); } if (tensor.IsReversedStride) { //TODO: not sure how to support reverse stride. may be able to calculate the shape differently throw new NotSupportedException(nameof(Tensor) + " of reverseStride is not supported"); } DenseTensor dt = null; if (tensor is DenseTensor) { dt = tensor as DenseTensor; } else { dt = tensor.ToDenseTensor(); } pinnedHandle = dt.Buffer.Pin(); dataBufferLength = dt.Buffer.Length * elementSize; shape = new long[dt.Dimensions.Length]; for (int i = 0; i < dt.Dimensions.Length; ++i) { shape[i] = dt.Dimensions[i]; } rank = dt.Rank; } private static OrtValue CreateStringTensor(Tensor tensor) { if (tensor == null) { throw new OnnxRuntimeException(ErrorCode.Fail, "Cast to Tensor failed. BUG check!"); } int totalLength = 0; for (int i = 0; i < tensor.Length; i++) { totalLength += System.Text.Encoding.UTF8.GetByteCount(tensor.GetValue(i)); } long[] shape = new long[tensor.Dimensions.Length]; for (int i = 0; i < tensor.Dimensions.Length; i++) { shape[i] = tensor.Dimensions[i]; } // allocate the native tensor IntPtr valueHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorAsOrtValue( OrtAllocator.DefaultInstance.Pointer, shape, (UIntPtr)(shape.Length), TensorElementType.String, out valueHandle )); var ortValue = new OrtValue(valueHandle); try { // fill the native tensor, using GetValue(index) from the Tensor var len = tensor.Length; var nativeStrings = new IntPtr[len]; using (var pinnedHandles = new DisposableList((int)len)) { for (int i = 0; i < len; i++) { var utf8str = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(tensor.GetValue(i)); var gcHandle = GCHandle.Alloc(utf8str, GCHandleType.Pinned); nativeStrings[i] = gcHandle.AddrOfPinnedObject(); pinnedHandles.Add(new PinnedGCHandle(gcHandle)); } using (var pinnedStrings = new PinnedGCHandle(GCHandle.Alloc(nativeStrings, GCHandleType.Pinned))) NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len)); } } catch (OnnxRuntimeException) { ortValue.Dispose(); throw; } return ortValue; } #region SafeHandle /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of /// the native instance of OrtValue /// /// always returns true protected override bool ReleaseHandle() { // We have to surrender ownership to some legacy classes // Or we never had that ownership to begin with if (IsOwned) { NativeMethods.OrtReleaseValue(handle); } // Prevent use after disposal handle = IntPtr.Zero; return true; } // No need for the finalizer #endregion } }