TrainingSession.shared.cs 22.4 KB
Newer Older
gaoqiong's avatar
gaoqiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
    enum LRScheduler
    {
        None = 0,
        Constant = 1,
        Linear = 2
    }
    /// <summary>
    /// Represents a Training Session on an ONNX Model.
    /// This is a IDisposable class and it must be disposed of
    /// using either a explicit call to Dispose() method or
    /// a pattern of using() block. If this is a member of another
    /// class that class must also become IDisposable and it must
    /// dispose of TrainingSession in its Dispose() method.
    /// </summary>
    public class TrainingSession : IDisposable
    {
        /// <summary>
        /// A pointer to a underlying native instance of OrtTrainingSession
        /// </summary>
        private IntPtr _nativeHandle;

        private ulong _trainOutputCount;
        private ulong _evalOutputCount;
        private List<string> _trainOutputNames;
        private List<string> _evalOutputNames;

        private SessionOptions _builtInSessionOptions = null;
        private RunOptions _builtInRunOptions = null;
        private LRScheduler _scheduler = LRScheduler.None;
        private bool _disposed = false;

    #region Public API

        /// <summary>
        /// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
        /// </summary>
        /// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
        /// <param name="trainModelPath">Specify path to training model graph.</param>
        /// <param name="evalModelPath">Specify path to eval model graph.</param>
        /// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
        public TrainingSession(CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
        {
            Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
        }

        /// <summary>
        /// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
        /// </summary>
        /// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
        /// <param name="trainModelPath">Specify path to training model graph.</param>
        /// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
        public TrainingSession(CheckpointState state, string trainModelPath, string optimizerModelPath)
        {
            Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
        }

        /// <summary>
        /// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
        /// </summary>
        /// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
        /// <param name="trainModelPath">Specify path to training model graph.</param>
        public TrainingSession(CheckpointState state, string trainModelPath)
        {
            Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, null);
        }


        /// <summary>
        /// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
        /// </summary>
        /// <param name="options">Session options</param>
        /// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
        /// <param name="trainModelPath">Specify path to training model graph.</param>
        /// <param name="evalModelPath">Specify path to eval model graph.</param>
        /// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
        public TrainingSession(SessionOptions options, CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
        {
            Init(options, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
        }

        /// <summary>
        /// Runs a train step on the loaded model for the given inputs.
        /// </summary>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
        public void TrainStep(
           IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
           IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
        {
            TrainStep(_builtInRunOptions, inputValues, outputValues);
        }

        /// <summary>
        /// Runs a train step on the loaded model for the given inputs. Uses the given RunOptions for this run.
        /// </summary>
        /// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
        public void TrainStep(
            RunOptions options,
            IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
            IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
        {
            if (_trainOutputCount!= (ulong)outputValues.Count())
            {
                throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
            }
            IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);

            IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
                inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
        }

        /// <summary>
        /// Runs the loaded model for the given inputs, and fetches the graph outputs.
        /// </summary>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
        public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
            IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
        {
            using (var ortValues = new DisposableList<OrtValue>((int)_trainOutputCount))
            {
                IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
                IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];

                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, _builtInRunOptions.Handle, (UIntPtr)inputValues.Count,
                    inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));
                foreach (var v in outputValuesArray)
                {
                    ortValues.Add(new OrtValue(v));
                }

                var result = new DisposableList<DisposableNamedOnnxValue>(_trainOutputNames.Count);
                try
                {
                    for (int i = 0; i < ortValues.Count; i++)
                    {
                        var ortValue = ortValues[i];
                        result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue));
                    }
                }
                catch (OnnxRuntimeException)
                {
                    result.Dispose();
                    throw;
                }
                return result;
            }
        }

        /// <summary>
        /// Runs the loaded model for the given inputs, and fetches the specified outputs in <paramref name="outputNames"/>. Uses the given RunOptions for this run.
        /// </summary>
        /// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
        public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
            RunOptions options,
            IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
        {
            using (var ortValues = new DisposableList<OrtValue>((int)_trainOutputCount))
            {
                IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
                IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];

                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
                    inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));
                foreach (var v in outputValuesArray)
                {
                    ortValues.Add(new OrtValue(v));
                }

                var result = new DisposableList<DisposableNamedOnnxValue>(_trainOutputNames.Count);
                try
                {
                    for (int i = 0; i < ortValues.Count; i++)
                    {
                        var ortValue = ortValues[i];
                        result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue));
                    }
                }
                catch (OnnxRuntimeException)
                {
                    result.Dispose();
                    throw;
                }
                return result;
            }
        }

        /// <summary>
        /// Sets the reset grad flag on the training graph. The gradient buffers will be reset while executing the
        /// next train step.
        /// </summary>
        public void ResetGrad()
        {
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtResetGrad(_nativeHandle));
        }

        /// <summary>
        /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
        /// </summary>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
        public void EvalStep(
            IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
            IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
        {
            EvalStep(_builtInRunOptions, inputValues, outputValues);
        }

        /// <summary>
        /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
        /// </summary>
        /// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
        /// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
        /// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
        public void EvalStep(
            RunOptions options,
            IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
            IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
        {
            if (!_evalOutputCount.Equals(outputValues.Count))
            {
                throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
            }
            IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);

            IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
                inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
        }


        /// <summary>
        /// Sets a constant learning rate for the session. LR must be controlled by either this method
        /// or by registering a LR scheduler.
        /// </summary>
        public void SetLearningRate(float learningRate)
        {
            if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant)
            {
                throw new InvalidOperationException("Cannot set constant LR while using LR scheduler.");
            }
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetLearningRate(_nativeHandle, learningRate));
            _scheduler = LRScheduler.Constant;
        }

        /// <summary>
        /// Gets the current learning rate for the session.
        /// </summary>
        public float GetLearningRate()
        {
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetLearningRate(_nativeHandle, out float lr));
            return lr;
        }

        /// <summary>
        /// Registers a linear learning rate scheduler for the session. LR must be controlled by either
        /// the SetLearningRate method or by registering a LR scheduler.
        /// <param name="warmupStepCount"> Number of warmup steps</param>
        /// <param name="totalStepCount"> Number of total steps</param>
        /// <param name="initialLearningRate"> Initial learning rate</param>
        /// </summary>
        public void RegisterLinearLRScheduler(long warmupStepCount,
                                              long totalStepCount,
                                              float initialLearningRate)
        {
            if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant)
            {
                throw new InvalidOperationException("Cannot set LR scheduler while using constant LR.");
            }

            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtRegisterLinearLRScheduler(_nativeHandle, warmupStepCount,totalStepCount, initialLearningRate));
            _scheduler = LRScheduler.Linear;
        }

        /// <summary>
        /// Runs a LR scheduler step. There must be a valid LR scheduler registered for the training session.
        /// </summary>
        public void SchedulerStep()
        {
            if (_scheduler == LRScheduler.Constant || _scheduler == LRScheduler.None)
            {
                throw new InvalidOperationException("Cannot take scheduler step without registering a valid LR scheduler.");
            }
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSchedulerStep(_nativeHandle));
        }

        /// <summary>
        /// Runs an optimizer step on the loaded model for the given inputs. The optimizer graph must be passed while TrainingSession creation.
        /// </summary>
        public void OptimizerStep()
        {
            OptimizerStep(_builtInRunOptions);
        }

        /// <summary>
        /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
        /// </summary>
        /// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
        /// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
        public void OptimizerStep(RunOptions options)
        {
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtOptimizerStep(_nativeHandle, options.Handle));

        }

        /// <summary>
        /// Saves a checkpoint to path. It can be loaded into <see cref="CheckpointState"/>
        /// </summary>
        /// <param name="path">Specify path for saving the checkpoint.</param>
        /// <param name="saveOptimizerState">SFlag indicating whether to save optimizer state or not.</param>
        public void SaveCheckpoint(string path, bool saveOptimizerState = false)
        {
            NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(path), _nativeHandle, saveOptimizerState));
        }

    #endregion
    #region private methods

        private void Init(SessionOptions sessOptions, CheckpointState state, byte[] trainModelPath, byte[] evalModelPath, byte[] optimizerModelPath)
        {
            if (!NativeTrainingMethods.TrainingEnabled())
            {
                throw new InvalidOperationException("Training is disabled in the current build.");
            }
            var options = sessOptions;
            if (sessOptions == null)
            {
                _builtInSessionOptions = new SessionOptions();
                options = _builtInSessionOptions;
            }
            var envHandle = OrtEnv.Handle;
            try
            {
                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCreateTrainingSession(envHandle, options.Handle, state.Handle, trainModelPath,
                                                                                     evalModelPath, optimizerModelPath, out _nativeHandle));

                UIntPtr outputCount = UIntPtr.Zero;
                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputCount(_nativeHandle, out outputCount));
                _trainOutputCount = outputCount.ToUInt64();

                // get all the output names and metadata
                _trainOutputNames = new List<string>();
                for (ulong i = 0; i < _trainOutputCount; i++)
                {
                    _trainOutputNames.Add(GetOutputName(i, true));
                }

                if (evalModelPath != null)
                {
                    outputCount = UIntPtr.Zero;
                    NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputCount(_nativeHandle, out outputCount));
                    _evalOutputCount = outputCount.ToUInt64();
                    _evalOutputNames = new List<string>();
                    for (ulong i = 0; i < _evalOutputCount; i++)
                    {
                        _evalOutputNames.Add(GetOutputName(i, false));
                    }
                }

                _builtInRunOptions = new RunOptions();  // create a default built-in run option, and avoid creating a new one every run() call
            }
            catch (Exception)
            {
                CleanupHelper(true);
                throw;
            }
        }

        private string GetOutputName(ulong index, bool training)
        {
            var allocator = OrtAllocator.DefaultInstance;
            IntPtr nameHandle;
            string str = null;
            if (training)
            { 
                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputName(
                                           _nativeHandle,
                                           (UIntPtr)index,
                                           allocator.Pointer,
                                           out nameHandle));
            } 
            else
            { 
                NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputName(
                                           _nativeHandle,
                                           (UIntPtr)index,
                                           allocator.Pointer,
                                           out nameHandle));
            }

            using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
            {
                str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
            }

            return str;
        }

        private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> values, bool input)
        {
            var valuesArray = new IntPtr[values.Count];
            for (int index = 0; index < values.Count; ++index)
            {
                var v = values.ElementAt(index);
                if (!input && v.ElementType == Tensors.TensorElementType.String)
                {
                    throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported.");
                }
                valuesArray[index] = v.Value.Handle;
            }
            return valuesArray;
        }

        /// <summary>
        /// Other classes access
        /// </summary>
        internal IntPtr Handle
        {
            get
            {
                return _nativeHandle;
            }
        }

    #endregion

    #region IDisposable

        /// <summary>
        /// Finalizer.
        /// </summary>
        ~TrainingSession()
        {
            Dispose(false);
        }

        /// <summary>
        /// IDisposable implementation
        /// </summary>
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        /// <summary>
        /// IDisposable implementation
        /// </summary>
        /// <param name="disposing">true if invoked from Dispose() method</param>
        protected virtual void Dispose(bool disposing)
        {
            if (_disposed)
            {
                return;
            }
            CleanupHelper(disposing);
            _disposed = true;
        }

        private void CleanupHelper(bool disposing)
        {
            if (disposing)
            {
                if (_builtInRunOptions != null)
                {
                    _builtInRunOptions.Dispose();
                    _builtInRunOptions = null;
                }

                if (_builtInSessionOptions != null)
                {
                    _builtInSessionOptions.Dispose();
                    _builtInSessionOptions = null;
                }
            }

            // cleanup unmanaged resources
            if (_nativeHandle != IntPtr.Zero)
            {
                NativeTrainingMethods.OrtReleaseTrainingSession(_nativeHandle);
                _nativeHandle = IntPtr.Zero;
            }
        }

        #endregion
    }
#endif
}