Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
onnxruntime_v14
Commits
1a91fcc2
Commit
1a91fcc2
authored
Jul 25, 2023
by
gaoqiong
Browse files
add dtk所需文件
parent
a144865d
Pipeline
#492
failed with stages
in 0 seconds
Changes
280
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2520 additions
and
0 deletions
+2520
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/activation/activations_impl.cu
...untime/core/providers/rocm/activation/activations_impl.cu
+112
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/activation/activations_impl.h
...runtime/core/providers/rocm/activation/activations_impl.h
+62
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/binary_elementwise_impl.cuh
...me/core/providers/rocm/cu_inc/binary_elementwise_impl.cuh
+315
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/bitmask.cuh
...amdgpu/onnxruntime/core/providers/rocm/cu_inc/bitmask.cuh
+93
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/elementwise_impl.cuh
...nxruntime/core/providers/rocm/cu_inc/elementwise_impl.cuh
+53
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/unary_elementwise_impl.cuh
...ime/core/providers/rocm/cu_inc/unary_elementwise_impl.cuh
+61
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/variadic_elementwise_impl.cuh
.../core/providers/rocm/cu_inc/variadic_elementwise_impl.cuh
+78
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/constant_of_shape.cc
...untime/core/providers/rocm/generator/constant_of_shape.cc
+50
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/constant_of_shape.h
...runtime/core/providers/rocm/generator/constant_of_shape.h
+23
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random.cc
...mdgpu/onnxruntime/core/providers/rocm/generator/random.cc
+109
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random.h
...amdgpu/onnxruntime/core/providers/rocm/generator/random.h
+131
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random_impl.cu
.../onnxruntime/core/providers/rocm/generator/random_impl.cu
+145
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random_impl.h
...u/onnxruntime/core/providers/rocm/generator/random_impl.h
+22
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range.cc
...amdgpu/onnxruntime/core/providers/rocm/generator/range.cc
+104
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range.h
.../amdgpu/onnxruntime/core/providers/rocm/generator/range.h
+19
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range_impl.cu
...u/onnxruntime/core/providers/rocm/generator/range_impl.cu
+43
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range_impl.h
...pu/onnxruntime/core/providers/rocm/generator/range_impl.h
+15
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc
...untime/core/providers/rocm/math/binary_elementwise_ops.cc
+608
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.h
...runtime/core/providers/rocm/math/binary_elementwise_ops.h
+290
-0
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu
...e/core/providers/rocm/math/binary_elementwise_ops_impl.cu
+187
-0
No files found.
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/activation/activations_impl.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
T
>
struct
OP_Elu
:
public
CtxElu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
>
(
T
)
0
?
a
:
(
T
)
alpha
*
(
_Exp
(
a
)
-
(
T
)
1
);
}
};
template
<
typename
T
>
struct
OP_HardSigmoid
:
public
CtxHardSigmoid
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
_Max
(
_Min
((
T
)
alpha
*
a
+
(
T
)
beta
,
(
T
)
1
),
(
T
)
0
);
}
};
template
<
typename
T
>
struct
OP_LeakyRelu
:
public
CtxLeakyRelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
>
(
T
)
0
?
a
:
(
T
)
alpha
*
a
;
}
};
template
<
typename
T
>
struct
OP_Relu
:
public
CtxRelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
_Max
(
a
,
(
T
)
0
);
}
};
template
<
typename
T
>
struct
OP_Selu
:
public
CtxSelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
>
(
T
)
0
?
(
T
)
gamma
*
a
:
(
T
)
gamma
*
(
T
)
alpha
*
(
_Exp
(
a
)
-
(
T
)
1
);
}
};
template
<
typename
T
>
struct
OP_Sigmoid
:
public
CtxSigmoid
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
>
T
(
0
)
?
(
T
)
1
/
((
T
)
1.
+
_Exp
(
-
_Abs
(
a
)))
:
(
T
)
1
-
(
T
)
1
/
((
T
)
1
+
_Exp
(
-
_Abs
(
a
)));
}
};
template
<
typename
T
>
struct
OP_Softplus
:
public
CtxSoftplus
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
if
(
a
>
(
T
)
0
)
return
a
+
_Log
(
_Exp
(
-
a
)
+
(
T
)
1
);
else
return
_Log
(
_Exp
(
a
)
+
(
T
)
1
);
}
};
template
<
typename
T
>
struct
OP_Softsign
:
public
CtxSoftsign
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
/
((
T
)
1.
+
_Abs
(
a
));
}
};
template
<
typename
T
>
struct
OP_Tanh
:
public
CtxTanh
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
_Tanh
(
a
);
}
};
template
<
typename
T
>
struct
OP_ThresholdedRelu
:
public
CtxThresholdedRelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
>
(
T
)
alpha
?
a
:
(
T
)
0
;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, \
size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, BFloat16)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/activation/activations_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace
onnxruntime
{
namespace
rocm
{
struct
CtxAlpha
{
float
alpha
;
};
struct
CtxAlphaBeta
{
float
alpha
;
float
beta
;
};
struct
CtxAlphaGamma
{
float
alpha
;
float
gamma
;
};
struct
CtxNull
{
};
typedef
CtxAlpha
CtxElu
;
typedef
CtxAlphaBeta
CtxHardSigmoid
;
typedef
CtxAlpha
CtxLeakyRelu
;
typedef
CtxNull
CtxRelu
;
typedef
CtxAlphaGamma
CtxSelu
;
typedef
CtxNull
CtxSigmoid
;
typedef
CtxNull
CtxSoftplus
;
typedef
CtxNull
CtxSoftsign
;
typedef
CtxNull
CtxTanh
;
typedef
CtxAlpha
CtxThresholdedRelu
;
#define UNARY_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(Elu) \
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
UNARY_ACTIVATION_OP_NAME(Relu) \
UNARY_ACTIVATION_OP_NAME(Selu) \
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
UNARY_ACTIVATION_OP_NAME(Softplus) \
UNARY_ACTIVATION_OP_NAME(Softsign) \
UNARY_ACTIVATION_OP_NAME(Tanh) \
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
hipStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \
size_t count)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/binary_elementwise_impl.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace
onnxruntime
{
namespace
rocm
{
// broadcast by computing output coordinate from offset, using fast_divmod
template
<
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
,
bool
lhs_need_compute
,
bool
rhs_need_compute
,
int
NumThreadsPerBlock
,
int
NumElementsPerThread
>
__global__
void
_BinaryElementWise
(
int32_t
output_rank
,
const
TArray
<
int64_t
>
lhs_padded_strides
,
const
T1
*
lhs_data
,
const
TArray
<
int64_t
>
rhs_padded_strides
,
const
T2
*
rhs_data
,
const
TArray
<
fast_divmod
>
fdm_output_strides
,
T
*
output_data
,
const
FuncT
&
functor
,
HIP_LONG
N
)
{
HIP_LONG
start
=
NumElementsPerThread
*
NumThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
T1
lvalue
[
NumElementsPerThread
];
T2
rvalue
[
NumElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
HIP_LONG
lhs_index
=
(
lhs_need_compute
?
0
:
id
);
HIP_LONG
rhs_index
=
(
rhs_need_compute
?
0
:
id
);
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
HIP_LONG
offset
=
id
;
#pragma unroll
for
(
auto
dim
=
0
;
dim
<
fdm_output_strides
.
Capacity
();
dim
++
)
{
if
(
dim
>=
output_rank
)
{
break
;
}
int
q
,
r
;
fdm_output_strides
[
dim
].
divmod
(
offset
,
q
,
r
);
if
(
lhs_need_compute
)
{
lhs_index
+=
static_cast
<
int
>
(
lhs_padded_strides
[
dim
])
*
q
;
}
if
(
rhs_need_compute
)
{
rhs_index
+=
static_cast
<
int
>
(
rhs_padded_strides
[
dim
])
*
q
;
}
offset
=
r
;
}
lvalue
[
i
]
=
lhs_data
[
lhs_index
];
rvalue
[
i
]
=
rhs_data
[
rhs_index
];
id
+=
NumThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
functor
(
lvalue
[
i
],
rvalue
[
i
]);
id
+=
NumThreadsPerBlock
;
}
}
}
// for scalar broadcast or non-broadcast case
template
<
bool
IncL
,
bool
IncR
,
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
,
int
NumThreadsPerBlock
,
int
NumElementsPerThread
>
__global__
void
_BinaryElementWiseSimple
(
const
T1
*
lhs_data
,
const
T2
*
rhs_data
,
T
*
output_data
,
const
FuncT
func
,
HIP_LONG
N
)
{
HIP_LONG
start
=
NumElementsPerThread
*
NumThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
T1
lvalue
[
NumElementsPerThread
];
T2
rvalue
[
NumElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
lvalue
[
i
]
=
lhs_data
[
IncL
?
id
:
0
];
rvalue
[
i
]
=
rhs_data
[
IncR
?
id
:
0
];
id
+=
NumThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
func
(
lvalue
[
i
],
rvalue
[
i
]);
id
+=
NumThreadsPerBlock
;
}
}
}
// for rhs per-channel broadcast case
template
<
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
,
int
NumThreadsPerBlock
,
int
NumElementsPerThread
>
__global__
void
_BinaryElementWiseRhsPerChannelBatch1
(
const
T1
*
lhs_data
,
const
T2
*
rhs_data
,
const
fast_divmod
fdm_H
,
T
*
output_data
,
FuncT
func
,
HIP_LONG
N
)
{
HIP_LONG
start
=
NumElementsPerThread
*
NumThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
T1
lvalue
[
NumElementsPerThread
];
T2
rvalue
[
NumElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
HIP_LONG
rhs_id
=
fdm_H
.
div
(
id
);
lvalue
[
i
]
=
lhs_data
[
id
];
rvalue
[
i
]
=
rhs_data
[
rhs_id
];
id
+=
NumThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
func
(
lvalue
[
i
],
rvalue
[
i
]);
id
+=
NumThreadsPerBlock
;
}
}
}
template
<
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
,
int
NumThreadsPerBlock
,
int
NumElementsPerThread
>
__global__
void
_BinaryElementWiseRhsPerChannelBatchN
(
const
T1
*
lhs_data
,
const
T2
*
rhs_data
,
const
fast_divmod
fdm_H
,
const
fast_divmod
fdm_C
,
T
*
output_data
,
FuncT
func
,
HIP_LONG
N
)
{
HIP_LONG
start
=
NumElementsPerThread
*
NumThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
T1
lvalue
[
NumElementsPerThread
];
T2
rvalue
[
NumElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
HIP_LONG
rhs_id
=
fdm_H
.
div
(
id
);
int
q
,
r
;
fdm_C
.
divmod
(
rhs_id
,
q
,
r
);
rhs_id
=
r
;
lvalue
[
i
]
=
lhs_data
[
id
];
rvalue
[
i
]
=
rhs_data
[
rhs_id
];
id
+=
NumThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
func
(
lvalue
[
i
],
rvalue
[
i
]);
id
+=
NumThreadsPerBlock
;
}
}
}
template
<
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
>
void
BinaryElementWiseNoBroadcastImpl
(
hipStream_t
stream
,
const
T1
*
lhs_data
,
const
T2
*
rhs_data
,
T
*
output_data
,
const
FuncT
&
func
,
size_t
count
)
{
if
(
count
==
0
)
// special case where there's a dim value of 0 in the output shape
return
;
#ifdef USE_ROCM
const
int
num_elements_per_thread
=
2
;
const
int
num_threads_per_block
=
512
;
#else
const
int
num_elements_per_thread
=
GridDim
::
maxElementsPerThread
;
const
int
num_threads_per_block
=
GridDim
::
maxThreadsPerBlock
;
#endif
int
blocksPerGrid
=
static_cast
<
int
>
(
CeilDiv
(
count
,
num_threads_per_block
*
num_elements_per_thread
));
HIP_LONG
N
=
static_cast
<
HIP_LONG
>
(
count
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseSimple
<
true
,
true
,
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
output_data
,
func
,
N
);
}
template
<
typename
T
,
typename
T1
,
typename
T2
,
typename
FuncT
>
void
BinaryElementWiseImpl
(
hipStream_t
stream
,
int32_t
output_rank_or_simple_broadcast
,
const
TArray
<
int64_t
>*
lhs_padded_strides
,
const
T1
*
lhs_data
,
const
TArray
<
int64_t
>*
rhs_padded_strides
,
const
T2
*
rhs_data
,
const
TArray
<
fast_divmod
>*
fdm_output_strides
,
const
fast_divmod
&
fdm_H
,
const
fast_divmod
&
fdm_C
,
T
*
output_data
,
const
FuncT
&
func
,
size_t
count
)
{
if
(
count
==
0
)
// special case where there's a dim value of 0 in the output shape
return
;
#ifdef USE_ROCM
const
int
num_elements_per_thread
=
2
;
const
int
num_threads_per_block
=
512
;
#else
const
int
num_elements_per_thread
=
GridDim
::
maxElementsPerThread
;
const
int
num_threads_per_block
=
GridDim
::
maxThreadsPerBlock
;
#endif
int
blocksPerGrid
=
static_cast
<
int
>
(
CeilDiv
(
count
,
num_threads_per_block
*
num_elements_per_thread
));
HIP_LONG
N
=
static_cast
<
HIP_LONG
>
(
count
);
if
(
output_rank_or_simple_broadcast
==
static_cast
<
int32_t
>
(
SimpleBroadcast
::
NoBroadcast
))
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseSimple
<
true
,
true
,
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
output_data
,
func
,
N
);
}
else
if
(
output_rank_or_simple_broadcast
==
static_cast
<
int32_t
>
(
SimpleBroadcast
::
LeftScalar
))
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseSimple
<
false
,
true
,
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
output_data
,
func
,
N
);
}
else
if
(
output_rank_or_simple_broadcast
==
static_cast
<
int32_t
>
(
SimpleBroadcast
::
RightScalar
))
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseSimple
<
true
,
false
,
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
output_data
,
func
,
N
);
}
else
if
(
output_rank_or_simple_broadcast
==
static_cast
<
int32_t
>
(
SimpleBroadcast
::
RightPerChannelBatch1
))
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseRhsPerChannelBatch1
<
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
fdm_H
,
output_data
,
func
,
N
);
}
else
if
(
output_rank_or_simple_broadcast
==
static_cast
<
int32_t
>
(
SimpleBroadcast
::
RightPerChannelBatchN
))
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWiseRhsPerChannelBatchN
<
T
,
T1
,
T2
,
FuncT
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
lhs_data
,
rhs_data
,
fdm_H
,
fdm_C
,
output_data
,
func
,
N
);
}
else
{
if
(
lhs_padded_strides
&&
rhs_padded_strides
&&
lhs_padded_strides
->
Size
()
&&
rhs_padded_strides
->
Size
())
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWise
<
T
,
T1
,
T2
,
FuncT
,
true
,
true
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
output_rank_or_simple_broadcast
,
*
lhs_padded_strides
,
lhs_data
,
*
rhs_padded_strides
,
rhs_data
,
*
fdm_output_strides
,
output_data
,
func
,
N
);
else
if
(
lhs_padded_strides
&&
lhs_padded_strides
->
Size
())
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWise
<
T
,
T1
,
T2
,
FuncT
,
true
,
false
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
output_rank_or_simple_broadcast
,
*
lhs_padded_strides
,
lhs_data
,
TArray
<
int64_t
>
(),
// rhs is not computed, so no need to deference rhs_padded_strides
rhs_data
,
*
fdm_output_strides
,
output_data
,
func
,
N
);
else
if
(
rhs_padded_strides
&&
rhs_padded_strides
->
Size
())
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_BinaryElementWise
<
T
,
T1
,
T2
,
FuncT
,
false
,
true
,
num_threads_per_block
,
num_elements_per_thread
>
),
blocksPerGrid
,
num_threads_per_block
,
0
,
stream
,
output_rank_or_simple_broadcast
,
TArray
<
int64_t
>
(),
// lhs is not computed, so no need to deference lhs_padded_strides
lhs_data
,
*
rhs_padded_strides
,
rhs_data
,
*
fdm_output_strides
,
output_data
,
func
,
N
);
}
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/bitmask.cuh
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
/**
* These functions MUST be called with an unroll factor that evenly divides the number of threads in a warp (32 for
* ROCM, 64 for ROCm). In addition, this kernel MUST be launched with a number of threads in a thread block which is
* evenly divisible by the number of threads in a warp.
*
* Take unroll factor of 4 and 32 threads in a warp as example, we take the following approach (for threads in the first
* warp, that is):
*
* Thread 0 generates output booleans 0-3
* Thread 1 generates output booleans 4-7
* ...
* Thread 7 generates output booleans 28-31
*
* These threads all agree on the same thread mask by determining which output bitmask index they want to write to.
* Threads 0-7 will generate the same thread mask (for output index 0), threads 8-15 will generate the same thread mask
* (for output index 1), and so on.
*
* After (partially before) agreeing upon which threads will collaborate to write out a single index,
* each thread generates 4 random values, and shifts them into the right location in the output uint32_t.
* For instance:
*
* Thread 0 will perform a shift of 0
* Thread 1 will perform a shift of 4
* Thread 2 will perform a shift of 8
* ...
*
* For index 0, this gives us the following composition of random bits (number represents which thread generated it):
*
* 77776666555544443333222211110000
*
* After each thread shifts its bits into the right location, we broadcast the reduced value to all threads. Finally,
* we just choose a single thread (in our case, we choose the thread with 0 shift, but any thread from 0-7 would work
* for the 0-7 group).
*
* Keep in mind that this must not be conditionally called, as all threads in the warp (that haven't already exited)
* must reach these function calls an equal number of times, otherwise the code execution is likely to hang or produce
* unintended side effects.
*
* We conditionally update the local thread's mask (with the "li < N" check), but all active threads always collaborate
* on the reduced value.
*/
namespace
onnxruntime
{
namespace
rocm
{
template
<
int
NumUnroll
>
__device__
__forceinline__
void
SetBitmask
(
const
HIP_LONG
id
,
const
HIP_LONG
mask_element_count
,
const
fast_divmod
fdm_bits_per_element
,
BitmaskElementType
thread_bitmask
,
BitmaskElementType
*
mask_data
)
{
int
bitmask_idx
,
bitmask_shift
;
fdm_bits_per_element
.
divmod
(
id
,
bitmask_idx
,
bitmask_shift
);
BitmaskElementType
bitmask
=
(
thread_bitmask
<<
bitmask_shift
);
#if defined(USE_ROCM) && __CUDA_ARCH__ >= 800
// All thread which intend to write to the same output index will have the same thread mask.
BitmaskElementType
thread_mask
=
__match_any_sync
(
0xFFFFFFFF
,
bitmask_idx
);
// All threads with the same thread mask (threads which intend to write to the same output index) collaborate
// on a bitwise-or reduction.
bitmask
=
__reduce_or_sync
(
thread_mask
,
bitmask
);
#else
#pragma unroll
for
(
int
stride
=
kNumBitsPerBitmaskElement
/
(
NumUnroll
*
2
);
stride
>
0
;
stride
/=
2
)
{
bitmask
|=
WARP_SHFL_DOWN
(
bitmask
,
stride
);
}
#endif
// Choose a single from the "thread mask" group to perform the output write.
if
(
bitmask_shift
==
0
&&
bitmask_idx
<
mask_element_count
)
{
mask_data
[
bitmask_idx
]
=
bitmask
;
}
}
template
<
int
NumUnroll
>
__device__
__forceinline__
void
GetMasks
(
HIP_LONG
id
,
const
fast_divmod
fdm_bits_per_element
,
const
BitmaskElementType
*
mask_data
,
bool
*
mask_result
)
{
int
bitmask_idx
,
bitmask_shift
;
fdm_bits_per_element
.
divmod
(
id
,
bitmask_idx
,
bitmask_shift
);
BitmaskElementType
shifted_mask
=
mask_data
[
bitmask_idx
]
>>
bitmask_shift
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumUnroll
;
i
++
)
{
mask_result
[
i
]
=
(
shifted_mask
&
(
1
<<
i
))
!=
0
;
}
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/elementwise_impl.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
namespace
onnxruntime
{
namespace
rocm
{
#ifdef USE_ROCM
constexpr
int
kElementsPerThread
=
2
;
constexpr
int
kThreadsPerBlock
=
512
;
#else
constexpr
int
kElementsPerThread
=
GridDim
::
maxElementsPerThread
;
constexpr
int
kThreadsPerBlock
=
GridDim
::
maxThreadsPerBlock
;
#endif
template
<
typename
T
,
typename
FuncT
>
__global__
void
ElementwiseKernel
(
T
*
output_data
,
const
FuncT
functor
,
HIP_LONG
N
)
{
HIP_LONG
start
=
kElementsPerThread
*
kThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
T
value
[
kElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kElementsPerThread
;
++
i
)
{
if
(
id
<
N
)
{
value
[
i
]
=
functor
(
id
);
id
+=
kThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
kElementsPerThread
;
++
i
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
value
[
i
];
id
+=
kThreadsPerBlock
;
}
}
}
template
<
typename
T
,
typename
FuncT
>
void
LaunchElementwiseKernel
(
hipStream_t
stream
,
T
*
output_data
,
const
FuncT
&
functor
,
size_t
output_size
)
{
if
(
output_size
==
0
)
return
;
HIP_LONG
N
=
static_cast
<
HIP_LONG
>
(
output_size
);
int
blocksPerGrid
=
CeilDiv
(
N
,
kThreadsPerBlock
*
kElementsPerThread
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ElementwiseKernel
<
T
,
FuncT
>
),
blocksPerGrid
,
kThreadsPerBlock
,
0
,
stream
,
output_data
,
functor
,
N
);
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/unary_elementwise_impl.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
InT
,
typename
OutT
,
typename
FuncT
,
int
NumThreadsPerBlock
,
int
NumElementsPerThread
>
__global__
void
_UnaryElementWise
(
const
InT
*
input_data
,
OutT
*
output_data
,
const
FuncT
functor
,
HIP_LONG
N
)
{
HIP_LONG
start
=
NumElementsPerThread
*
NumThreadsPerBlock
*
blockIdx
.
x
+
threadIdx
.
x
;
InT
value
[
NumElementsPerThread
];
HIP_LONG
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
value
[
i
]
=
input_data
[
id
];
id
+=
NumThreadsPerBlock
;
}
}
id
=
start
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NumElementsPerThread
;
i
++
)
{
if
(
id
<
N
)
{
output_data
[
id
]
=
functor
(
value
[
i
]);
id
+=
NumThreadsPerBlock
;
}
}
}
template
<
typename
InT
,
typename
OutT
,
typename
FuncT
>
void
UnaryElementWiseImpl
(
hipStream_t
stream
,
const
InT
*
input_data
,
OutT
*
output_data
,
const
FuncT
&
func
,
size_t
count
)
{
if
(
count
==
0
)
// special case where there's a dim value of 0 in the shape
return
;
int
blocksPerGrid
=
static_cast
<
int
>
(
CeilDiv
(
count
,
GridDim
::
maxThreadsPerBlock
*
GridDim
::
maxElementsPerThread
));
HIP_LONG
N
=
static_cast
<
HIP_LONG
>
(
count
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
_UnaryElementWise
<
InT
,
OutT
,
FuncT
,
GridDim
::
maxThreadsPerBlock
,
GridDim
::
maxElementsPerThread
>
),
blocksPerGrid
,
GridDim
::
maxThreadsPerBlock
,
0
,
stream
,
input_data
,
output_data
,
func
,
N
);
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/cu_inc/variadic_elementwise_impl.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
T
,
typename
Func
,
int32_t
max_input_batch_size
,
int32_t
num_elements_per_thread
>
__global__
void
VariadicElementWiseNoBroadcastInputBatchKernel
(
Func
func
,
size_t
N
,
TArray
<
const
T
*
,
max_input_batch_size
>
inputs
,
T
*
output
)
{
const
size_t
base_idx
=
num_elements_per_thread
*
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
T
inputs_buffer
[
num_elements_per_thread
][
max_input_batch_size
];
int32_t
element_count
;
size_t
element_idx
;
#pragma unroll
for
(
element_count
=
0
,
element_idx
=
base_idx
;
element_count
<
num_elements_per_thread
;
++
element_count
,
element_idx
+=
blockDim
.
x
)
{
if
(
element_idx
<
N
)
{
#pragma unroll
for
(
int32_t
input_batch_idx
=
0
;
input_batch_idx
<
max_input_batch_size
;
++
input_batch_idx
)
{
if
(
input_batch_idx
<
inputs
.
Size
())
{
inputs_buffer
[
element_count
][
input_batch_idx
]
=
inputs
[
input_batch_idx
][
element_idx
];
}
}
}
}
#pragma unroll
for
(
element_count
=
0
,
element_idx
=
base_idx
;
element_count
<
num_elements_per_thread
;
++
element_count
,
element_idx
+=
blockDim
.
x
)
{
if
(
element_idx
<
N
)
{
// first and second inputs
T
output_value
=
func
(
inputs_buffer
[
element_count
][
0
],
inputs_buffer
[
element_count
][
1
]);
// remaining inputs
#pragma unroll
for
(
int32_t
input_batch_idx
=
2
;
input_batch_idx
<
max_input_batch_size
;
++
input_batch_idx
)
{
if
(
input_batch_idx
<
inputs
.
Size
())
{
output_value
=
func
(
output_value
,
inputs_buffer
[
element_count
][
input_batch_idx
]);
}
}
output
[
element_idx
]
=
output_value
;
}
}
}
// assumptions:
// - inputs.Size() > 1 && inputs.Size() <= max_input_batch_size
// - inputs and output have N elements
template
<
typename
T
,
typename
Func
,
int32_t
max_input_batch_size
>
void
VariadicElementWiseNoBroadcastInputBatchImpl
(
hipStream_t
stream
,
Func
func
,
size_t
N
,
TArray
<
const
T
*
,
max_input_batch_size
>
inputs
,
T
*
output
)
{
constexpr
int32_t
elements_per_thread
=
GridDim
::
maxElementsPerThread
;
constexpr
int32_t
threads_per_block
=
GridDim
::
maxThreadsPerBlock
;
const
int32_t
blocks_per_grid
=
static_cast
<
int32_t
>
(
CeilDiv
(
N
,
elements_per_thread
*
threads_per_block
));
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
VariadicElementWiseNoBroadcastInputBatchKernel
<
T
,
Func
,
max_input_batch_size
,
elements_per_thread
>
),
blocks_per_grid
,
threads_per_block
,
0
,
stream
,
func
,
N
,
inputs
,
output
);
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/constant_of_shape.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "constant_of_shape.h"
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
rocm
{
ONNX_OPERATOR_KERNEL_EX
(
ConstantOfShape
,
kOnnxDomain
,
9
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
InputMemoryType
(
OrtMemTypeCPUInput
,
0
)
.
TypeConstraint
(
"T1"
,
DataTypeImpl
::
GetTensorType
<
int64_t
>
())
.
TypeConstraint
(
"T2"
,
DataTypeImpl
::
AllFixedSizeTensorTypes
()),
ConstantOfShape
);
Status
ConstantOfShape
::
ComputeInternal
(
OpKernelContext
*
ctx
)
const
{
Tensor
*
output_tensor
=
nullptr
;
ORT_RETURN_IF_ERROR
(
PrepareCompute
(
ctx
,
&
output_tensor
));
auto
output_data
=
output_tensor
->
MutableDataRaw
();
const
auto
size
=
output_tensor
->
Shape
().
Size
();
const
void
*
value_ptr
=
GetValuePtr
();
const
auto
element_size
=
output_tensor
->
DataType
()
->
Size
();
#define CASE(TYPE) \
case sizeof(TYPE): \
if (size > 0) { \
rocm::Fill(Stream(), reinterpret_cast<TYPE*>(output_data), *(reinterpret_cast<const TYPE*>(value_ptr)), size); \
} \
break;
switch
(
element_size
)
{
CASE
(
int8_t
)
CASE
(
int16_t
)
CASE
(
int32_t
)
CASE
(
int64_t
)
default:
ORT_THROW
(
"Unsupported value attribute datatype with sizeof=: "
,
element_size
);
break
;
}
return
Status
::
OK
();
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/constant_of_shape.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/cpu/generator/constant_of_shape_base.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
rocm
{
class
ConstantOfShape
final
:
public
ConstantOfShapeBase
<>
,
public
RocmKernel
{
public:
explicit
ConstantOfShape
(
const
OpKernelInfo
&
info
)
:
ConstantOfShapeBase
(
info
),
RocmKernel
(
info
)
{}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE
(
ConstantOfShape
);
Status
ComputeInternal
(
OpKernelContext
*
ctx
)
const
override
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/generator/random.h"
#include "core/providers/rocm/generator/random_impl.h"
namespace
onnxruntime
{
namespace
rocm
{
using
namespace
ONNX_NAMESPACE
;
ONNX_OPERATOR_KERNEL_EX
(
RandomNormal
,
kOnnxDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
()).
TypeConstraint
(
"T"
,
DataTypeImpl
::
AllIEEEFloatTensorTypes
()),
RandomNormal
);
ONNX_OPERATOR_KERNEL_EX
(
RandomNormalLike
,
kOnnxDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T1"
,
DataTypeImpl
::
AllTensorTypes
())
.
TypeConstraint
(
"T2"
,
DataTypeImpl
::
AllIEEEFloatTensorTypes
()),
RandomNormalLike
);
ONNX_OPERATOR_KERNEL_EX
(
RandomUniform
,
kOnnxDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
()).
TypeConstraint
(
"T"
,
DataTypeImpl
::
AllIEEEFloatTensorTypes
()),
RandomUniform
);
ONNX_OPERATOR_KERNEL_EX
(
RandomUniformLike
,
kOnnxDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T1"
,
DataTypeImpl
::
AllTensorTypes
())
.
TypeConstraint
(
"T2"
,
DataTypeImpl
::
AllIEEEFloatTensorTypes
()),
RandomUniformLike
);
#define RANDOM_COMPUTE_IMPL(name) \
template <typename T> \
struct name##ComputeImpl { \
void operator()(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, Tensor& Y) const { \
typedef typename ToHipType<T>::MappedType HipT; \
HipT* Y_data = reinterpret_cast<HipT*>(Y.MutableData<T>()); \
name##KernelImpl<HipT>(prop, stream, N, alpha, beta, generator, Y_data); \
} \
};
RANDOM_COMPUTE_IMPL
(
RandomNormal
)
RANDOM_COMPUTE_IMPL
(
RandomUniform
)
#undef RANDOM_COMPUTE_IMPL
Status
RandomNormalBase
::
ComputeNormal
(
const
RocmKernel
&
rocm_kernel
,
OpKernelContext
&
ctx
,
const
TensorShape
&
shape
,
int
dtype
)
const
{
Tensor
&
Y
=
*
ctx
.
Output
(
0
,
shape
);
const
int64_t
N
=
shape
.
Size
();
PhiloxGenerator
&
generator
=
GetPhiloxGenerator
();
utils
::
MLTypeCallDispatcher
<
float
,
MLFloat16
,
double
>
t_disp
(
dtype
);
t_disp
.
Invoke
<
RandomNormalComputeImpl
>
(
rocm_kernel
.
GetDeviceProp
(),
rocm_kernel
.
Stream
(),
N
,
scale_
,
mean_
,
generator
,
Y
);
return
Status
::
OK
();
}
Status
RandomNormalLike
::
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
{
const
Tensor
*
p_X
=
p_ctx
->
Input
<
Tensor
>
(
0
);
if
(
!
p_X
)
{
return
Status
(
common
::
ONNXRUNTIME
,
common
::
FAIL
,
"X Input is not available."
);
}
int
dtype
=
GetDType
();
if
(
dtype
==
TensorProto_DataType_UNDEFINED
&&
!
p_X
->
IsDataType
<
float
>
()
&&
!
p_X
->
IsDataType
<
double
>
()
&&
!
p_X
->
IsDataType
<
MLFloat16
>
())
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
FAIL
,
"Output data type is required to be one of float types, but got incompatible data type "
,
p_X
->
DataType
(),
" from input tensor."
);
}
if
(
dtype
==
TensorProto_DataType_UNDEFINED
)
dtype
=
p_X
->
GetElementType
();
return
ComputeNormal
(
*
this
,
*
p_ctx
,
p_X
->
Shape
(),
dtype
);
}
Status
RandomUniformBase
::
ComputeUniform
(
const
RocmKernel
&
rocm_kernel
,
OpKernelContext
&
ctx
,
const
TensorShape
&
shape
,
int
dtype
)
const
{
Tensor
&
Y
=
*
ctx
.
Output
(
0
,
shape
);
const
int64_t
N
=
shape
.
Size
();
PhiloxGenerator
&
generator
=
GetPhiloxGenerator
();
utils
::
MLTypeCallDispatcher
<
float
,
MLFloat16
,
double
>
t_disp
(
dtype
);
t_disp
.
Invoke
<
RandomUniformComputeImpl
>
(
rocm_kernel
.
GetDeviceProp
(),
rocm_kernel
.
Stream
(),
N
,
range_
,
from_
,
generator
,
Y
);
return
Status
::
OK
();
}
Status
RandomUniformLike
::
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
{
const
Tensor
*
p_X
=
p_ctx
->
Input
<
Tensor
>
(
0
);
if
(
!
p_X
)
{
return
Status
(
common
::
ONNXRUNTIME
,
common
::
FAIL
,
"X Input is not available."
);
}
int
dtype
=
GetDType
();
if
(
dtype
==
TensorProto_DataType_UNDEFINED
&&
!
p_X
->
IsDataType
<
float
>
()
&&
!
p_X
->
IsDataType
<
double
>
()
&&
!
p_X
->
IsDataType
<
MLFloat16
>
())
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
FAIL
,
"Output data type is required to be one of float types, but got incompatible data type "
,
p_X
->
DataType
(),
" from input tensor."
);
}
if
(
dtype
==
TensorProto_DataType_UNDEFINED
)
dtype
=
p_X
->
GetElementType
();
return
ComputeUniform
(
*
this
,
*
p_ctx
,
p_X
->
Shape
(),
dtype
);
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/random_generator.h"
#include "core/providers/rocm/rocm_kernel.h"
#include <optional>
namespace
onnxruntime
{
namespace
rocm
{
class
RandomBase
{
protected:
explicit
RandomBase
(
const
OpKernelInfo
&
info
)
{
float
seed
=
0.
f
;
if
(
info
.
GetAttr
<
float
>
(
"seed"
,
&
seed
).
IsOK
())
{
generator_
.
emplace
(
static_cast
<
uint64_t
>
(
seed
));
}
int64_t
dtype
;
if
(
info
.
GetAttr
<
int64_t
>
(
"dtype"
,
&
dtype
).
IsOK
())
{
ORT_ENFORCE
(
ONNX_NAMESPACE
::
TensorProto
::
DataType_IsValid
(
gsl
::
narrow
<
int
>
(
dtype
))
&&
dtype
!=
ONNX_NAMESPACE
::
TensorProto_DataType_UNDEFINED
,
"Invalid dtype of "
,
dtype
);
dtype_
=
static_cast
<
ONNX_NAMESPACE
::
TensorProto
::
DataType
>
(
dtype
);
}
}
protected:
void
SetDTypeIfUndefined
(
ONNX_NAMESPACE
::
TensorProto
::
DataType
dtype
)
noexcept
{
if
(
dtype_
==
ONNX_NAMESPACE
::
TensorProto_DataType_UNDEFINED
)
{
dtype_
=
dtype
;
}
}
ONNX_NAMESPACE
::
TensorProto
::
DataType
GetDType
()
const
noexcept
{
return
dtype_
;
}
PhiloxGenerator
&
GetPhiloxGenerator
()
const
{
return
(
generator_
.
has_value
())
?
*
generator_
:
PhiloxGenerator
::
Default
();
}
private:
ONNX_NAMESPACE
::
TensorProto
::
DataType
dtype_
=
ONNX_NAMESPACE
::
TensorProto_DataType_UNDEFINED
;
// optional and may be inferred
// This member is thread-safe, ensuring proper synchronization
mutable
std
::
optional
<
PhiloxGenerator
>
generator_
;
};
class
RandomNormalBase
:
public
RandomBase
{
protected:
RandomNormalBase
(
const
OpKernelInfo
&
info
)
:
RandomBase
(
info
)
{
ORT_THROW_IF_ERROR
(
info
.
GetAttr
<
float
>
(
"scale"
,
&
scale_
));
ORT_THROW_IF_ERROR
(
info
.
GetAttr
<
float
>
(
"mean"
,
&
mean_
));
}
Status
ComputeNormal
(
const
RocmKernel
&
rocm_kernel
,
OpKernelContext
&
ctx
,
const
TensorShape
&
shape
,
int
dtype
)
const
;
private:
float
scale_
;
float
mean_
;
};
class
RandomNormal
final
:
public
RocmKernel
,
protected
RandomNormalBase
{
public:
explicit
RandomNormal
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
RandomNormalBase
(
info
)
{
SetDTypeIfUndefined
(
ONNX_NAMESPACE
::
TensorProto_DataType_FLOAT
);
std
::
vector
<
int64_t
>
shape
;
ORT_THROW_IF_ERROR
(
info
.
GetAttrs
<
int64_t
>
(
"shape"
,
shape
));
shape_
=
TensorShape
(
shape
);
}
Status
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
override
{
return
ComputeNormal
(
*
this
,
*
p_ctx
,
shape_
,
GetDType
());
}
private:
TensorShape
shape_
;
};
class
RandomNormalLike
final
:
public
RocmKernel
,
protected
RandomNormalBase
{
public:
explicit
RandomNormalLike
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
RandomNormalBase
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
override
;
};
class
RandomUniformBase
:
public
RandomBase
{
protected:
explicit
RandomUniformBase
(
const
OpKernelInfo
&
info
)
:
RandomBase
(
info
)
{
float
low
,
high
;
ORT_THROW_IF_ERROR
(
info
.
GetAttr
<
float
>
(
"low"
,
&
low
));
ORT_THROW_IF_ERROR
(
info
.
GetAttr
<
float
>
(
"high"
,
&
high
));
from_
=
low
;
range_
=
high
-
low
;
}
Status
ComputeUniform
(
const
RocmKernel
&
rocm_kernel
,
OpKernelContext
&
ctx
,
const
TensorShape
&
shape
,
int
dtype
)
const
;
private:
float
range_
;
float
from_
;
};
class
RandomUniform
final
:
public
RocmKernel
,
protected
RandomUniformBase
{
public:
explicit
RandomUniform
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
RandomUniformBase
(
info
)
{
SetDTypeIfUndefined
(
ONNX_NAMESPACE
::
TensorProto_DataType_FLOAT
);
std
::
vector
<
int64_t
>
shape
;
ORT_THROW_IF_ERROR
(
info
.
GetAttrs
<
int64_t
>
(
"shape"
,
shape
));
shape_
=
TensorShape
(
shape
);
}
Status
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
override
{
return
ComputeUniform
(
*
this
,
*
p_ctx
,
shape_
,
GetDType
());
}
private:
TensorShape
shape_
;
};
class
RandomUniformLike
final
:
public
RocmKernel
,
protected
RandomUniformBase
{
public:
explicit
RandomUniformLike
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
RandomUniformBase
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
p_ctx
)
const
override
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random_impl.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/generator/random_impl.h"
#include <hiprand_kernel.h>
#include <algorithm>
#include "core/providers/rocm/cu_inc/common.cuh"
namespace
onnxruntime
{
namespace
rocm
{
constexpr
int
UNROLL
=
4
;
struct
DistFunc_RandomNormal
{
__device__
__inline__
float4
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_normal4
(
state
);
}
};
struct
DistFunc_RandomUniform
{
__device__
__inline__
float4
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_uniform4
(
state
);
}
};
struct
TransformFunc_RandomNormal
{
__device__
__inline__
float
operator
()(
const
float
value
,
const
float
scale
,
const
float
mean
)
const
{
return
value
*
scale
+
mean
;
}
};
struct
TransformFunc_RandomUniform
{
__device__
__inline__
float
operator
()(
const
float
value
,
const
float
range
,
const
float
from
)
const
{
// reverse the bounds of hiprand4 from (0, 1] to [0, 1).
// ref: https://github.com/pytorch/pytorch/blob/e795315c638228d4170f3797356c09a70b2ed4cd/aten/src/ATen/native/rocm/DistributionTemplates.h#L464
float
reverse_bound_value
=
value
==
1.0
f
?
0.0
f
:
value
;
return
reverse_bound_value
*
range
+
from
;
}
};
template
<
typename
T
,
typename
DistFuncT
,
typename
TransformFuncT
>
__global__
void
RandomKernel
(
const
int64_t
N
,
const
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
,
const
DistFuncT
&
dist_func
,
const
TransformFuncT
&
transform_func
,
const
float
alpha
,
const
float
beta
,
T
*
Y_data
)
{
HIP_LONG
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
HIP_LONG
step_size
=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
;
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
float4
rand
;
// We ensure every thread generates the same number of random numbers (by rounding
// up the size) and at the same timestep (by syncing threads).
// From ROCM hiprand documentation:
// The Philox_4x32_10 algorithm is closely tied to the thread and block count.
// Each thread computes 4 random numbers in the same time thus the most efficient
// use of Philox_4x32_10 is to generate a multiple of 4 times number of threads.
for
(
HIP_LONG
id
=
idx
*
UNROLL
;
id
<
N
;
id
+=
step_size
)
{
rand
=
dist_func
(
&
state
);
// actual computation
#pragma unroll
for
(
int
i
=
0
;
i
<
UNROLL
;
i
++
)
{
HIP_LONG
li
=
id
+
i
;
if
(
li
<
N
)
{
Y_data
[
li
]
=
static_cast
<
T
>
(
transform_func
((
&
rand
.
x
)[
i
],
alpha
,
beta
));
}
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
DistFuncT
,
typename
TransformFuncT
>
__global__
void
RandomVectorizedKernel
(
const
int64_t
N
,
const
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
,
const
DistFuncT
&
dist_func
,
const
TransformFuncT
&
transform_func
,
const
float
alpha
,
const
float
beta
,
T
*
Y_data
)
{
HIP_LONG
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
HIP_LONG
step_size
=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
;
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
float4
rand
;
// Using vectorized data load/store approach when N % 4 == 0 since this is typical case for input shape size.
using
LoadT
=
aligned_vector
<
T
,
UNROLL
>
;
for
(
HIP_LONG
id
=
idx
*
UNROLL
;
id
<
N
;
id
+=
step_size
)
{
rand
=
dist_func
(
&
state
);
T
r
[
UNROLL
];
// actual computation
#pragma unroll
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
r
[
ii
]
=
static_cast
<
T
>
(
transform_func
((
&
rand
.
x
)[
ii
],
alpha
,
beta
));
}
// Vectorized writes for Y_data
*
(
reinterpret_cast
<
LoadT
*>
(
&
Y_data
[
id
]))
=
*
reinterpret_cast
<
LoadT
*>
(
&
r
[
0
]);
__syncthreads
();
}
}
template
<
typename
T
,
typename
DistFuncT
,
typename
TransformFuncT
>
void
RandomKernelImpl
(
const
hipDeviceProp_t
&
prop
,
hipStream_t
stream
,
const
int64_t
N
,
const
DistFuncT
&
dist_func
,
const
TransformFuncT
&
transform_func
,
float
alpha
,
float
beta
,
PhiloxGenerator
&
generator
,
T
*
Y_data
)
{
const
int
block_size
=
256
;
const
int
blocks_per_sm
=
prop
.
maxThreadsPerMultiProcessor
/
block_size
;
const
int
grid_size
=
std
::
min
(
prop
.
multiProcessorCount
*
blocks_per_sm
,
static_cast
<
int
>
(
CeilDiv
(
N
,
block_size
*
UNROLL
)));
// Compute the number of random numbers generated by each thread, and increment philox generator offset by that
// amount.
const
uint64_t
counter_offset
=
static_cast
<
uint64_t
>
(((
N
-
1
)
/
(
block_size
*
grid_size
*
UNROLL
)
+
1
)
*
UNROLL
);
auto
seeds
=
generator
.
NextPhiloxSeeds
(
counter_offset
);
if
(
N
%
UNROLL
!=
0
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
RandomKernel
<
T
>
),
grid_size
,
block_size
,
0
,
stream
,
N
,
seeds
,
dist_func
,
transform_func
,
alpha
,
beta
,
Y_data
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
RandomVectorizedKernel
<
T
>
),
grid_size
,
block_size
,
0
,
stream
,
N
,
seeds
,
dist_func
,
transform_func
,
alpha
,
beta
,
Y_data
);
}
}
#define RANDOM_KERNEL_IMPL(name) \
template <typename T> \
void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data) { \
RandomKernelImpl(prop, stream, N, DistFunc_##name(), TransformFunc_##name(), alpha, beta, generator, Y_data); \
}
RANDOM_KERNEL_IMPL
(
RandomNormal
)
RANDOM_KERNEL_IMPL
(
RandomUniform
)
#define SPECIALIZED_RANDOM_KERNEL(name, T) \
template void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data);
#define SPECIALIZED_RANDOM_KERNELS(T) \
SPECIALIZED_RANDOM_KERNEL(RandomNormal, T) \
SPECIALIZED_RANDOM_KERNEL(RandomUniform, T)
SPECIALIZED_RANDOM_KERNELS
(
float
)
SPECIALIZED_RANDOM_KERNELS
(
double
)
SPECIALIZED_RANDOM_KERNELS
(
half
)
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/random_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/random_generator.h"
namespace
onnxruntime
{
namespace
rocm
{
#define RANDOM_KERNEL_DECLARE(name) \
template <typename T> \
void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data);
RANDOM_KERNEL_DECLARE
(
RandomNormal
)
RANDOM_KERNEL_DECLARE
(
RandomUniform
)
#undef RANDOM_KERNEL_DECLARE
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_common.h"
#include "range.h"
#include "range_impl.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
rocm
{
ONNX_OPERATOR_KERNEL_EX
(
Range
,
kOnnxDomain
,
11
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
InputMemoryType
(
OrtMemTypeCPUInput
,
0
)
// start
.
InputMemoryType
(
OrtMemTypeCPUInput
,
1
)
// limit
.
InputMemoryType
(
OrtMemTypeCPUInput
,
2
)
// delta
.
TypeConstraint
(
"T"
,
{
DataTypeImpl
::
GetTensorType
<
float
>
(),
DataTypeImpl
::
GetTensorType
<
double
>
(),
DataTypeImpl
::
GetTensorType
<
int16_t
>
(),
DataTypeImpl
::
GetTensorType
<
int32_t
>
(),
DataTypeImpl
::
GetTensorType
<
int64_t
>
()}),
Range
);
template
<
typename
T
>
static
Status
ComputeRange
(
hipStream_t
stream
,
OpKernelContext
*
ctx
)
{
const
auto
&
start_tensor
=
*
ctx
->
Input
<
Tensor
>
(
0
);
const
auto
&
limit_tensor
=
*
ctx
->
Input
<
Tensor
>
(
1
);
const
auto
*
delta_tensor_ptr
=
ctx
->
Input
<
Tensor
>
(
2
);
if
(
!
start_tensor
.
Shape
().
IsScalar
())
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"start in Range operator should be scalar like tensor, yet got shape:"
,
start_tensor
.
Shape
());
}
if
(
!
limit_tensor
.
Shape
().
IsScalar
())
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"limit in Range operator should be scalar like tensor, yet got shape:"
,
limit_tensor
.
Shape
());
}
if
(
delta_tensor_ptr
!=
nullptr
&&
!
delta_tensor_ptr
->
Shape
().
IsScalar
())
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"delta in Range operator should be scalar like tensor, yet got shape:"
,
delta_tensor_ptr
->
Shape
());
}
// Start, Limit and Delta are stored in CPU.
T
start
=
*
(
start_tensor
.
Data
<
T
>
());
T
limit
=
*
(
limit_tensor
.
Data
<
T
>
());
T
delta
=
T
(
1
);
if
(
delta_tensor_ptr
!=
nullptr
)
{
delta
=
*
(
delta_tensor_ptr
->
Data
<
T
>
());
}
if
(
delta
==
T
(
0
))
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"delta in Range operator can not be zero!"
);
}
double
num
=
(
static_cast
<
double
>
(
limit
)
-
static_cast
<
double
>
(
start
))
/
static_cast
<
double
>
(
delta
);
int
count
=
static_cast
<
int
>
(
ceil
(
num
));
if
(
count
<=
0
)
count
=
0
;
TensorShape
shape
=
{
static_cast
<
int64_t
>
(
count
)};
T
*
y
=
ctx
->
Output
(
0
,
shape
)
->
MutableData
<
T
>
();
if
(
count
>
0
)
{
ORT_RETURN_IF_ERROR
(
RangeImpl
(
stream
,
start
,
delta
,
count
,
y
));
}
return
Status
::
OK
();
}
namespace
rocm_range_internal
{
template
<
class
T
>
struct
CallCudaRangeImpl
{
Status
operator
()(
hipStream_t
stream
,
OpKernelContext
*
ctx
)
const
{
return
ComputeRange
<
T
>
(
stream
,
ctx
);
}
};
}
// namespace rocm_range_internal
Status
Range
::
ComputeInternal
(
OpKernelContext
*
ctx
)
const
{
const
auto
*
input_tensor
=
ctx
->
Input
<
Tensor
>
(
0
);
if
(
input_tensor
==
nullptr
)
{
return
Status
(
common
::
ONNXRUNTIME
,
common
::
FAIL
,
"input count mismatch"
);
}
utils
::
MLTypeCallDispatcher
<
int32_t
,
float
,
int64_t
,
double
,
int16_t
>
t_disp
(
input_tensor
->
GetElementType
());
return
t_disp
.
InvokeRet
<
Status
,
rocm_range_internal
::
CallCudaRangeImpl
>
(
Stream
(),
ctx
);
}
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
rocm
{
class
Range
final
:
public
RocmKernel
{
public:
explicit
Range
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range_impl.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "range_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
RangeKernel
(
const
T
start
,
const
T
delta
,
const
int
count
,
T
*
output
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
count
)
{
output
[
index
]
=
start
+
delta
*
index
;
}
}
template
<
typename
T
>
Status
RangeImpl
(
hipStream_t
stream
,
const
T
start
,
const
T
delta
,
const
int
count
,
T
*
output
)
{
constexpr
int
block_size
=
256
;
int
grid_size
=
(
count
+
block_size
-
1
)
/
block_size
;
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
RangeKernel
<
T
>
),
grid_size
,
block_size
,
0
,
stream
,
start
,
delta
,
count
,
output
);
return
HIP_CALL
(
hipGetLastError
());
}
#define SPECIALIZED_IMPL(T) \
template Status RangeImpl<T>(hipStream_t stream, const T start, const T delta, const int count, T* output);
SPECIALIZED_IMPL
(
int16_t
)
SPECIALIZED_IMPL
(
int32_t
)
SPECIALIZED_IMPL
(
int64_t
)
SPECIALIZED_IMPL
(
float
)
SPECIALIZED_IMPL
(
double
)
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/generator/range_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
rocm
{
template
<
typename
T
>
Status
RangeImpl
(
hipStream_t
stream
,
const
T
start
,
const
T
delta
,
const
int
count
,
T
*
output
);
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
using
namespace
onnxruntime
::
common
;
namespace
onnxruntime
{
namespace
rocm
{
template
<
>
Status
BinaryElementwise
<
ShouldNotBroadcast
>::
Prepare
(
OpKernelContext
*
context
,
BinaryElementwisePreparation
*
p
)
const
{
p
->
lhs_tensor
=
context
->
Input
<
Tensor
>
(
0
);
p
->
rhs_tensor
=
context
->
Input
<
Tensor
>
(
1
);
if
(
!
(
p
->
lhs_tensor
->
Shape
()
==
p
->
rhs_tensor
->
Shape
()))
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
FAIL
,
Node
().
Name
(),
": mismatching input shapes: "
,
p
->
lhs_tensor
->
Shape
().
ToString
(),
" != "
,
p
->
rhs_tensor
->
Shape
().
ToString
());
p
->
output_tensor
=
context
->
Output
(
0
,
p
->
lhs_tensor
->
Shape
());
p
->
output_rank_or_simple_broadcast
=
static_cast
<
int32_t
>
(
SimpleBroadcast
::
NoBroadcast
);
return
Status
::
OK
();
}
Status
ComputeOutputShape
(
const
std
::
string
&
node_name
,
const
TensorShape
&
lhs_shape
,
const
TensorShape
&
rhs_shape
,
TensorShape
&
out_shape
)
{
size_t
lhs_rank
=
lhs_shape
.
NumDimensions
();
size_t
rhs_rank
=
rhs_shape
.
NumDimensions
();
size_t
out_rank
=
std
::
max
(
lhs_rank
,
rhs_rank
);
std
::
vector
<
int64_t
>
output_dims
(
out_rank
,
0
);
for
(
size_t
i
=
0
;
i
<
out_rank
;
++
i
)
{
int64_t
lhs_dim
=
1
;
if
(
i
<
lhs_rank
)
lhs_dim
=
lhs_shape
[
lhs_rank
-
1
-
i
];
int64_t
rhs_dim
=
1
;
if
(
i
<
rhs_rank
)
rhs_dim
=
rhs_shape
[
rhs_rank
-
1
-
i
];
int64_t
max
=
std
::
max
(
lhs_dim
,
rhs_dim
);
int64_t
min
=
std
::
min
(
lhs_dim
,
rhs_dim
);
int64_t
out_dim
=
(
min
==
0
?
min
:
max
);
// special case a dim value of 0.
if
(
lhs_dim
!=
out_dim
&&
lhs_dim
!=
1
)
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
FAIL
,
node_name
,
": left operand cannot broadcast on dim "
,
lhs_rank
-
1
-
i
,
" LeftShape: "
,
lhs_shape
.
ToString
(),
", RightShape: "
,
rhs_shape
.
ToString
());
if
(
rhs_dim
!=
out_dim
&&
rhs_dim
!=
1
)
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
FAIL
,
node_name
,
": right operand cannot broadcast on dim "
,
rhs_rank
-
1
-
i
,
" LeftShape: "
,
lhs_shape
.
ToString
(),
", RightShape: "
,
rhs_shape
.
ToString
());
output_dims
[
out_rank
-
1
-
i
]
=
out_dim
;
}
out_shape
=
TensorShape
(
output_dims
);
return
Status
::
OK
();
}
Status
BinaryElementwiseBroadcastPrepare
(
const
Tensor
*
lhs_tensor
,
const
Tensor
*
rhs_tensor
,
Tensor
*
output_tensor
,
BinaryElementwisePreparation
*
p
,
const
TensorShape
*
override_lhs_shape
,
const
TensorShape
*
override_rhs_shape
)
{
p
->
lhs_tensor
=
lhs_tensor
;
p
->
rhs_tensor
=
rhs_tensor
;
const
auto
&
lhs_shape
=
override_lhs_shape
?
*
override_lhs_shape
:
lhs_tensor
->
Shape
();
const
auto
&
rhs_shape
=
override_rhs_shape
?
*
override_rhs_shape
:
rhs_tensor
->
Shape
();
p
->
output_tensor
=
output_tensor
;
const
auto
&
output_shape
=
output_tensor
->
Shape
();
ORT_RETURN_IF_ERROR
(
p
->
BinaryElementwiseBroadcastPrepareHelper
(
lhs_shape
,
rhs_shape
,
output_shape
));
return
Status
::
OK
();
}
template
<
>
Status
BinaryElementwise
<
ShouldBroadcast
>::
Prepare
(
OpKernelContext
*
context
,
BinaryElementwisePreparation
*
p
)
const
{
auto
lhs_tensor
=
context
->
Input
<
Tensor
>
(
0
);
auto
rhs_tensor
=
context
->
Input
<
Tensor
>
(
1
);
const
auto
&
lhs_shape
=
lhs_tensor
->
Shape
();
const
auto
&
rhs_shape
=
rhs_tensor
->
Shape
();
TensorShape
output_shape
;
ORT_RETURN_IF_ERROR
(
ComputeOutputShape
(
Node
().
Name
(),
lhs_shape
,
rhs_shape
,
output_shape
));
auto
output_tensor
=
context
->
Output
(
0
,
output_shape
);
ORT_RETURN_IF_ERROR
(
BinaryElementwiseBroadcastPrepare
(
lhs_tensor
,
rhs_tensor
,
output_tensor
,
p
));
return
Status
::
OK
();
}
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED_V(x, class_name, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
class_name<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(x, ver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED_V(x, x, ver, T)
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_NONTEMP(x, class_name, ver, ...) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<>(__VAR_ARGS__)), \
class_name);
#define BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED_CLASS(x, class_name, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
class_name<T>);
#define BINARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()), \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.rhs_tensor->Data<T>()), \
&prepare.fdm_output_strides, \
prepare.fdm_H, \
prepare.fdm_C, \
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()), \
prepare.output_tensor->Shape().Size()); \
return Status::OK(); \
}
#define BINARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, T)
#define BINARY_OP_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_COMPUTE(name, T)
#define BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED_CLASS(name, class_name, startver, endver, T) \
BINARY_ELEMENTWISE_COMPUTE(class_name, T)
#define BINARY_LOGICALOP_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_COMPUTE(name, T)
// since different ops has different types, we cannot use BINARY_OPS() directly
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, double)
#define BINARY_OP_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_VERSIONED_HFD(name, startver, endver)
#define BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_HFD(name, ver) \
BINARY_OP_TYPED(name, ver, MLFloat16) \
BINARY_OP_TYPED(name, ver, float) \
BINARY_OP_TYPED(name, ver, double) \
BINARY_OP_TYPED(name, ver, BFloat16)
#define BINARY_OP_UZILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
BINARY_OP_HFD(name, ver)
#define BINARY_OP_REGISTER_VERSIONED_OIL(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, bool) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t)
#define BINARY_LOGICALOP_REGISTER_OIL(name, ver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, bool) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t)
#define BINARY_OP_REGISTER_HFD(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, double) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, BFloat16)
#define BINARY_OP_REGISTER_UZILHFD(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, uint32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, uint64_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) \
BINARY_OP_REGISTER_HFD(name, ver)
#define BINARY_LOGICALOP_REGISTER_UZILHFD(name, ver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, float) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, double) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, BFloat16)
#define BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, double) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_HFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, double) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(name, class_name, startver, endver) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, MLFloat16) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, float) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, double) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_REGISTER_VERSIONED_HFD(name, startver, endver)
BINARY_OP_VERSIONED_UZILHFD
(
Add
,
7
,
12
)
BINARY_OP_VERSIONED_UZILHFD
(
Sub
,
7
,
12
)
BINARY_OP_VERSIONED_UZILHFD
(
Mul
,
7
,
12
)
BINARY_OP_VERSIONED_UZILHFD
(
Div
,
7
,
12
)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16
(
Add
,
13
,
13
)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16
(
Sub
,
13
,
13
)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16
(
Mul
,
13
,
13
)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16
(
Div
,
13
,
13
)
BINARY_OP_UZILHFD
(
Add
,
14
)
BINARY_OP_UZILHFD
(
Sub
,
14
)
BINARY_OP_UZILHFD
(
Mul
,
14
)
BINARY_OP_UZILHFD
(
Div
,
14
)
BINARY_OP_REGISTER_VERSIONED_CLASS_HFD
(
Pow
,
Pow_7
,
7
,
11
)
BINARY_LOGICALOP_TYPED
(
And
,
7
,
bool
)
BINARY_LOGICALOP_TYPED
(
Or
,
7
,
bool
)
BINARY_LOGICALOP_TYPED
(
Xor
,
7
,
bool
)
BINARY_OP_VERSIONED_HFD
(
PRelu
,
7
,
8
)
BINARY_OP_VERSIONED_HFD
(
PRelu
,
9
,
15
)
// Opset-16 adds BFloat16 to allowed types for the PRelu operator
BINARY_OP_HFD
(
PRelu
,
16
)
// Pow since version 12
ONNX_OPERATOR_VERSIONED_KERNEL_EX
(
Pow
,
kOnnxDomain
,
12
,
12
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
())
.
TypeConstraint
(
"T1"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
()),
Pow
);
ONNX_OPERATOR_VERSIONED_KERNEL_EX
(
Pow
,
kOnnxDomain
,
13
,
14
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
())
.
TypeConstraint
(
"T1"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
()),
Pow
);
ONNX_OPERATOR_KERNEL_EX
(
Pow
,
kOnnxDomain
,
15
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
())
.
TypeConstraint
(
"T1"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
float
,
double
,
MLFloat16
>
()),
Pow
);
namespace
pow12_internal
{
template
<
class
T
>
Status
DispatchOnFirstArg
(
hipStream_t
stream
,
const
BinaryElementwisePreparation
&
prepare
)
{
namespace
on
=
ONNX_NAMESPACE
;
Status
s
;
switch
(
prepare
.
rhs_tensor
->
GetElementType
())
{
case
on
::
TensorProto_DataType_INT32
:
ImplT1_Pow
<
typename
ToHipType
<
T
>::
MappedType
,
typename
ToHipType
<
int32_t
>::
MappedType
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
int32_t
>::
MappedType
*>
(
prepare
.
rhs_tensor
->
Data
<
int32_t
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
break
;
case
on
::
TensorProto_DataType_INT64
:
ImplT1_Pow
<
typename
ToHipType
<
T
>::
MappedType
,
typename
ToHipType
<
int64_t
>::
MappedType
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
int64_t
>::
MappedType
*>
(
prepare
.
rhs_tensor
->
Data
<
int64_t
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
break
;
case
on
::
TensorProto_DataType_FLOAT
:
ImplT1_Pow
<
typename
ToHipType
<
T
>::
MappedType
,
typename
ToHipType
<
float
>::
MappedType
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
float
>::
MappedType
*>
(
prepare
.
rhs_tensor
->
Data
<
float
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
break
;
case
on
::
TensorProto_DataType_DOUBLE
:
ImplT1_Pow
<
typename
ToHipType
<
T
>::
MappedType
,
typename
ToHipType
<
double
>::
MappedType
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
double
>::
MappedType
*>
(
prepare
.
rhs_tensor
->
Data
<
double
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
break
;
case
on
::
TensorProto_DataType_FLOAT16
:
ImplT1_Pow
<
typename
ToHipType
<
T
>::
MappedType
,
typename
ToHipType
<
MLFloat16
>::
MappedType
>
(
stream
,
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
typename
ToHipType
<
MLFloat16
>::
MappedType
*>
(
prepare
.
rhs_tensor
->
Data
<
MLFloat16
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
typename
ToHipType
<
T
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
T
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
break
;
default:
s
=
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Unsupported Y type: "
,
DataTypeImpl
::
ToString
(
prepare
.
rhs_tensor
->
DataType
()));
}
return
s
;
}
}
// namespace pow12_internal
Status
Pow
::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
BinaryElementwisePreparation
prepare
;
ORT_RETURN_IF_ERROR
(
Prepare
(
context
,
&
prepare
));
namespace
on
=
ONNX_NAMESPACE
;
using
namespace
pow12_internal
;
Status
s
;
switch
(
prepare
.
lhs_tensor
->
GetElementType
())
{
case
on
::
TensorProto_DataType_INT32
:
s
=
DispatchOnFirstArg
<
int32_t
>
(
Stream
(),
prepare
);
break
;
case
on
::
TensorProto_DataType_INT64
:
s
=
DispatchOnFirstArg
<
int64_t
>
(
Stream
(),
prepare
);
break
;
case
on
::
TensorProto_DataType_FLOAT
:
s
=
DispatchOnFirstArg
<
float
>
(
Stream
(),
prepare
);
break
;
case
on
::
TensorProto_DataType_DOUBLE
:
s
=
DispatchOnFirstArg
<
double
>
(
Stream
(),
prepare
);
break
;
case
on
::
TensorProto_DataType_FLOAT16
:
s
=
DispatchOnFirstArg
<
MLFloat16
>
(
Stream
(),
prepare
);
break
;
default:
s
=
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Unsupported X type: "
,
DataTypeImpl
::
ToString
(
prepare
.
lhs_tensor
->
DataType
()));
}
return
s
;
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX
(
Mod
,
kOnnxDomain
,
10
,
12
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
uint32_t
,
uint64_t
,
float
,
double
,
MLFloat16
>
()),
Mod
);
ONNX_OPERATOR_KERNEL_EX
(
Mod
,
kOnnxDomain
,
13
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
())
.
TypeConstraint
(
"T"
,
BuildKernelDefConstraints
<
int32_t
,
int64_t
,
uint32_t
,
uint64_t
,
float
,
double
,
MLFloat16
,
BFloat16
>
()),
Mod
);
Status
Mod
::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
namespace
on
=
ONNX_NAMESPACE
;
BinaryElementwisePreparation
prepare
;
ORT_RETURN_IF_ERROR
(
Prepare
(
context
,
&
prepare
));
auto
element_type
=
prepare
.
lhs_tensor
->
GetElementType
();
ORT_ENFORCE
(
fmod_
||
element_type
==
on
::
TensorProto_DataType_INT32
||
element_type
==
on
::
TensorProto_DataType_INT64
||
element_type
==
on
::
TensorProto_DataType_UINT32
||
element_type
==
on
::
TensorProto_DataType_UINT64
,
"Non-fmod can support integer types only."
);
#define CASE_MOD_ELEMENT_TYPE(name, onnx_type, data_type) \
case onnx_type: { \
Impl_##name<typename ToHipType<data_type>::MappedType>( \
Stream(), prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToHipType<data_type>::MappedType*>(prepare.lhs_tensor->Data<data_type>()), \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToHipType<data_type>::MappedType*>(prepare.rhs_tensor->Data<data_type>()), \
&prepare.fdm_output_strides, prepare.fdm_H, prepare.fdm_C, \
reinterpret_cast<typename ToHipType<data_type>::MappedType*>( \
prepare.output_tensor->MutableData<data_type>()), \
prepare.output_tensor->Shape().Size()); \
} break
if
(
fmod_
)
{
switch
(
element_type
)
{
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_INT32
,
int32_t
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_INT64
,
int64_t
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_UINT32
,
uint32_t
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_UINT64
,
uint64_t
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_FLOAT
,
float
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_DOUBLE
,
double
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_FLOAT16
,
MLFloat16
);
CASE_MOD_ELEMENT_TYPE
(
Fmod
,
on
::
TensorProto_DataType_BFLOAT16
,
BFloat16
);
default:
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Unsupported element type: "
,
DataTypeImpl
::
ToString
(
prepare
.
lhs_tensor
->
DataType
()));
}
}
else
{
switch
(
element_type
)
{
CASE_MOD_ELEMENT_TYPE
(
Mod
,
on
::
TensorProto_DataType_INT32
,
int32_t
);
CASE_MOD_ELEMENT_TYPE
(
Mod
,
on
::
TensorProto_DataType_INT64
,
int64_t
);
CASE_MOD_ELEMENT_TYPE
(
Mod
,
on
::
TensorProto_DataType_UINT32
,
uint32_t
);
CASE_MOD_ELEMENT_TYPE
(
Mod
,
on
::
TensorProto_DataType_UINT64
,
uint64_t
);
default:
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Unsupported element type: "
,
DataTypeImpl
::
ToString
(
prepare
.
lhs_tensor
->
DataType
()));
}
}
#undef CASE_MOD_ELEMENT_TYPE
return
Status
::
OK
();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template
<
typename
T
,
typename
HipT
>
Status
CompareFunction
<
T
,
HipT
>::
CompareMethod
(
OpKernelContext
*
context
,
ImplCompare
Impl_Compare
)
const
{
BinaryElementwisePreparation
prepare
;
ORT_RETURN_IF_ERROR
(
Prepare
(
context
,
&
prepare
));
Impl_Compare
(
Stream
(),
prepare
.
output_rank_or_simple_broadcast
,
&
prepare
.
lhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
lhs_tensor
->
Data
<
T
>
()),
&
prepare
.
rhs_padded_strides
,
reinterpret_cast
<
const
HipT
*>
(
prepare
.
rhs_tensor
->
Data
<
T
>
()),
&
prepare
.
fdm_output_strides
,
prepare
.
fdm_H
,
prepare
.
fdm_C
,
reinterpret_cast
<
ToHipType
<
bool
>::
MappedType
*>
(
prepare
.
output_tensor
->
MutableData
<
bool
>
()),
prepare
.
output_tensor
->
Shape
().
Size
());
return
Status
::
OK
();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template
<
typename
T
>
Status
Greater
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
return
this
->
CompareMethod
(
context
,
&
ImplT2_Greater
);
}
template
<
typename
T
>
Status
Equal
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
return
this
->
CompareMethod
(
context
,
&
ImplT2_Equal
);
}
//Less op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template
<
typename
T
>
Status
Less
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
return
this
->
CompareMethod
(
context
,
&
ImplT2_Less
);
}
//GreaterOrEqual op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template
<
typename
T
>
Status
GreaterOrEqual
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
return
this
->
CompareMethod
(
context
,
&
ImplT2_GreaterOrEqual
);
}
//LessOrEqual op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template
<
typename
T
>
Status
LessOrEqual
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
return
this
->
CompareMethod
(
context
,
&
ImplT2_LessOrEqual
);
}
BINARY_LOGICALOP_REGISTER_UZILHFD
(
Equal
,
13
)
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED
(
Equal
,
13
,
bool
)
BINARY_OP_REGISTER_VERSIONED_UZILHFD
(
Equal
,
11
,
12
)
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED
(
Equal
,
11
,
12
,
bool
)
BINARY_OP_REGISTER_VERSIONED_OIL
(
Equal
,
7
,
10
)
BINARY_LOGICALOP_REGISTER_UZILHFD
(
Greater
,
13
)
BINARY_OP_REGISTER_VERSIONED_UZILHFD
(
Greater
,
9
,
12
)
BINARY_OP_REGISTER_VERSIONED_HFD
(
Greater
,
7
,
8
)
BINARY_LOGICALOP_REGISTER_UZILHFD
(
Less
,
13
)
BINARY_OP_REGISTER_VERSIONED_UZILHFD
(
Less
,
9
,
12
)
BINARY_OP_REGISTER_VERSIONED_HFD
(
Less
,
7
,
8
)
BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD
(
GreaterOrEqual
,
12
,
15
)
BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD
(
LessOrEqual
,
12
,
15
)
// Opset-16 adds BFloat16 to allowed types for the GreaterOrEqual operator
BINARY_LOGICALOP_REGISTER_UZILHFD
(
GreaterOrEqual
,
16
)
// Opset-16 adds BFloat16 to allowed types for the LessOrEqual operator
BINARY_LOGICALOP_REGISTER_UZILHFD
(
LessOrEqual
,
16
)
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/shared_inc/fast_divmod.h"
#include "core/providers/cpu/tensor/utils.h"
namespace
onnxruntime
{
namespace
rocm
{
struct
BinaryElementwisePreparation
{
const
Tensor
*
lhs_tensor
=
nullptr
;
const
Tensor
*
rhs_tensor
=
nullptr
;
Tensor
*
output_tensor
=
nullptr
;
int32_t
output_rank_or_simple_broadcast
=
0
;
// for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums
TArray
<
int64_t
>
lhs_padded_strides
;
TArray
<
int64_t
>
rhs_padded_strides
;
TArray
<
fast_divmod
>
fdm_output_strides
;
// these are for RightPerChannel case
fast_divmod
fdm_H
;
fast_divmod
fdm_C
;
BinaryElementwisePreparation
()
{}
Status
BinaryElementwiseBroadcastPrepareHelper
(
const
TensorShape
&
lhs_shape
,
const
TensorShape
&
rhs_shape
,
const
TensorShape
&
output_shape
)
{
int32_t
lhs_rank
=
gsl
::
narrow_cast
<
int32_t
>
(
lhs_shape
.
NumDimensions
());
int32_t
rhs_rank
=
gsl
::
narrow_cast
<
int32_t
>
(
rhs_shape
.
NumDimensions
());
int32_t
out_rank
=
std
::
max
(
lhs_rank
,
rhs_rank
);
// early return when shapes match
if
(
lhs_shape
==
rhs_shape
)
{
output_rank_or_simple_broadcast
=
static_cast
<
int32_t
>
(
SimpleBroadcast
::
NoBroadcast
);
return
Status
::
OK
();
}
// early return if one operand is scalar
if
(
lhs_shape
.
Size
()
==
1
||
rhs_shape
.
Size
()
==
1
)
{
output_rank_or_simple_broadcast
=
static_cast
<
int32_t
>
(
lhs_shape
.
Size
()
==
1
?
SimpleBroadcast
::
LeftScalar
:
SimpleBroadcast
::
RightScalar
);
return
Status
::
OK
();
}
// special case for lhs(N,C,H) and rhs (C,1) which is used in conv bias
// when N == 1: out[id] = op(lhs[id], rhs[id / H])
// When N > 1: out[id] = op(lhs[id], rhs[id / H % C])
if
(
lhs_shape
==
output_shape
)
{
const
auto
&
rhs_dims
=
rhs_shape
.
GetDims
();
int64_t
C
=
0
;
if
(
1
==
std
::
count_if
(
rhs_dims
.
begin
(),
rhs_dims
.
end
(),
[
&
C
](
int64_t
dim
)
{
if
(
dim
!=
1
)
C
=
dim
;
return
(
dim
!=
1
);
}))
{
int32_t
dim_C
=
gsl
::
narrow_cast
<
int32_t
>
(
std
::
find
(
rhs_dims
.
begin
(),
rhs_dims
.
end
(),
C
)
-
rhs_dims
.
begin
()
+
output_shape
.
NumDimensions
()
-
rhs_shape
.
NumDimensions
());
int64_t
N
=
output_shape
.
SizeToDimension
(
dim_C
);
int64_t
H
=
(
dim_C
<
out_rank
-
1
?
output_shape
.
SizeFromDimension
(
static_cast
<
size_t
>
(
dim_C
)
+
1
)
:
1
);
std
::
vector
<
int64_t
>
new_output_dims
;
if
(
N
==
1
)
{
output_rank_or_simple_broadcast
=
static_cast
<
int32_t
>
(
SimpleBroadcast
::
RightPerChannelBatch1
);
fdm_H
=
fast_divmod
(
gsl
::
narrow_cast
<
int
>
(
H
));
}
else
{
output_rank_or_simple_broadcast
=
static_cast
<
int32_t
>
(
SimpleBroadcast
::
RightPerChannelBatchN
);
fdm_H
=
fast_divmod
(
gsl
::
narrow_cast
<
int
>
(
H
));
fdm_C
=
fast_divmod
(
gsl
::
narrow_cast
<
int
>
(
C
));
}
return
Status
::
OK
();
}
}
output_rank_or_simple_broadcast
=
out_rank
;
if
(
lhs_shape
!=
output_shape
)
{
TensorPitches
original_lhs_padded_strides
(
lhs_shape
.
GetDims
(),
out_rank
);
lhs_padded_strides
.
SetSize
(
out_rank
);
auto
offset
=
out_rank
-
lhs_rank
;
for
(
auto
i
=
offset
;
i
<
out_rank
;
++
i
)
{
// the stride for broadcast dimension is kept as 0
if
(
lhs_shape
.
GetDims
()[
static_cast
<
size_t
>
(
i
)
-
offset
]
!=
1
)
{
lhs_padded_strides
[
i
]
=
original_lhs_padded_strides
[
i
];
}
}
}
if
(
rhs_shape
!=
output_shape
)
{
TensorPitches
original_rhs_padded_strides
(
rhs_shape
.
GetDims
(),
out_rank
);
rhs_padded_strides
.
SetSize
(
out_rank
);
auto
offset
=
out_rank
-
rhs_rank
;
for
(
auto
i
=
offset
;
i
<
out_rank
;
++
i
)
{
// the stride for broadcast dimension is kept as 0
if
(
rhs_shape
.
GetDims
()[
static_cast
<
size_t
>
(
i
)
-
offset
]
!=
1
)
{
rhs_padded_strides
[
i
]
=
original_rhs_padded_strides
[
i
];
}
}
}
TensorPitches
original_output_strides
(
output_shape
.
GetDims
());
fdm_output_strides
.
SetSize
(
out_rank
);
for
(
auto
i
=
0
;
i
<
out_rank
;
++
i
)
{
fdm_output_strides
[
i
]
=
fast_divmod
(
gsl
::
narrow_cast
<
int
>
(
original_output_strides
[
i
]));
}
return
Status
::
OK
();
}
};
Status
ComputeOutputShape
(
const
std
::
string
&
node_name
,
const
TensorShape
&
lhs_shape
,
const
TensorShape
&
rhs_shape
,
TensorShape
&
out_shape
);
Status
BinaryElementwiseBroadcastPrepare
(
const
Tensor
*
lhs_tensor
,
const
Tensor
*
rhs_tensor
,
Tensor
*
output_tensor
,
BinaryElementwisePreparation
*
p
,
const
TensorShape
*
override_lhs_shape
=
nullptr
,
const
TensorShape
*
override_rhs_shape
=
nullptr
);
// trait classes to indicate if the kernel supports broadcast
class
ShouldBroadcast
{
};
class
ShouldNotBroadcast
{
};
template
<
typename
BroadcastTrait
>
class
BinaryElementwise
:
public
RocmKernel
{
protected:
typedef
BroadcastTrait
broadcast_type
;
BinaryElementwise
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
)
const
override
{
return
Status
(
common
::
ONNXRUNTIME
,
common
::
FAIL
);
// should not reach here
}
Status
Prepare
(
OpKernelContext
*
context
,
BinaryElementwisePreparation
*
p
)
const
;
};
template
<
typename
T
>
class
Add
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Add
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Sub
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Sub
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Mul
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Mul
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Div
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Div
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Pow_7
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Pow_7
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
// Since version 12
class
Pow
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Pow
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
And
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
And
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Or
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Or
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Xor
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Xor
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
// PRelu is activation function, but it's closer to binary elementwise ops in implementation
template
<
typename
T
>
class
PRelu
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
PRelu
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
class
Mod
final
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
Mod
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{
int64_t
fmod
=
info
.
GetAttrOrDefault
<
int64_t
>
(
"fmod"
,
0LL
);
fmod_
=
fmod
!=
0
;
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
bool
fmod_
{
false
};
};
template
<
typename
T
,
typename
HipT
>
class
CompareFunction
:
public
BinaryElementwise
<
ShouldBroadcast
>
{
public:
CompareFunction
(
const
OpKernelInfo
&
info
)
:
BinaryElementwise
(
info
)
{}
typedef
void
(
*
ImplCompare
)(
hipStream_t
stream
,
int32_t
output_rank_or_simple_broadcast
,
const
TArray
<
int64_t
>*
lhs_padded_strides
,
const
HipT
*
lhs_data
,
const
TArray
<
int64_t
>*
rhs_padded_strides
,
const
HipT
*
rhs_data
,
const
TArray
<
fast_divmod
>*
fdm_output_strides
,
const
fast_divmod
&
fdm_H
,
const
fast_divmod
&
fdm_C
,
bool
*
output_data
,
size_t
count
);
Status
CompareMethod
(
OpKernelContext
*
context
,
ImplCompare
Impl_Compare
)
const
;
};
template
<
typename
T
>
class
Greater
final
:
public
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
{
public:
Greater
(
const
OpKernelInfo
&
info
)
:
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Equal
final
:
public
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
{
public:
Equal
(
const
OpKernelInfo
&
info
)
:
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
Less
final
:
public
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
{
public:
Less
(
const
OpKernelInfo
&
info
)
:
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
GreaterOrEqual
final
:
public
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
{
public:
GreaterOrEqual
(
const
OpKernelInfo
&
info
)
:
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
template
<
typename
T
>
class
LessOrEqual
final
:
public
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
{
public:
LessOrEqual
(
const
OpKernelInfo
&
info
)
:
CompareFunction
<
T
,
typename
ToHipType
<
T
>::
MappedType
>
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
};
}
// namespace rocm
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/core/providers/rocm/math/binary_elementwise_ops_impl.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
namespace
onnxruntime
{
namespace
rocm
{
#define BINARY_ELEMENTWISE_IMPL(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T, T>(), \
count); \
}
#define BINARY_ELEMENTWISE_IMPL_T1(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION_T1(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T, T1>(), \
count); \
}
#define BINARY_ELEMENTWISE_IMPL_T2(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION_T2(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T1, T2>(), \
count); \
}
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, T) \
template void Impl_##x<T>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, T1) \
template void ImplT1_##x<T, T1>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T1* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(x, T, T1, T2) \
template void ImplT2_##x<T, T1, T2>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T1* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T2* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(x, T) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, double)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
// create declarations for impl
#define BINARY_OP_NAME_EXPR(name, expr) \
BINARY_ELEMENTWISE_IMPL(name)
BINARY_OPS
()
#undef BINARY_OP_NAME_EXPR
// create specialized impl
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Add
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL
(
Add
,
bool
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Sub
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Mul
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Div
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD
(
Pow_7
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL
(
And
,
bool
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL
(
Or
,
bool
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL
(
Xor
,
bool
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD
(
PRelu
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Max
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Min
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL
(
Mod
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD
(
Fmod
)
// create declarations for impl for Pow
BINARY_ELEMENTWISE_IMPL_T1
(
Pow
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD
(
Pow
,
int32_t
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD
(
Pow
,
int64_t
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD
(
Pow
,
float
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD
(
Pow
,
double
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD
(
Pow
,
half
)
// create declarations for impl2
#define BINARY_OP_NAME_EXPR2(name, expr) \
BINARY_ELEMENTWISE_IMPL_T2(name)
BINARY_OPS2
()
#undef BINARY_OP_NAME_EXPR2
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(name) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, uint32_t, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, uint64_t, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, int32_t, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, int64_t, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, half, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, float, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, double, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, BFloat16, BFloat16)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2
(
Greater
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2
(
Equal
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2
(
Equal
,
bool
,
bool
,
bool
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2
(
Less
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2
(
GreaterOrEqual
)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2
(
LessOrEqual
)
}
// namespace rocm
}
// namespace onnxruntime
Prev
1
2
3
4
5
6
7
8
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment