// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.ML.OnnxRuntime
{
///
/// This class enables binding of inputs and/or outputs to pre-allocated
/// memory. This enables interesting scenarios. For example, if your input
/// already resides in some pre-allocated memory like GPU, you can bind
/// that piece of memory to an input name and shape and onnxruntime will use that as input.
/// Other traditional inputs can also be bound that already exists as Tensors.
///
/// Note, that this arrangement is designed to minimize data copies and to that effect
/// your memory allocations must match what is expected by the model, whether you run on
/// CPU or GPU. Data copy will still be made, if your pre-allocated memory location does not
/// match the one expected by the model. However, copies with OrtIoBindings are only done once,
/// at the time of the binding, not at run time. This means, that if your input data required a copy,
/// your further input modifications would not be seen by onnxruntime unless you rebind it, even if it is
/// the same buffer. If you require the scenario where data is copied, OrtIOBinding may not be the best match
/// for your use case.
///
/// The fact that data copy is not made during runtime also has performance implications.
///
public class OrtIoBinding : SafeHandle
{
///
/// Use InferenceSession.CreateIoBinding()
///
///
internal OrtIoBinding(InferenceSession session)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateIoBinding(session.Handle, out handle));
}
internal IntPtr Handle
{
get
{
return handle;
}
}
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
///
/// Bind a piece of pre-allocated native memory as a OrtValue Tensor with a given shape
/// to an input with a given name. The model will read the specified input from that memory
/// possibly avoiding the need to copy between devices. OrtMemoryAllocation continues to own
/// the chunk of native memory, and the allocation should be alive until the end of execution.
///
/// of the input
/// Tensor element type
///
/// native memory allocation
public void BindInput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation)
{
BindOrtAllocation(name, elementType, shape, allocation, true);
}
///
/// Bind externally (not from OrtAllocator) allocated memory as input.
/// The model will read the specified input from that memory
/// possibly avoiding the need to copy between devices. The user code continues to own
/// the chunk of externally allocated memory, and the allocation should be alive until the end of execution.
///
/// name
/// non ort allocated memory
public void BindInput(string name, OrtExternalAllocation allocation)
{
BindExternalAllocation(name, allocation, true);
}
///
/// Bind the input with the given name as an OrtValue Tensor allocated in pinned managed memory.
/// Instance of FixedBufferOnnxValue owns the memory and should be alive until the end of execution.
///
/// name of input
///
public void BindInput(string name, FixedBufferOnnxValue fixedValue)
{
if (fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Binding works only with Tensors");
}
BindInputOrOutput(name, fixedValue.Value.Handle, true);
}
///
/// Blocks until device completes all preceding requested tasks.
/// Useful for memory synchronization.
///
public void SynchronizeBoundInputs()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundInputs(handle));
}
///
/// Bind model output to an OrtValue as Tensor with a given type and shape. An instance of OrtMemoryAllocaiton
/// owns the memory and should be alive for the time of execution.
///
/// of the output
/// tensor element type
/// tensor shape
/// allocated memory
public void BindOutput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation)
{
BindOrtAllocation(name, elementType, shape, allocation, false);
}
///
/// Bind externally (not from OrtAllocator) allocated memory as output.
/// The model will read the specified input from that memory
/// possibly avoiding the need to copy between devices. The user code continues to own
/// the chunk of externally allocated memory, and the allocation should be alive until the end of execution.
///
/// name
/// non ort allocated memory
public void BindOutput(string name, OrtExternalAllocation allocation)
{
BindExternalAllocation(name, allocation, false);
}
///
/// Bind model output to a given instance of FixedBufferOnnxValue which owns the underlying
/// pinned managed memory and should be alive for the time of execution.
///
/// of the output
///
public void BindOutput(string name, FixedBufferOnnxValue fixedValue)
{
if (fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Binding works only with Tensors");
}
BindInputOrOutput(name, fixedValue.Value.Handle, false);
}
///
/// This function will bind model output with the given name to a device
/// specified by the memInfo.
///
/// output name
/// instance of memory info
public void BindOutputToDevice(string name, OrtMemoryInfo memInfo)
{
var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned);
using (var pinnedName = new PinnedGCHandle(utf8NamePinned))
NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutputToDevice(handle, pinnedName.Pointer, memInfo.Pointer));
}
///
/// Blocks until device completes all preceding requested tasks.
/// Useful for memory synchronization.
///
public void SynchronizeBoundOutputs()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundOutputs(handle));
}
///
/// Bind allocation obtained from an Ort allocator
///
/// name
/// data type
/// tensor shape
/// ort allocation
/// whether this is input or output
private void BindOrtAllocation(string name, Tensors.TensorElementType elementType, long[] shape,
OrtMemoryAllocation allocation, bool isInput)
{
using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info,
elementType,
shape,
allocation.Pointer, allocation.Size))
BindInputOrOutput(name, ortValue.Handle, isInput);
}
///
/// Bind external allocation as input or output.
/// The allocation is owned by the user code.
///
/// name
/// non ort allocated memory
/// whether this is an input or output
private void BindExternalAllocation(string name, OrtExternalAllocation allocation, bool isInput)
{
using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info,
allocation.ElementType,
allocation.Shape,
allocation.Pointer,
allocation.Size))
BindInputOrOutput(name, ortValue.Handle, isInput);
}
///
/// Internal helper
///
///
///
///
private void BindInputOrOutput(string name, IntPtr ortValue, bool isInput)
{
var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned);
using (var pinnedName = new PinnedGCHandle(utf8NamePinned))
{
if (isInput)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtBindInput(handle, pinnedName.Pointer, ortValue));
}
else
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutput(handle, pinnedName.Pointer, ortValue));
}
}
}
///
/// Returns an array of output names in the same order they were bound
///
/// array of output names
public string[] GetOutputNames()
{
IntPtr buffer = IntPtr.Zero;
IntPtr lengths = IntPtr.Zero;
UIntPtr count = UIntPtr.Zero;
var allocator = OrtAllocator.DefaultInstance;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputNames(handle, allocator.Pointer, out buffer, out lengths, out count));
if (count.Equals(UIntPtr.Zero))
{
return new string[0];
}
using (var bufferAllocation = new OrtMemoryAllocation(allocator, buffer, 0))
using (var lengthsAllocation = new OrtMemoryAllocation(allocator, lengths, 0))
{
int outputCount = (int)count;
var lens = new int[outputCount];
int totalLength = 0;
for (int i = 0; i < outputCount; ++i)
{
var len = (int)Marshal.ReadIntPtr(lengths, IntPtr.Size * i);
lens[i] = len;
totalLength += len;
}
var stringData = new byte[totalLength];
Marshal.Copy(buffer, stringData, 0, stringData.Length);
string[] result = new string[outputCount];
int readOffset = 0;
for (int i = 0; i < outputCount; ++i)
{
var strLen = lens[i];
result[i] = Encoding.UTF8.GetString(stringData, readOffset, strLen);
readOffset += strLen;
}
return result;
}
}
///
/// This fetches bound outputs after running the model with RunWithBinding()
///
/// IDisposableReadOnlyCollection
public IDisposableReadOnlyCollection GetOutputValues()
{
IntPtr ortValues = IntPtr.Zero;
UIntPtr count = UIntPtr.Zero;
var allocator = OrtAllocator.DefaultInstance;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputValues(handle, allocator.Pointer, out ortValues, out count));
if (count.Equals(UIntPtr.Zero))
{
return new DisposableList();
}
using (var ortValuesAllocation = new OrtMemoryAllocation(allocator, ortValues, 0))
{
int outputCount = (int)count;
var ortList = new DisposableList(outputCount);
try
{
for (int i = 0; i < outputCount; ++i)
{
IntPtr ortValue = Marshal.ReadIntPtr(ortValues, IntPtr.Size * i);
ortList.Add(new OrtValue(ortValue));
}
}
catch (Exception)
{
ortList.Dispose();
throw;
}
return ortList;
}
}
///
/// Clear all bound inputs and start anew
///
public void ClearBoundInputs()
{
NativeMethods.OrtClearBoundInputs(handle);
}
///
/// Clear all bound outputs
///
public void ClearBoundOutputs()
{
NativeMethods.OrtClearBoundOutputs(handle);
}
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtIoBidning
///
/// always returns true
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseIoBinding(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
}