<#@ template debug="false" hostspecific="false" language="C#" #> <#@ assembly name="System.Core" #> <#@ output extension=".cs" #> <#@ include file="TensorTemplate.ttinclude" #>// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // This file is copied and adapted from the following git repository - // https://github.com/dotnet/corefx // Commit ID: bdd0814360d4c3a58860919f292a306242f27da1 // Path: /src/System.Numerics.Tensors/tests/TensorArithmetic.cs // Original license statement below - // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; namespace Microsoft.ML.OnnxRuntime.Tensors { internal interface ITensorArithmetic { T One { get; } T Zero { get; } <# foreach (MethodConfiguration method in methodConfiguration) { #> <#= method.GetResultMethodSignature("Tensor", "T")#>; <# } #> } internal static class TensorArithmetic { public static ITensorArithmetic Instance => TensorArithmetic.GetArithmetic(); } internal static class TensorArithmetic { public static ITensorArithmetic GetArithmetic() { <# foreach (TypeConfiguration type in typeConfiguration) { #> <#=GenerateIfStatementHeader(type)#> { return (ITensorArithmetic)new <#=type.ClassPrefix#>Arithmetic(); } <# } #> return null; } } <# foreach (TypeConfiguration type in typeConfiguration) { #> internal class <#=type.ClassPrefix#>Arithmetic : ITensorArithmetic<<#=type.TypeName#>> { public <#=type.TypeName#> One => <#=type.OneLiteral#>; public <#=type.TypeName#> Zero => <#=type.ZeroLiteral#>; <# foreach (MethodConfiguration method in methodConfiguration) { #> public <#= method.GetResultMethodSignature("Tensor", type.TypeName)#> { <# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> throw new NotSupportedException(); <# } else if (method.Operator != null) { #> Span indices = new Span(new int[result.Rank]); for(int i = 0; i < <#= method.ResultName #>.Length; i++) { ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices); <#=method.GetElementOperation(type.TypeName, "[indices]")#>; } <# } else if (method.MethodName == "Contract") {#> var leftIndices = new int[left.Rank]; var rightIndices = new int[right.Rank]; var resultIndices = new int[result.Rank]; var summingDimensions = new int[leftAxes.Length]; for(int i = 0; i < leftAxes.Length; i++) { summingDimensions[i] = left.dimensions[leftAxes[i]]; } var summingStrides = ArrayUtilities.GetStrides(summingDimensions); int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); var resultStrides = result.strides; // translates from result index to left non-summing dimensions' index portion // since left non-summing dimensions are given precedence in result, the end is zero-padded int[] leftNonSummingStrides = new int[result.Rank]; // translates from summing index to left summing dimensions' index portion int[] leftSummingStrides = new int[leftAxes.Length]; ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); // translates from result index to right non-summing dimensions' index portion int[] rightNonSummingStrides = new int[result.Rank]; // right non-summing dimensions appear after left non-summing dimensions. int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); // translates from summing index to right summing dimensions' index portion int[] rightSummingStrides = new int[rightAxes.Length]; ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); for (int resultIndex = 0; resultIndex < result.Length; resultIndex++) { <#=type.TypeName#> sum = (<#=type.TypeName#>)0; ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices); int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) { int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); int leftIndex = leftIndexNonSumming + leftIndexSumming; int rightIndex = rightIndexNonSumming + rightIndexSumming; // todo, make this more efficient ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices); ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices); sum += (<#=type.TypeName#>)(left[leftIndices] * right[rightIndices]); } result[resultIndices] = sum; } <# } #> } <# } #> <# foreach (MethodConfiguration method in methodConfiguration) { #> public <#= method.GetResultMethodSignature("DenseTensor", type.TypeName)#> { <# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #> throw new NotSupportedException(); <# } else if (method.Operator != null) { #> <# if (method.MethodType == MethodType.UnaryInPlace) { #> var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; for(int i = 0; i < <#=method.ResultName #>Span.Length; i++) { <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; } <# } else {#> var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span; var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span; <# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> var <#=method.Op2Name #>Span = <#=method.Op2Name #>.Buffer.Span; <# } #> if <#= method.GetLinearOperationCheck() #> { for(int i = 0; i < <#= method.ResultName #>Span.Length; i++) { <#=method.GetElementOperation(type.TypeName, "Span[i]")#>; } } else { int rowMajorIndex = 0; int colMajorIndex = 0; ref int resultIndex = ref <#= method.ResultName #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; ref int op1Index = ref <#= method.Op1Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; <# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#> ref int op2Index = ref <#= method.Op2Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex; var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : !<#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : <#= method.Op2Name #>.strides; var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : <#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides : <#= method.Op2Name #>.strides; <# } else {#> var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : <#= method.Op1Name #>.strides; var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides : <#= method.Op1Name #>.strides; <# } #> for(;rowMajorIndex < <#= method.ResultName #>Span.Length; rowMajorIndex++) { colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides); <#=method.GetElementOperation(type.TypeName, "Span[resultIndex]", "Span[op1Index]", "Span[op2Index]")#>; } } <# } #> <# } else if (method.MethodName == "Contract") {#> var summingDimensions = new int[leftAxes.Length]; for(int i = 0; i < leftAxes.Length; i++) { summingDimensions[i] = left.dimensions[leftAxes[i]]; } var summingStrides = ArrayUtilities.GetStrides(summingDimensions); int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions); var resultStrides = result.strides; // translates from result index to left non-summing dimensions' index portion // since left non-summing dimensions are given precedence in result, the end is zero-padded int[] leftNonSummingStrides = new int[result.Rank]; // translates from summing index to left summing dimensions' index portion int[] leftSummingStrides = new int[leftAxes.Length]; ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0); // translates from result index to right non-summing dimensions' index portion int[] rightNonSummingStrides = new int[result.Rank]; // right non-summing dimensions appear after left non-summing dimensions. int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length); // translates from summing index to right summing dimensions' index portion int[] rightSummingStrides = new int[rightAxes.Length]; ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0); var resultSpan = result.Buffer.Span; var leftSpan = left.Buffer.Span; var rightSpan = right.Buffer.Span; for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++) { <#=type.TypeName#> sum = (<#=type.TypeName#>)0; int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides); int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides); for (int summingIndex = 0; summingIndex < summingLength; summingIndex++) { int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides); int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides); int leftIndex = leftIndexNonSumming + leftIndexSumming; int rightIndex = rightIndexNonSumming + rightIndexSumming; sum += (<#=type.TypeName#>)(leftSpan[leftIndex] * rightSpan[rightIndex]); } resultSpan[resultIndex] = sum; } <# } #> } <# } #> } <# } #> }