// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.ML.OnnxRuntime
{
///
/// See documentation for OrtAllocatorType in C API
///
public enum OrtAllocatorType
{
DeviceAllocator = 0, // Device specific allocator
ArenaAllocator = 1 // Memory arena
}
///
/// See documentation for OrtMemType in C API
///
public enum OrtMemType
{
CpuInput = -2, // Any CPU memory used by non-CPU execution provider
CpuOutput = -1, // CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED
Cpu = CpuOutput, // temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED
Default = 0, // the default allocator for execution provider
}
///
/// This class encapsulates arena configuration information that will be used to define the behavior
/// of an arena based allocator
/// See docs/C_API.md for more details
///
public class OrtArenaCfg : SafeHandle
{
///
/// Create an instance of arena configuration which will be used to create an arena based allocator
/// See docs/C_API.md for details on what the following parameters mean and how to choose these values
///
/// Maximum amount of memory the arena allocates
/// Strategy for arena expansion
/// Size of the region that the arena allocates first
/// Maximum amount of fragmentation allowed per chunk
public OrtArenaCfg(uint maxMemory, int arenaExtendStrategy, int initialChunkSizeBytes, int maxDeadBytesPerChunk)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateArenaCfg((UIntPtr)maxMemory,
arenaExtendStrategy,
initialChunkSizeBytes,
maxDeadBytesPerChunk,
out handle));
}
internal IntPtr Pointer
{
get
{
return handle;
}
}
#region SafeHandle
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtEnv
///
/// always returns true
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseArenaCfg(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
///
/// This class encapsulates and most of the time owns the underlying native OrtMemoryInfo instance.
/// Instance returned from OrtAllocator will not own OrtMemoryInfo, the class must be disposed
/// regardless.
///
/// Use this class to query and create OrtAllocator instances so you can pre-allocate memory for model
/// inputs/outputs and use it for binding. Instances of the class can also used to created OrtValues bound
/// to pre-allocated memory. In that case, the instance of OrtMemoryInfo contains the information about the allocator
/// used to allocate the underlying memory.
///
public class OrtMemoryInfo : SafeHandle
{
private static readonly Lazy _defaultCpuAllocInfo = new Lazy(CreateCpuMemoryInfo);
private readonly bool _owned; // false if we are exposing OrtMemoryInfo from an allocator which owns it
private static OrtMemoryInfo CreateCpuMemoryInfo()
{
IntPtr memoryInfo = IntPtr.Zero;
// Returns OrtMemoryInfo instance that needs to be disposed
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, OrtMemType.Cpu, out memoryInfo));
return new OrtMemoryInfo(memoryInfo, true);
}
///
/// Default CPU based instance
///
/// Singleton instance of a CpuMemoryInfo
public static OrtMemoryInfo DefaultInstance
{
get
{
return _defaultCpuAllocInfo.Value;
}
}
internal IntPtr Pointer
{
get
{
return handle;
}
}
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// This allocator takes an native pointer to already existing
/// instance of OrtMemoryInfo. That instance may either be owned or not
/// owned. In the latter case, this class serves to expose native properties
/// of the instance.
///
///
internal OrtMemoryInfo(IntPtr allocInfo, bool owned)
: base(allocInfo, true)
{
_owned = owned;
}
///
/// Predefined utf8 encoded allocator names. Use them to construct an instance of
/// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs.
///
public static readonly byte[] allocatorCPU = Encoding.UTF8.GetBytes("Cpu" + Char.MinValue);
///
/// Predefined utf8 encoded allocator names. Use them to construct an instance of
/// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs.
///
public static readonly byte[] allocatorCUDA = Encoding.UTF8.GetBytes("Cuda" + Char.MinValue);
///
/// Predefined utf8 encoded allocator names. Use them to construct an instance of
/// OrtMemoryInfo to avoid UTF-16 to UTF-8 conversion costs.
///
public static readonly byte[] allocatorCUDA_PINNED = Encoding.UTF8.GetBytes("CudaPinned" + Char.MinValue);
///
/// Create an instance of OrtMemoryInfo according to the specification
/// Memory info instances are usually used to get a handle of a native allocator
/// that is present within the current inference session object. That, in turn, depends
/// of what execution providers are available within the binary that you are using and are
/// registered with Add methods.
///
/// Allocator name. Use of the predefined above.
/// Allocator type
/// Device id
/// Memory type
public OrtMemoryInfo(byte[] utf8AllocatorName, OrtAllocatorType allocatorType, int deviceId, OrtMemType memoryType)
: base(IntPtr.Zero, true)
{
using (var pinnedName = new PinnedGCHandle(GCHandle.Alloc(utf8AllocatorName, GCHandleType.Pinned)))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMemoryInfo(pinnedName.Pointer,
allocatorType,
deviceId,
memoryType,
out handle));
}
_owned = true;
}
///
/// Create an instance of OrtMemoryInfo according to the specification.
///
/// Allocator name
/// Allocator type
/// Device id
/// Memory type
public OrtMemoryInfo(string allocatorName, OrtAllocatorType allocatorType, int deviceId, OrtMemType memoryType)
: this(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), allocatorType, deviceId, memoryType)
{
}
///
/// Name of the allocator associated with the OrtMemoryInfo instance
///
public string Name
{
get
{
IntPtr utf8Name = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtMemoryInfoGetName(handle, out utf8Name));
return NativeOnnxValueHelper.StringFromNativeUtf8(utf8Name);
}
}
///
/// Returns device ID
///
/// returns integer Id value
public int Id
{
get
{
int id = 0;
NativeApiStatus.VerifySuccess(NativeMethods.OrtMemoryInfoGetId(handle, out id));
return id;
}
}
///
/// The below 2 are really properties but naming them is a challenge
/// as names would conflict with the returned type. Also, there are native
/// calls behind them so exposing them as Get() would be appropriate.
///
/// OrtMemoryType for the instance
public OrtMemType GetMemoryType()
{
OrtMemType memoryType = OrtMemType.Default;
NativeApiStatus.VerifySuccess(NativeMethods.OrtMemoryInfoGetMemType(handle, out memoryType));
return memoryType;
}
///
/// Fetches allocator type from the underlying OrtAllocator
///
/// Returns allocator type
public OrtAllocatorType GetAllocatorType()
{
OrtAllocatorType allocatorType = OrtAllocatorType.ArenaAllocator;
NativeApiStatus.VerifySuccess(NativeMethods.OrtMemoryInfoGetType(handle, out allocatorType));
return allocatorType;
}
///
/// Overrides System.Object.Equals(object)
///
/// object to compare to
/// true if obj is an instance of OrtMemoryInfo and is equal to this
public override bool Equals(object obj)
{
var other = obj as OrtMemoryInfo;
if (other == null)
{
return false;
}
return Equals(other);
}
///
/// Compares this instance with another
///
/// OrtMemoryInfo to compare to
/// true if instances are equal according to OrtCompareMemoryInfo.
public bool Equals(OrtMemoryInfo other)
{
if (this == other)
{
return true;
}
int result = -1;
NativeApiStatus.VerifySuccess(NativeMethods.OrtCompareMemoryInfo(handle, other.Pointer, out result));
return (result == 0);
}
///
/// Overrides System.Object.GetHashCode()
///
/// integer hash value
public override int GetHashCode()
{
return Pointer.ToInt32();
}
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtMmeoryInfo
///
/// always returns true
protected override bool ReleaseHandle()
{
// If this instance exposes OrtMemoryInfo that belongs
// to the allocator then the allocator owns it
if (_owned)
{
NativeMethods.OrtReleaseMemoryInfo(handle);
}
handle = IntPtr.Zero;
return true;
}
#endregion
}
///
/// This class represents an arbitrary buffer of memory
/// allocated and owned by the user. It can be either a CPU, GPU or other device memory
/// that can be suitably represented by IntPtr.
/// This is just a composite of the buffer related information.
/// The memory is assumed to be pinned if necessary and usable immediately
/// in the native code.
///
public class OrtExternalAllocation
{
///
/// Constructor
///
/// use to accurately describe a piece of memory that this is wrapping
/// shape of this buffer
/// element type
/// the actual pointer to memory
/// size of the allocation in bytes
public OrtExternalAllocation(OrtMemoryInfo memInfo, long[] shape, Tensors.TensorElementType elementType, IntPtr pointer, long sizeInBytes)
{
Type type;
int width;
if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elementType.ToString());
}
if (elementType == TensorElementType.String)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Strings are not supported by this API");
}
var shapeSize = ArrayUtilities.GetSizeForShape(shape);
var requiredBufferSize = shapeSize * width;
if (requiredBufferSize > sizeInBytes)
{
var message = String.Format("Shape of {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes",
shapeSize, requiredBufferSize, sizeInBytes);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}
Info = memInfo;
Shape = shape;
ElementType = elementType;
Pointer = pointer;
Size = sizeInBytes;
}
///
/// OrtMemoryInfo
///
public OrtMemoryInfo Info { get; private set; }
///
/// Shape
///
public long[] Shape { get; private set; }
///
/// Data type
///
public Tensors.TensorElementType ElementType { get; private set; }
///
/// Actual memory ptr
///
public IntPtr Pointer { get; private set; }
///
/// Size of the allocation in bytes
///
public long Size { get; private set; }
}
///
/// This class represents memory allocation made by a specific onnxruntime
/// allocator. Use OrtAllocator.Allocate() to obtain an instance of this class.
/// It implements IDisposable and makes use of the original allocator
/// used to allocate the memory. The lifespan of the allocator instance must eclipse the
/// lifespan of the allocation. Or, if you prefer, all OrtMemoryAllocation instances must be
/// disposed of before the corresponding allocator instances are disposed of.
///
public class OrtMemoryAllocation : SafeHandle
{
// This allocator is used to free this allocation
// This also prevents OrtAllocator GC/finalization in case
// the user forgets to Dispose() of this allocation
private OrtAllocator _allocator;
///
/// This constructs an instance representing an native memory allocation.
/// Typically returned by OrtAllocator.Allocate(). However, some APIs return
/// natively allocated IntPtr using a specific allocator. It is a good practice
/// to wrap such a memory into OrtAllocation for proper disposal. You can set
/// size to zero if not known, which is not important for disposing.
///
///
///
///
internal OrtMemoryAllocation(OrtAllocator allocator, IntPtr pointer, uint size)
: base(pointer, true)
{
_allocator = allocator;
Size = size;
}
///
/// Internal accessor to call native methods
///
internal IntPtr Pointer { get { return handle; } }
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// Size of the allocation
///
/// uint size of the allocation in bytes
public uint Size { get; private set; }
///
/// Memory Information about this allocation
///
/// Returns OrtMemoryInfo from the allocator
public OrtMemoryInfo Info
{
get
{
return _allocator.Info;
}
}
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to deallocate
/// a chunk of memory using the specified allocator.
///
/// always returns true
protected override bool ReleaseHandle()
{
_allocator.FreeMemory(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
///
/// The class exposes native internal allocator for Onnxruntime.
/// This allocator enables you to allocate memory from the internal
/// memory pools including device allocations. Useful for binding.
///
public class OrtAllocator : SafeHandle
{
private static readonly Lazy _defaultInstance = new Lazy(GetDefaultCpuAllocator);
private readonly bool _owned;
private static OrtAllocator GetDefaultCpuAllocator()
{
IntPtr allocator = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAllocatorWithDefaultOptions(out allocator));
// Instance of default cpu allocator is a native singleton
// Do not dispose of
return new OrtAllocator(allocator, false);
}
///
/// Default CPU allocator instance
///
public static OrtAllocator DefaultInstance // May throw exception in every access, if the constructor have thrown an exception
{
get
{
return _defaultInstance.Value;
}
}
internal IntPtr Pointer
{
get
{
return handle;
}
}
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// Internal constructor wraps existing native allocators
///
///
///
internal OrtAllocator(IntPtr allocator, bool owned)
: base(allocator, true)
{
_owned = owned;
}
///
/// Creates an instance of OrtAllocator according to the specifications in OrtMemorInfo.
/// The requested allocator should be available within the given session instance. This means
/// both, the native library was build with specific allocators (for instance CUDA) and the corresponding
/// provider was added to SessionsOptions before instantiating the session object.
///
///
///
public OrtAllocator(InferenceSession session, OrtMemoryInfo memInfo)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAllocator(session.Handle, memInfo.Pointer, out handle));
_owned = true;
}
///
/// OrtMemoryInfo instance owned by the allocator
///
/// Instance of OrtMemoryInfo describing this allocator
public OrtMemoryInfo Info
{
get
{
IntPtr memoryInfo = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtAllocatorGetInfo(handle, out memoryInfo));
// This serves as an exposure of memory_info owned by the allocator
return new OrtMemoryInfo(memoryInfo, false);
}
}
///
/// Allocate native memory. Returns a disposable instance of OrtMemoryAllocation
///
/// number of bytes to allocate
/// Instance of OrtMemoryAllocation
public OrtMemoryAllocation Allocate(uint size)
{
IntPtr allocation = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtAllocatorAlloc(handle, (UIntPtr)size, out allocation));
return new OrtMemoryAllocation(this, allocation, size);
}
///
/// This internal interface is used for freeing memory.
///
/// pointer to a native memory chunk allocated by this allocator instance
internal void FreeMemory(IntPtr allocation)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAllocatorFree(handle, allocation));
}
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtAllocator
///
/// always returns true
protected override bool ReleaseHandle()
{
// Singleton default allocator is not owned
if (_owned)
{
NativeMethods.OrtReleaseAllocator(handle);
}
handle = IntPtr.Zero;
return true;
}
#endregion
}
}