// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Win32.SafeHandles;
using System;
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
using static Microsoft.ML.OnnxRuntime.Tests.InferenceTest;
namespace Microsoft.ML.OnnxRuntime.Tests
{
public class OrtIoBindingAllocationTest
{
///
/// This works only for allocations accessible from host memory
///
///
///
private static void PopulateNativeBufferFloat(OrtMemoryAllocation buffer, float[] elements)
{
if (buffer.Size < elements.Length * sizeof(float))
{
Assert.True(false);
}
PopulateNativeBuffer(buffer.Pointer, elements);
}
private static void PopulateNativeBuffer(IntPtr buffer, float[] elements)
{
unsafe
{
float* p = (float*)buffer;
for (int i = 0; i < elements.Length; ++i)
{
*p++ = elements[i];
}
}
}
///
/// Use to free globally allocated memory
///
class OrtSafeMemoryHandle : SafeHandle
{
public OrtSafeMemoryHandle(IntPtr allocPtr) : base(allocPtr, true) { }
public override bool IsInvalid => handle == IntPtr.Zero;
protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
handle = IntPtr.Zero;
return true;
}
}
[Fact(DisplayName = "TestIOBindingWithOrtAllocation")]
public void TestIOBindingWithOrtAllocation()
{
var inputName = "data_0";
var outputName = "softmaxout_1";
var allocator = OrtAllocator.DefaultInstance;
// From the model
using (var dispList = new DisposableListTest())
{
var tuple = OpenSessionSqueezeNet();
var session = tuple.Item1;
var inputData = tuple.Item2;
var inputTensor = tuple.Item3;
var outputData = tuple.Item4;
dispList.Add(session);
var runOptions = new RunOptions();
dispList.Add(runOptions);
var inputMeta = session.InputMetadata;
var outputMeta = session.OutputMetadata;
var outputTensor = new DenseTensor(outputData, outputMeta[outputName].Dimensions);
var ioBinding = session.CreateIoBinding();
dispList.Add(ioBinding);
var ortAllocationInput = allocator.Allocate((uint)inputData.Length * sizeof(float));
dispList.Add(ortAllocationInput);
var inputShape = Array.ConvertAll(inputMeta[inputName].Dimensions, d => d);
var shapeSize = ArrayUtilities.GetSizeForShape(inputShape);
Assert.Equal(shapeSize, inputData.Length);
PopulateNativeBufferFloat(ortAllocationInput, inputData);
// Create an external allocation for testing OrtExternalAllocation
var cpuMemInfo = OrtMemoryInfo.DefaultInstance;
var sizeInBytes = shapeSize * sizeof(float);
IntPtr allocPtr = Marshal.AllocHGlobal((int)sizeInBytes);
dispList.Add(new OrtSafeMemoryHandle(allocPtr));
PopulateNativeBuffer(allocPtr, inputData);
var ortAllocationOutput = allocator.Allocate((uint)outputData.Length * sizeof(float));
dispList.Add(ortAllocationOutput);
var outputShape = Array.ConvertAll(outputMeta[outputName].Dimensions, i => i);
// Test 1. Bind the output to OrtAllocated buffer
using (FixedBufferOnnxValue fixedInputBuffer = FixedBufferOnnxValue.CreateFromTensor(inputTensor))
{
ioBinding.BindInput(inputName, fixedInputBuffer);
ioBinding.BindOutput(outputName, Tensors.TensorElementType.Float, outputShape, ortAllocationOutput);
ioBinding.SynchronizeBoundInputs();
using (var outputs = session.RunWithBindingAndNames(runOptions, ioBinding))
{
ioBinding.SynchronizeBoundOutputs();
Assert.Equal(1, outputs.Count);
var output = outputs.ElementAt(0);
Assert.Equal(outputName, output.Name);
var tensor = output.AsTensor();
Assert.True(tensor.IsFixedSize);
Assert.Equal(outputData, tensor.ToArray(), new FloatComparer());
}
}
// Test 2. Bind the input to memory allocation and output to a fixedBuffer
{
ioBinding.BindInput(inputName, Tensors.TensorElementType.Float, inputShape, ortAllocationInput);
ioBinding.BindOutput(outputName, Tensors.TensorElementType.Float, outputShape, ortAllocationOutput);
ioBinding.SynchronizeBoundInputs();
using (var outputs = session.RunWithBindingAndNames(runOptions, ioBinding))
{
ioBinding.SynchronizeBoundOutputs();
Assert.Equal(1, outputs.Count);
var output = outputs.ElementAt(0);
Assert.Equal(outputName, output.Name);
var tensor = output.AsTensor();
Assert.True(tensor.IsFixedSize);
Assert.Equal(outputData, tensor.ToArray(), new FloatComparer());
}
}
// 3. Test external allocation
{
var externalInputAllocation = new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.Float, allocPtr, sizeInBytes);
ioBinding.BindInput(inputName, externalInputAllocation);
ioBinding.BindOutput(outputName, Tensors.TensorElementType.Float, outputShape, ortAllocationOutput);
ioBinding.SynchronizeBoundInputs();
using (var outputs = session.RunWithBindingAndNames(runOptions, ioBinding))
{
ioBinding.SynchronizeBoundOutputs();
Assert.Equal(1, outputs.Count);
var output = outputs.ElementAt(0);
Assert.Equal(outputName, output.Name);
var tensor = output.AsTensor();
Assert.True(tensor.IsFixedSize);
Assert.Equal(outputData, tensor.ToArray(), new FloatComparer());
}
}
// 4. Some negative tests for external allocation
{
// Small buffer size
Action smallBuffer = delegate ()
{
new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.Float, allocPtr, sizeInBytes - 10);
};
Assert.Throws(smallBuffer);
Action stringType = delegate ()
{
new OrtExternalAllocation(cpuMemInfo, inputShape,
Tensors.TensorElementType.String, allocPtr, sizeInBytes);
};
Assert.Throws(stringType);
}
}
}
}
}