// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
// NOTE: The order of the APIs in this struct should match exactly that in
// OrtTrainingApi (onnxruntime_training_c_api.cc)
[StructLayout(LayoutKind.Sequential)]
public struct OrtTrainingApi
{
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
public IntPtr TrainingSessionGetEvalModelOutputName;
public IntPtr ResetGrad;
public IntPtr TrainStep;
public IntPtr EvalStep;
public IntPtr SetLearningRate;
public IntPtr GetLearningRate;
public IntPtr OptimizerStep;
public IntPtr RegisterLinearLRScheduler;
public IntPtr SchedulerStep;
public IntPtr GetParametersSize;
public IntPtr CopyParametersToBuffer;
public IntPtr CopyBufferToParameters;
public IntPtr ReleaseTrainingSession;
public IntPtr ReleaseCheckpointState;
public IntPtr ExportModelForInferencing;
}
internal static class NativeTrainingMethods
{
static OrtApi api_;
static OrtTrainingApi trainingApi_;
static IntPtr trainingApiPtr;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate ref OrtApi DOrtGetApi(UInt32 version);
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version);
public static DOrtGetTrainingApi OrtGetTrainingApi;
static NativeTrainingMethods()
{
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(4 /*ORT_API_VERSION*/);
OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(4 /*ORT_API_VERSION*/);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
OrtLoadCheckpoint = (DOrtLoadCheckpoint)Marshal.GetDelegateForFunctionPointer(trainingApi_.LoadCheckpoint, typeof(DOrtLoadCheckpoint));
OrtSaveCheckpoint = (DOrtSaveCheckpoint)Marshal.GetDelegateForFunctionPointer(trainingApi_.SaveCheckpoint, typeof(DOrtSaveCheckpoint));
OrtCreateTrainingSession = (DOrtCreateTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.CreateTrainingSession, typeof(DOrtCreateTrainingSession));
OrtGetTrainingModelOutputCount = (DOrtGetTrainingModelOutputCount)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetTrainingModelOutputCount, typeof(DOrtGetTrainingModelOutputCount));
OrtGetEvalModelOutputCount = (DOrtGetEvalModelOutputCount)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelOutputCount, typeof(DOrtGetEvalModelOutputCount));
OrtGetTrainingModelOutputName = (DOrtGetTrainingModelOutputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetTrainingModelOutputName, typeof(DOrtGetTrainingModelOutputName));
OrtGetEvalModelOutputName = (DOrtGetEvalModelOutputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelOutputName, typeof(DOrtGetEvalModelOutputName));
OrtResetGrad = (DOrtResetGrad)Marshal.GetDelegateForFunctionPointer(trainingApi_.ResetGrad, typeof(DOrtResetGrad));
OrtTrainStep = (DOrtTrainStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainStep, typeof(DOrtTrainStep));
OrtEvalStep = (DOrtEvalStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.EvalStep, typeof(DOrtEvalStep));
OrtSetLearningRate = (DOrtSetLearningRate)Marshal.GetDelegateForFunctionPointer(trainingApi_.SetLearningRate, typeof(DOrtSetLearningRate));
OrtGetLearningRate = (DOrtGetLearningRate)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetLearningRate, typeof(DOrtGetLearningRate));
OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep));
OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler));
OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep));
OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession));
OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState));
}
}
#region TrainingSession API
///
/// Creates an instance of OrtSession with provided parameters
///
/// checkpoint string path
/// (Output) Loaded OrtCheckpointState instance
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtLoadCheckpoint(
byte[] checkpointPath,
out IntPtr /* (OrtCheckpointState**) */ checkpointState);
public static DOrtLoadCheckpoint OrtLoadCheckpoint;
///
/// Creates an instance of OrtSession with provided parameters
///
/// checkpoint string path
/// (Output) Loaded OrtCheckpointState instance
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtSaveCheckpoint(
byte[] checkpointPath,
IntPtr /*(OrtTrainingSession*)*/ session,
bool saveOptimizerState);
public static DOrtSaveCheckpoint OrtSaveCheckpoint;
///
/// Creates an instance of OrtSession with provided parameters
///
/// Native OrtEnv instance
/// Native SessionOptions instance
/// Loaded OrtCheckpointState instance
/// model string path
/// model string path
/// model string path
/// (Output) Created native OrtTrainingSession instance
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtCreateTrainingSession(
IntPtr /* (OrtEnv*) */ environment,
IntPtr /* (OrtSessionOptions*) */ sessionOptions,
IntPtr /* (OrtCheckpointState*) */ checkpointState,
byte[] trainModelPath,
byte[] evalModelPath,
byte[] optimizerModelPath,
out IntPtr /* (OrtTrainingSession**) */ session);
public static DOrtCreateTrainingSession OrtCreateTrainingSession;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTrainingModelOutputCount(
IntPtr /*(OrtSession*)*/ session,
out UIntPtr count);
public static DOrtGetTrainingModelOutputCount OrtGetTrainingModelOutputCount;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetEvalModelOutputCount(
IntPtr /*(OrtSession*)*/ session,
out UIntPtr count);
public static DOrtGetEvalModelOutputCount OrtGetEvalModelOutputCount;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTrainingModelOutputName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
public static DOrtGetTrainingModelOutputName OrtGetTrainingModelOutputName;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetEvalModelOutputName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
public static DOrtGetEvalModelOutputName OrtGetEvalModelOutputName;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtResetGrad(
IntPtr /*(OrtSession*)*/ session);
public static DOrtResetGrad OrtResetGrad;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtTrainStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
UIntPtr inputCount,
IntPtr[] /* (OrtValue*[])*/ inputValues,
UIntPtr outputCount,
IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */
);
public static DOrtTrainStep OrtTrainStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtEvalStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
UIntPtr inputCount,
IntPtr[] /* (OrtValue*[])*/ inputValues,
UIntPtr outputCount,
IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */
);
public static DOrtEvalStep OrtEvalStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtOptimizerStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions // can be null to use the default options
);
public static DOrtOptimizerStep OrtOptimizerStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtSetLearningRate(
IntPtr /*(OrtTrainingSession*)*/ session,
float learningRate
);
public static DOrtSetLearningRate OrtSetLearningRate;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtGetLearningRate(
IntPtr /*(OrtTrainingSession*)*/ session,
out float learningRate
);
public static DOrtGetLearningRate OrtGetLearningRate;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtRegisterLinearLRScheduler(
IntPtr /*(OrtTrainingSession*)*/ session,
long warmupStepCount,
long totalStepCount,
float learningRate
);
public static DOrtRegisterLinearLRScheduler OrtRegisterLinearLRScheduler;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtSchedulerStep(
IntPtr /*(OrtTrainingSession*)*/ session
);
public static DOrtSchedulerStep OrtSchedulerStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtSession*)*/session);
public static DOrtReleaseTrainingSession OrtReleaseTrainingSession;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseCheckpointState(IntPtr /*(OrtSession*)*/session);
public static DOrtReleaseCheckpointState OrtReleaseCheckpointState;
#endregion TrainingSession API
public static bool TrainingEnabled()
{
if (trainingApiPtr == IntPtr.Zero)
{
return false;
}
return true;
}
} //class NativeTrainingMethods
#endif
} //namespace