Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
namespace onnxruntime {
namespace rocm {
template <typename T>
struct OP_Elu : public CtxElu {
__device__ __inline__ T operator()(const T& a) const {
return a > (T)0 ? a : (T)alpha * (_Exp(a) - (T)1);
}
};
template <typename T>
struct OP_HardSigmoid : public CtxHardSigmoid {
__device__ __inline__ T operator()(const T& a) const {
return _Max(_Min((T)alpha * a + (T)beta, (T)1), (T)0);
}
};
template <typename T>
struct OP_LeakyRelu : public CtxLeakyRelu {
__device__ __inline__ T operator()(const T& a) const {
return a > (T)0 ? a : (T)alpha * a;
}
};
template <typename T>
struct OP_Relu : public CtxRelu {
__device__ __inline__ T operator()(const T& a) const {
return _Max(a, (T)0);
}
};
template <typename T>
struct OP_Selu : public CtxSelu {
__device__ __inline__ T operator()(const T& a) const {
return a > (T)0 ? (T)gamma * a : (T)gamma * (T)alpha * (_Exp(a) - (T)1);
}
};
template <typename T>
struct OP_Sigmoid : public CtxSigmoid {
__device__ __inline__ T operator()(const T& a) const {
return a > T(0) ? (T)1 / ((T)1. + _Exp(-_Abs(a))) : (T)1 - (T)1 / ((T)1 + _Exp(-_Abs(a)));
}
};
template <typename T>
struct OP_Softplus : public CtxSoftplus {
__device__ __inline__ T operator()(const T& a) const {
if (a > (T)0)
return a + _Log(_Exp(-a) + (T)1);
else
return _Log(_Exp(a) + (T)1);
}
};
template <typename T>
struct OP_Softsign : public CtxSoftsign {
__device__ __inline__ T operator()(const T& a) const {
return a / ((T)1. + _Abs(a));
}
};
template <typename T>
struct OP_Tanh : public CtxTanh {
__device__ __inline__ T operator()(const T& a) const {
return _Tanh(a);
}
};
template <typename T>
struct OP_ThresholdedRelu : public CtxThresholdedRelu {
__device__ __inline__ T operator()(const T& a) const {
return a > (T)alpha ? a : (T)0;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, \
size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, BFloat16)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace rocm {
struct CtxAlpha {
float alpha;
};
struct CtxAlphaBeta {
float alpha;
float beta;
};
struct CtxAlphaGamma {
float alpha;
float gamma;
};
struct CtxNull {
};
typedef CtxAlpha CtxElu;
typedef CtxAlphaBeta CtxHardSigmoid;
typedef CtxAlpha CtxLeakyRelu;
typedef CtxNull CtxRelu;
typedef CtxAlphaGamma CtxSelu;
typedef CtxNull CtxSigmoid;
typedef CtxNull CtxSoftplus;
typedef CtxNull CtxSoftsign;
typedef CtxNull CtxTanh;
typedef CtxAlpha CtxThresholdedRelu;
#define UNARY_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(Elu) \
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
UNARY_ACTIVATION_OP_NAME(Relu) \
UNARY_ACTIVATION_OP_NAME(Selu) \
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
UNARY_ACTIVATION_OP_NAME(Softplus) \
UNARY_ACTIVATION_OP_NAME(Softsign) \
UNARY_ACTIVATION_OP_NAME(Tanh) \
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
hipStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \
size_t count)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace rocm {
// broadcast by computing output coordinate from offset, using fast_divmod
template <typename T, typename T1, typename T2, typename FuncT,
bool lhs_need_compute, bool rhs_need_compute, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _BinaryElementWise(
int32_t output_rank,
const TArray<int64_t> lhs_padded_strides,
const T1* lhs_data,
const TArray<int64_t> rhs_padded_strides,
const T2* rhs_data,
const TArray<fast_divmod> fdm_output_strides,
T* output_data,
const FuncT& functor,
HIP_LONG N) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];
T2 rvalue[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
HIP_LONG lhs_index = (lhs_need_compute ? 0 : id);
HIP_LONG rhs_index = (rhs_need_compute ? 0 : id);
// compute indexes with broadcasting rules: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
HIP_LONG offset = id;
#pragma unroll
for (auto dim = 0; dim < fdm_output_strides.Capacity(); dim++) {
if (dim >= output_rank) {
break;
}
int q, r;
fdm_output_strides[dim].divmod(offset, q, r);
if (lhs_need_compute) {
lhs_index += static_cast<int>(lhs_padded_strides[dim]) * q;
}
if (rhs_need_compute) {
rhs_index += static_cast<int>(rhs_padded_strides[dim]) * q;
}
offset = r;
}
lvalue[i] = lhs_data[lhs_index];
rvalue[i] = rhs_data[rhs_index];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = functor(lvalue[i], rvalue[i]);
id += NumThreadsPerBlock;
}
}
}
// for scalar broadcast or non-broadcast case
template <bool IncL, bool IncR, typename T, typename T1, typename T2, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _BinaryElementWiseSimple(
const T1* lhs_data,
const T2* rhs_data,
T* output_data,
const FuncT func,
HIP_LONG N) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];
T2 rvalue[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
lvalue[i] = lhs_data[IncL ? id : 0];
rvalue[i] = rhs_data[IncR ? id : 0];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = func(lvalue[i], rvalue[i]);
id += NumThreadsPerBlock;
}
}
}
// for rhs per-channel broadcast case
template <typename T, typename T1, typename T2, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _BinaryElementWiseRhsPerChannelBatch1(
const T1* lhs_data,
const T2* rhs_data,
const fast_divmod fdm_H,
T* output_data,
FuncT func,
HIP_LONG N) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];
T2 rvalue[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
HIP_LONG rhs_id = fdm_H.div(id);
lvalue[i] = lhs_data[id];
rvalue[i] = rhs_data[rhs_id];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = func(lvalue[i], rvalue[i]);
id += NumThreadsPerBlock;
}
}
}
template <typename T, typename T1, typename T2, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _BinaryElementWiseRhsPerChannelBatchN(
const T1* lhs_data,
const T2* rhs_data,
const fast_divmod fdm_H,
const fast_divmod fdm_C,
T* output_data,
FuncT func,
HIP_LONG N) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];
T2 rvalue[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
HIP_LONG rhs_id = fdm_H.div(id);
int q, r;
fdm_C.divmod(rhs_id, q, r);
rhs_id = r;
lvalue[i] = lhs_data[id];
rvalue[i] = rhs_data[rhs_id];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = func(lvalue[i], rvalue[i]);
id += NumThreadsPerBlock;
}
}
}
template <typename T, typename T1, typename T2, typename FuncT>
void BinaryElementWiseNoBroadcastImpl(
hipStream_t stream,
const T1* lhs_data,
const T2* rhs_data,
T* output_data,
const FuncT& func,
size_t count) {
if (count == 0) // special case where there's a dim value of 0 in the output shape
return;
#ifdef USE_ROCM
const int num_elements_per_thread = 2;
const int num_threads_per_block = 512;
#else
const int num_elements_per_thread = GridDim::maxElementsPerThread;
const int num_threads_per_block = GridDim::maxThreadsPerBlock;
#endif
int blocksPerGrid = static_cast<int>(CeilDiv(count, num_threads_per_block * num_elements_per_thread));
HIP_LONG N = static_cast<HIP_LONG>(count);
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseSimple<true, true, T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
output_data,
func,
N);
}
template <typename T, typename T1, typename T2, typename FuncT>
void BinaryElementWiseImpl(
hipStream_t stream,
int32_t output_rank_or_simple_broadcast,
const TArray<int64_t>* lhs_padded_strides,
const T1* lhs_data,
const TArray<int64_t>* rhs_padded_strides,
const T2* rhs_data,
const TArray<fast_divmod>* fdm_output_strides,
const fast_divmod& fdm_H,
const fast_divmod& fdm_C,
T* output_data,
const FuncT& func,
size_t count) {
if (count == 0) // special case where there's a dim value of 0 in the output shape
return;
#ifdef USE_ROCM
const int num_elements_per_thread = 2;
const int num_threads_per_block = 512;
#else
const int num_elements_per_thread = GridDim::maxElementsPerThread;
const int num_threads_per_block = GridDim::maxThreadsPerBlock;
#endif
int blocksPerGrid = static_cast<int>(CeilDiv(count, num_threads_per_block * num_elements_per_thread));
HIP_LONG N = static_cast<HIP_LONG>(count);
if (output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::NoBroadcast)) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseSimple<true, true, T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
output_data,
func,
N);
} else if (output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::LeftScalar)) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseSimple<false, true, T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
output_data,
func,
N);
} else if (output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::RightScalar)) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseSimple<true, false, T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
output_data,
func,
N);
} else if (output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatch1)) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseRhsPerChannelBatch1<T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
fdm_H,
output_data,
func,
N);
} else if (output_rank_or_simple_broadcast == static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatchN)) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWiseRhsPerChannelBatchN<T, T1, T2, FuncT, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
lhs_data,
rhs_data,
fdm_H,
fdm_C,
output_data,
func,
N);
} else {
if (lhs_padded_strides && rhs_padded_strides && lhs_padded_strides->Size() && rhs_padded_strides->Size())
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWise<T, T1, T2, FuncT, true, true, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
output_rank_or_simple_broadcast,
*lhs_padded_strides,
lhs_data,
*rhs_padded_strides,
rhs_data,
*fdm_output_strides,
output_data,
func,
N);
else if (lhs_padded_strides && lhs_padded_strides->Size())
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWise<T, T1, T2, FuncT, true, false, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
output_rank_or_simple_broadcast,
*lhs_padded_strides,
lhs_data,
TArray<int64_t>(), // rhs is not computed, so no need to deference rhs_padded_strides
rhs_data,
*fdm_output_strides,
output_data,
func,
N);
else if (rhs_padded_strides && rhs_padded_strides->Size())
hipLaunchKernelGGL(HIP_KERNEL_NAME(_BinaryElementWise<T, T1, T2, FuncT, false, true, num_threads_per_block, num_elements_per_thread>), blocksPerGrid, num_threads_per_block, 0, stream,
output_rank_or_simple_broadcast,
TArray<int64_t>(), // lhs is not computed, so no need to deference lhs_padded_strides
lhs_data,
*rhs_padded_strides,
rhs_data,
*fdm_output_strides,
output_data,
func,
N);
}
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
/**
* These functions MUST be called with an unroll factor that evenly divides the number of threads in a warp (32 for
* ROCM, 64 for ROCm). In addition, this kernel MUST be launched with a number of threads in a thread block which is
* evenly divisible by the number of threads in a warp.
*
* Take unroll factor of 4 and 32 threads in a warp as example, we take the following approach (for threads in the first
* warp, that is):
*
* Thread 0 generates output booleans 0-3
* Thread 1 generates output booleans 4-7
* ...
* Thread 7 generates output booleans 28-31
*
* These threads all agree on the same thread mask by determining which output bitmask index they want to write to.
* Threads 0-7 will generate the same thread mask (for output index 0), threads 8-15 will generate the same thread mask
* (for output index 1), and so on.
*
* After (partially before) agreeing upon which threads will collaborate to write out a single index,
* each thread generates 4 random values, and shifts them into the right location in the output uint32_t.
* For instance:
*
* Thread 0 will perform a shift of 0
* Thread 1 will perform a shift of 4
* Thread 2 will perform a shift of 8
* ...
*
* For index 0, this gives us the following composition of random bits (number represents which thread generated it):
*
* 77776666555544443333222211110000
*
* After each thread shifts its bits into the right location, we broadcast the reduced value to all threads. Finally,
* we just choose a single thread (in our case, we choose the thread with 0 shift, but any thread from 0-7 would work
* for the 0-7 group).
*
* Keep in mind that this must not be conditionally called, as all threads in the warp (that haven't already exited)
* must reach these function calls an equal number of times, otherwise the code execution is likely to hang or produce
* unintended side effects.
*
* We conditionally update the local thread's mask (with the "li < N" check), but all active threads always collaborate
* on the reduced value.
*/
namespace onnxruntime {
namespace rocm {
template <int NumUnroll>
__device__ __forceinline__ void SetBitmask(const HIP_LONG id, const HIP_LONG mask_element_count,
const fast_divmod fdm_bits_per_element, BitmaskElementType thread_bitmask,
BitmaskElementType* mask_data) {
int bitmask_idx, bitmask_shift;
fdm_bits_per_element.divmod(id, bitmask_idx, bitmask_shift);
BitmaskElementType bitmask = (thread_bitmask << bitmask_shift);
#if defined(USE_ROCM) && __CUDA_ARCH__ >= 800
// All thread which intend to write to the same output index will have the same thread mask.
BitmaskElementType thread_mask = __match_any_sync(0xFFFFFFFF, bitmask_idx);
// All threads with the same thread mask (threads which intend to write to the same output index) collaborate
// on a bitwise-or reduction.
bitmask = __reduce_or_sync(thread_mask, bitmask);
#else
#pragma unroll
for (int stride = kNumBitsPerBitmaskElement / (NumUnroll * 2); stride > 0; stride /= 2) {
bitmask |= WARP_SHFL_DOWN(bitmask, stride);
}
#endif
// Choose a single from the "thread mask" group to perform the output write.
if (bitmask_shift == 0 && bitmask_idx < mask_element_count) {
mask_data[bitmask_idx] = bitmask;
}
}
template <int NumUnroll>
__device__ __forceinline__ void GetMasks(HIP_LONG id, const fast_divmod fdm_bits_per_element,
const BitmaskElementType* mask_data, bool* mask_result) {
int bitmask_idx, bitmask_shift;
fdm_bits_per_element.divmod(id, bitmask_idx, bitmask_shift);
BitmaskElementType shifted_mask = mask_data[bitmask_idx] >> bitmask_shift;
#pragma unroll
for (int i = 0; i < NumUnroll; i++) {
mask_result[i] = (shifted_mask & (1 << i)) != 0;
}
}
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace rocm {
#ifdef USE_ROCM
constexpr int kElementsPerThread = 2;
constexpr int kThreadsPerBlock = 512;
#else
constexpr int kElementsPerThread = GridDim::maxElementsPerThread;
constexpr int kThreadsPerBlock = GridDim::maxThreadsPerBlock;
#endif
template <typename T, typename FuncT>
__global__ void ElementwiseKernel(T* output_data, const FuncT functor, HIP_LONG N) {
HIP_LONG start = kElementsPerThread * kThreadsPerBlock * blockIdx.x + threadIdx.x;
T value[kElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < kElementsPerThread; ++i) {
if (id < N) {
value[i] = functor(id);
id += kThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < kElementsPerThread; ++i) {
if (id < N) {
output_data[id] = value[i];
id += kThreadsPerBlock;
}
}
}
template <typename T, typename FuncT>
void LaunchElementwiseKernel(hipStream_t stream, T* output_data, const FuncT& functor, size_t output_size) {
if (output_size == 0) return;
HIP_LONG N = static_cast<HIP_LONG>(output_size);
int blocksPerGrid = CeilDiv(N, kThreadsPerBlock * kElementsPerThread);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ElementwiseKernel<T, FuncT>), blocksPerGrid, kThreadsPerBlock, 0, stream, output_data, functor, N);
}
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace rocm {
template <typename InT, typename OutT, typename FuncT, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _UnaryElementWise(
const InT* input_data,
OutT* output_data,
const FuncT functor,
HIP_LONG N) {
HIP_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
InT value[NumElementsPerThread];
HIP_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
value[i] = input_data[id];
id += NumThreadsPerBlock;
}
}
id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output_data[id] = functor(value[i]);
id += NumThreadsPerBlock;
}
}
}
template <typename InT, typename OutT, typename FuncT>
void UnaryElementWiseImpl(
hipStream_t stream,
const InT* input_data,
OutT* output_data,
const FuncT& func,
size_t count) {
if (count == 0) // special case where there's a dim value of 0 in the shape
return;
int blocksPerGrid = static_cast<int>(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
HIP_LONG N = static_cast<HIP_LONG>(count);
hipLaunchKernelGGL(HIP_KERNEL_NAME(_UnaryElementWise<InT, OutT, FuncT, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread>), blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream,
input_data,
output_data,
func,
N);
}
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
#pragma once
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
template <
typename T, typename Func,
int32_t max_input_batch_size, int32_t num_elements_per_thread>
__global__ void VariadicElementWiseNoBroadcastInputBatchKernel(
Func func,
size_t N,
TArray<const T*, max_input_batch_size> inputs,
T* output) {
const size_t base_idx = num_elements_per_thread * blockDim.x * blockIdx.x + threadIdx.x;
T inputs_buffer[num_elements_per_thread][max_input_batch_size];
int32_t element_count;
size_t element_idx;
#pragma unroll
for (element_count = 0, element_idx = base_idx;
element_count < num_elements_per_thread;
++element_count, element_idx += blockDim.x) {
if (element_idx < N) {
#pragma unroll
for (int32_t input_batch_idx = 0; input_batch_idx < max_input_batch_size; ++input_batch_idx) {
if (input_batch_idx < inputs.Size()) {
inputs_buffer[element_count][input_batch_idx] = inputs[input_batch_idx][element_idx];
}
}
}
}
#pragma unroll
for (element_count = 0, element_idx = base_idx;
element_count < num_elements_per_thread;
++element_count, element_idx += blockDim.x) {
if (element_idx < N) {
// first and second inputs
T output_value = func(
inputs_buffer[element_count][0], inputs_buffer[element_count][1]);
// remaining inputs
#pragma unroll
for (int32_t input_batch_idx = 2; input_batch_idx < max_input_batch_size; ++input_batch_idx) {
if (input_batch_idx < inputs.Size()) {
output_value = func(output_value, inputs_buffer[element_count][input_batch_idx]);
}
}
output[element_idx] = output_value;
}
}
}
// assumptions:
// - inputs.Size() > 1 && inputs.Size() <= max_input_batch_size
// - inputs and output have N elements
template <typename T, typename Func, int32_t max_input_batch_size>
void VariadicElementWiseNoBroadcastInputBatchImpl(
hipStream_t stream,
Func func,
size_t N,
TArray<const T*, max_input_batch_size> inputs,
T* output) {
constexpr int32_t elements_per_thread = GridDim::maxElementsPerThread;
constexpr int32_t threads_per_block = GridDim::maxThreadsPerBlock;
const int32_t blocks_per_grid = static_cast<int32_t>(CeilDiv(N, elements_per_thread * threads_per_block));
hipLaunchKernelGGL(HIP_KERNEL_NAME(VariadicElementWiseNoBroadcastInputBatchKernel<T, Func, max_input_batch_size, elements_per_thread>), blocks_per_grid, threads_per_block, 0, stream, func, N, inputs, output);
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "constant_of_shape.h"
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
ConstantOfShape,
kOnnxDomain,
9,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2", DataTypeImpl::AllFixedSizeTensorTypes()),
ConstantOfShape);
Status ConstantOfShape::ComputeInternal(OpKernelContext* ctx) const {
Tensor* output_tensor = nullptr;
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, &output_tensor));
auto output_data = output_tensor->MutableDataRaw();
const auto size = output_tensor->Shape().Size();
const void* value_ptr = GetValuePtr();
const auto element_size = output_tensor->DataType()->Size();
#define CASE(TYPE) \
case sizeof(TYPE): \
if (size > 0) { \
rocm::Fill(Stream(), reinterpret_cast<TYPE*>(output_data), *(reinterpret_cast<const TYPE*>(value_ptr)), size); \
} \
break;
switch (element_size) {
CASE(int8_t)
CASE(int16_t)
CASE(int32_t)
CASE(int64_t)
default:
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
break;
}
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/cpu/generator/constant_of_shape_base.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
class ConstantOfShape final : public ConstantOfShapeBase<>, public RocmKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), RocmKernel(info) {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/generator/random.h"
#include "core/providers/rocm/generator/random_impl.h"
namespace onnxruntime {
namespace rocm {
using namespace ONNX_NAMESPACE;
ONNX_OPERATOR_KERNEL_EX(RandomNormal, kOnnxDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
RandomNormal);
ONNX_OPERATOR_KERNEL_EX(RandomNormalLike, kOnnxDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T2", DataTypeImpl::AllIEEEFloatTensorTypes()),
RandomNormalLike);
ONNX_OPERATOR_KERNEL_EX(RandomUniform, kOnnxDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
RandomUniform);
ONNX_OPERATOR_KERNEL_EX(RandomUniformLike, kOnnxDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::AllTensorTypes())
.TypeConstraint("T2", DataTypeImpl::AllIEEEFloatTensorTypes()),
RandomUniformLike);
#define RANDOM_COMPUTE_IMPL(name) \
template <typename T> \
struct name##ComputeImpl { \
void operator()(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, Tensor& Y) const { \
typedef typename ToHipType<T>::MappedType HipT; \
HipT* Y_data = reinterpret_cast<HipT*>(Y.MutableData<T>()); \
name##KernelImpl<HipT>(prop, stream, N, alpha, beta, generator, Y_data); \
} \
};
RANDOM_COMPUTE_IMPL(RandomNormal)
RANDOM_COMPUTE_IMPL(RandomUniform)
#undef RANDOM_COMPUTE_IMPL
Status RandomNormalBase::ComputeNormal(const RocmKernel& rocm_kernel, OpKernelContext& ctx, const TensorShape& shape, int dtype) const {
Tensor& Y = *ctx.Output(0, shape);
const int64_t N = shape.Size();
PhiloxGenerator& generator = GetPhiloxGenerator();
utils::MLTypeCallDispatcher<float, MLFloat16, double> t_disp(dtype);
t_disp.Invoke<RandomNormalComputeImpl>(rocm_kernel.GetDeviceProp(), rocm_kernel.Stream(), N, scale_, mean_, generator, Y);
return Status::OK();
}
Status RandomNormalLike::ComputeInternal(OpKernelContext* p_ctx) const {
const Tensor* p_X = p_ctx->Input<Tensor>(0);
if (!p_X) {
return Status(common::ONNXRUNTIME, common::FAIL, "X Input is not available.");
}
int dtype = GetDType();
if (dtype == TensorProto_DataType_UNDEFINED && !p_X->IsDataType<float>() && !p_X->IsDataType<double>() &&
!p_X->IsDataType<MLFloat16>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Output data type is required to be one of float types, but got incompatible data type ",
p_X->DataType(), " from input tensor.");
}
if (dtype == TensorProto_DataType_UNDEFINED)
dtype = p_X->GetElementType();
return ComputeNormal(*this, *p_ctx, p_X->Shape(), dtype);
}
Status RandomUniformBase::ComputeUniform(const RocmKernel& rocm_kernel, OpKernelContext& ctx, const TensorShape& shape, int dtype) const {
Tensor& Y = *ctx.Output(0, shape);
const int64_t N = shape.Size();
PhiloxGenerator& generator = GetPhiloxGenerator();
utils::MLTypeCallDispatcher<float, MLFloat16, double> t_disp(dtype);
t_disp.Invoke<RandomUniformComputeImpl>(rocm_kernel.GetDeviceProp(), rocm_kernel.Stream(), N, range_, from_, generator, Y);
return Status::OK();
}
Status RandomUniformLike::ComputeInternal(OpKernelContext* p_ctx) const {
const Tensor* p_X = p_ctx->Input<Tensor>(0);
if (!p_X) {
return Status(common::ONNXRUNTIME, common::FAIL, "X Input is not available.");
}
int dtype = GetDType();
if (dtype == TensorProto_DataType_UNDEFINED && !p_X->IsDataType<float>() && !p_X->IsDataType<double>() &&
!p_X->IsDataType<MLFloat16>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Output data type is required to be one of float types, but got incompatible data type ",
p_X->DataType(), " from input tensor.");
}
if (dtype == TensorProto_DataType_UNDEFINED)
dtype = p_X->GetElementType();
return ComputeUniform(*this, *p_ctx, p_X->Shape(), dtype);
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/random_generator.h"
#include "core/providers/rocm/rocm_kernel.h"
#include <optional>
namespace onnxruntime {
namespace rocm {
class RandomBase {
protected:
explicit RandomBase(const OpKernelInfo& info) {
float seed = 0.f;
if (info.GetAttr<float>("seed", &seed).IsOK()) {
generator_.emplace(static_cast<uint64_t>(seed));
}
int64_t dtype;
if (info.GetAttr<int64_t>("dtype", &dtype).IsOK()) {
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(gsl::narrow<int>(dtype)) &&
dtype != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
"Invalid dtype of ", dtype);
dtype_ = static_cast<ONNX_NAMESPACE::TensorProto::DataType>(dtype);
}
}
protected:
void SetDTypeIfUndefined(ONNX_NAMESPACE::TensorProto::DataType dtype) noexcept {
if (dtype_ == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
dtype_ = dtype;
}
}
ONNX_NAMESPACE::TensorProto::DataType GetDType() const noexcept { return dtype_; }
PhiloxGenerator& GetPhiloxGenerator() const {
return (generator_.has_value()) ? *generator_ : PhiloxGenerator::Default();
}
private:
ONNX_NAMESPACE::TensorProto::DataType dtype_ =
ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; // optional and may be inferred
// This member is thread-safe, ensuring proper synchronization
mutable std::optional<PhiloxGenerator> generator_;
};
class RandomNormalBase : public RandomBase {
protected:
RandomNormalBase(const OpKernelInfo& info) : RandomBase(info) {
ORT_THROW_IF_ERROR(info.GetAttr<float>("scale", &scale_));
ORT_THROW_IF_ERROR(info.GetAttr<float>("mean", &mean_));
}
Status ComputeNormal(const RocmKernel& rocm_kernel, OpKernelContext& ctx, const TensorShape& shape, int dtype) const;
private:
float scale_;
float mean_;
};
class RandomNormal final : public RocmKernel, protected RandomNormalBase {
public:
explicit RandomNormal(const OpKernelInfo& info) : RocmKernel(info), RandomNormalBase(info) {
SetDTypeIfUndefined(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
std::vector<int64_t> shape;
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("shape", shape));
shape_ = TensorShape(shape);
}
Status ComputeInternal(OpKernelContext* p_ctx) const override {
return ComputeNormal(*this, *p_ctx, shape_, GetDType());
}
private:
TensorShape shape_;
};
class RandomNormalLike final : public RocmKernel, protected RandomNormalBase {
public:
explicit RandomNormalLike(const OpKernelInfo& info) : RocmKernel(info), RandomNormalBase(info) {}
Status ComputeInternal(OpKernelContext* p_ctx) const override;
};
class RandomUniformBase : public RandomBase {
protected:
explicit RandomUniformBase(const OpKernelInfo& info) : RandomBase(info) {
float low, high;
ORT_THROW_IF_ERROR(info.GetAttr<float>("low", &low));
ORT_THROW_IF_ERROR(info.GetAttr<float>("high", &high));
from_ = low;
range_ = high - low;
}
Status ComputeUniform(const RocmKernel& rocm_kernel, OpKernelContext& ctx, const TensorShape& shape, int dtype) const;
private:
float range_;
float from_;
};
class RandomUniform final : public RocmKernel, protected RandomUniformBase {
public:
explicit RandomUniform(const OpKernelInfo& info) : RocmKernel(info), RandomUniformBase(info) {
SetDTypeIfUndefined(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
std::vector<int64_t> shape;
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("shape", shape));
shape_ = TensorShape(shape);
}
Status ComputeInternal(OpKernelContext* p_ctx) const override {
return ComputeUniform(*this, *p_ctx, shape_, GetDType());
}
private:
TensorShape shape_;
};
class RandomUniformLike final : public RocmKernel, protected RandomUniformBase {
public:
explicit RandomUniformLike(const OpKernelInfo& info) : RocmKernel(info), RandomUniformBase(info) {}
Status ComputeInternal(OpKernelContext* p_ctx) const override;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/generator/random_impl.h"
#include <hiprand_kernel.h>
#include <algorithm>
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace rocm {
constexpr int UNROLL = 4;
struct DistFunc_RandomNormal {
__device__ __inline__ float4 operator()(hiprandStatePhilox4_32_10_t* state) const { return hiprand_normal4(state); }
};
struct DistFunc_RandomUniform {
__device__ __inline__ float4 operator()(hiprandStatePhilox4_32_10_t* state) const { return hiprand_uniform4(state); }
};
struct TransformFunc_RandomNormal {
__device__ __inline__ float operator()(const float value, const float scale, const float mean) const {
return value * scale + mean;
}
};
struct TransformFunc_RandomUniform {
__device__ __inline__ float operator()(const float value, const float range, const float from) const {
// reverse the bounds of hiprand4 from (0, 1] to [0, 1).
// ref: https://github.com/pytorch/pytorch/blob/e795315c638228d4170f3797356c09a70b2ed4cd/aten/src/ATen/native/rocm/DistributionTemplates.h#L464
float reverse_bound_value = value == 1.0f ? 0.0f : value;
return reverse_bound_value * range + from;
}
};
template <typename T, typename DistFuncT, typename TransformFuncT>
__global__ void RandomKernel(const int64_t N, const std::pair<uint64_t, uint64_t> seeds, const DistFuncT& dist_func,
const TransformFuncT& transform_func, const float alpha, const float beta, T* Y_data) {
HIP_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
HIP_LONG step_size = gridDim.x * blockDim.x * UNROLL;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seeds.first, idx, seeds.second, &state);
float4 rand;
// We ensure every thread generates the same number of random numbers (by rounding
// up the size) and at the same timestep (by syncing threads).
// From ROCM hiprand documentation:
// The Philox_4x32_10 algorithm is closely tied to the thread and block count.
// Each thread computes 4 random numbers in the same time thus the most efficient
// use of Philox_4x32_10 is to generate a multiple of 4 times number of threads.
for (HIP_LONG id = idx * UNROLL; id < N; id += step_size) {
rand = dist_func(&state);
// actual computation
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
HIP_LONG li = id + i;
if (li < N) {
Y_data[li] = static_cast<T>(transform_func((&rand.x)[i], alpha, beta));
}
}
__syncthreads();
}
}
template <typename T, typename DistFuncT, typename TransformFuncT>
__global__ void RandomVectorizedKernel(const int64_t N, const std::pair<uint64_t, uint64_t> seeds,
const DistFuncT& dist_func, const TransformFuncT& transform_func,
const float alpha, const float beta, T* Y_data) {
HIP_LONG idx = blockDim.x * blockIdx.x + threadIdx.x;
HIP_LONG step_size = gridDim.x * blockDim.x * UNROLL;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seeds.first, idx, seeds.second, &state);
float4 rand;
// Using vectorized data load/store approach when N % 4 == 0 since this is typical case for input shape size.
using LoadT = aligned_vector<T, UNROLL>;
for (HIP_LONG id = idx * UNROLL; id < N; id += step_size) {
rand = dist_func(&state);
T r[UNROLL];
// actual computation
#pragma unroll
for (int ii = 0; ii < UNROLL; ii++) {
r[ii] = static_cast<T>(transform_func((&rand.x)[ii], alpha, beta));
}
// Vectorized writes for Y_data
*(reinterpret_cast<LoadT*>(&Y_data[id])) = *reinterpret_cast<LoadT*>(&r[0]);
__syncthreads();
}
}
template <typename T, typename DistFuncT, typename TransformFuncT>
void RandomKernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const DistFuncT& dist_func,
const TransformFuncT& transform_func, float alpha, float beta, PhiloxGenerator& generator,
T* Y_data) {
const int block_size = 256;
const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / block_size;
const int grid_size =
std::min(prop.multiProcessorCount * blocks_per_sm, static_cast<int>(CeilDiv(N, block_size * UNROLL)));
// Compute the number of random numbers generated by each thread, and increment philox generator offset by that
// amount.
const uint64_t counter_offset = static_cast<uint64_t>(((N - 1) / (block_size * grid_size * UNROLL) + 1) * UNROLL);
auto seeds = generator.NextPhiloxSeeds(counter_offset);
if (N % UNROLL != 0) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomKernel<T>), grid_size, block_size, 0, stream, N, seeds, dist_func, transform_func, alpha, beta, Y_data);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomVectorizedKernel<T>), grid_size, block_size, 0, stream, N, seeds, dist_func, transform_func, alpha, beta, Y_data);
}
}
#define RANDOM_KERNEL_IMPL(name) \
template <typename T> \
void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data) { \
RandomKernelImpl(prop, stream, N, DistFunc_##name(), TransformFunc_##name(), alpha, beta, generator, Y_data); \
}
RANDOM_KERNEL_IMPL(RandomNormal)
RANDOM_KERNEL_IMPL(RandomUniform)
#define SPECIALIZED_RANDOM_KERNEL(name, T) \
template void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data);
#define SPECIALIZED_RANDOM_KERNELS(T) \
SPECIALIZED_RANDOM_KERNEL(RandomNormal, T) \
SPECIALIZED_RANDOM_KERNEL(RandomUniform, T)
SPECIALIZED_RANDOM_KERNELS(float)
SPECIALIZED_RANDOM_KERNELS(double)
SPECIALIZED_RANDOM_KERNELS(half)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/random_generator.h"
namespace onnxruntime {
namespace rocm {
#define RANDOM_KERNEL_DECLARE(name) \
template <typename T> \
void name##KernelImpl(const hipDeviceProp_t& prop, hipStream_t stream, const int64_t N, const float alpha, \
const float beta, PhiloxGenerator& generator, T* Y_data);
RANDOM_KERNEL_DECLARE(RandomNormal)
RANDOM_KERNEL_DECLARE(RandomUniform)
#undef RANDOM_KERNEL_DECLARE
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_common.h"
#include "range.h"
#include "range_impl.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
Range,
kOnnxDomain,
11,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // start
.InputMemoryType(OrtMemTypeCPUInput, 1) // limit
.InputMemoryType(OrtMemTypeCPUInput, 2) // delta
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
Range);
template <typename T>
static Status ComputeRange(hipStream_t stream, OpKernelContext* ctx) {
const auto& start_tensor = *ctx->Input<Tensor>(0);
const auto& limit_tensor = *ctx->Input<Tensor>(1);
const auto* delta_tensor_ptr = ctx->Input<Tensor>(2);
if (!start_tensor.Shape().IsScalar()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"start in Range operator should be scalar like tensor, yet got shape:",
start_tensor.Shape());
}
if (!limit_tensor.Shape().IsScalar()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"limit in Range operator should be scalar like tensor, yet got shape:",
limit_tensor.Shape());
}
if (delta_tensor_ptr != nullptr && !delta_tensor_ptr->Shape().IsScalar()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"delta in Range operator should be scalar like tensor, yet got shape:",
delta_tensor_ptr->Shape());
}
// Start, Limit and Delta are stored in CPU.
T start = *(start_tensor.Data<T>());
T limit = *(limit_tensor.Data<T>());
T delta = T(1);
if (delta_tensor_ptr != nullptr) {
delta = *(delta_tensor_ptr->Data<T>());
}
if (delta == T(0)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "delta in Range operator can not be zero!");
}
double num = (static_cast<double>(limit) - static_cast<double>(start)) / static_cast<double>(delta);
int count = static_cast<int>(ceil(num));
if (count <= 0)
count = 0;
TensorShape shape = {static_cast<int64_t>(count)};
T* y = ctx->Output(0, shape)->MutableData<T>();
if (count > 0) {
ORT_RETURN_IF_ERROR(RangeImpl(stream, start, delta, count, y));
}
return Status::OK();
}
namespace rocm_range_internal {
template <class T>
struct CallCudaRangeImpl {
Status operator()(hipStream_t stream, OpKernelContext* ctx) const {
return ComputeRange<T>(stream, ctx);
}
};
} // namespace rocm_range_internal
Status Range::ComputeInternal(OpKernelContext* ctx) const {
const auto* input_tensor = ctx->Input<Tensor>(0);
if (input_tensor == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
}
utils::MLTypeCallDispatcher<int32_t, float, int64_t, double, int16_t>
t_disp(input_tensor->GetElementType());
return t_disp.InvokeRet<Status, rocm_range_internal::CallCudaRangeImpl>(Stream(), ctx);
}
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
class Range final : public RocmKernel {
public:
explicit Range(const OpKernelInfo& info) : RocmKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "range_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace rocm {
template <typename T>
__global__ void RangeKernel(const T start, const T delta, const int count, T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < count) {
output[index] = start + delta * index;
}
}
template <typename T>
Status RangeImpl(hipStream_t stream, const T start, const T delta, const int count, T* output) {
constexpr int block_size = 256;
int grid_size = (count + block_size - 1) / block_size;
hipLaunchKernelGGL(HIP_KERNEL_NAME(RangeKernel<T>), grid_size, block_size, 0, stream, start, delta, count, output);
return HIP_CALL(hipGetLastError());
}
#define SPECIALIZED_IMPL(T) \
template Status RangeImpl<T>(hipStream_t stream, const T start, const T delta, const int count, T* output);
SPECIALIZED_IMPL(int16_t)
SPECIALIZED_IMPL(int32_t)
SPECIALIZED_IMPL(int64_t)
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
Status RangeImpl(hipStream_t stream, const T start, const T delta, const int count, T* output);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"
using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {
template <>
Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context, BinaryElementwisePreparation* p) const {
p->lhs_tensor = context->Input<Tensor>(0);
p->rhs_tensor = context->Input<Tensor>(1);
if (!(p->lhs_tensor->Shape() == p->rhs_tensor->Shape()))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, Node().Name(), ": mismatching input shapes: ",
p->lhs_tensor->Shape().ToString(), " != ", p->rhs_tensor->Shape().ToString());
p->output_tensor = context->Output(0, p->lhs_tensor->Shape());
p->output_rank_or_simple_broadcast = static_cast<int32_t>(SimpleBroadcast::NoBroadcast);
return Status::OK();
}
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
std::vector<int64_t> output_dims(out_rank, 0);
for (size_t i = 0; i < out_rank; ++i) {
int64_t lhs_dim = 1;
if (i < lhs_rank)
lhs_dim = lhs_shape[lhs_rank - 1 - i];
int64_t rhs_dim = 1;
if (i < rhs_rank)
rhs_dim = rhs_shape[rhs_rank - 1 - i];
int64_t max = std::max(lhs_dim, rhs_dim);
int64_t min = std::min(lhs_dim, rhs_dim);
int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0.
if (lhs_dim != out_dim && lhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
if (rhs_dim != out_dim && rhs_dim != 1)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i,
" LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString());
output_dims[out_rank - 1 - i] = out_dim;
}
out_shape = TensorShape(output_dims);
return Status::OK();
}
Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Tensor* output_tensor,
BinaryElementwisePreparation* p,
const TensorShape* override_lhs_shape,
const TensorShape* override_rhs_shape) {
p->lhs_tensor = lhs_tensor;
p->rhs_tensor = rhs_tensor;
const auto& lhs_shape = override_lhs_shape ? *override_lhs_shape : lhs_tensor->Shape();
const auto& rhs_shape = override_rhs_shape ? *override_rhs_shape : rhs_tensor->Shape();
p->output_tensor = output_tensor;
const auto& output_shape = output_tensor->Shape();
ORT_RETURN_IF_ERROR(p->BinaryElementwiseBroadcastPrepareHelper(lhs_shape, rhs_shape, output_shape));
return Status::OK();
}
template <>
Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, BinaryElementwisePreparation* p) const {
auto lhs_tensor = context->Input<Tensor>(0);
auto rhs_tensor = context->Input<Tensor>(1);
const auto& lhs_shape = lhs_tensor->Shape();
const auto& rhs_shape = rhs_tensor->Shape();
TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context->Output(0, output_shape);
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(lhs_tensor, rhs_tensor, output_tensor, p));
return Status::OK();
}
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED_V(x, class_name, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
class_name<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(x, ver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED_V(x, x, ver, T)
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_NONTEMP(x, class_name, ver, ...) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<>(__VAR_ARGS__)), \
class_name);
#define BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED_CLASS(x, class_name, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
class_name<T>);
#define BINARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()), \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.rhs_tensor->Data<T>()), \
&prepare.fdm_output_strides, \
prepare.fdm_H, \
prepare.fdm_C, \
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()), \
prepare.output_tensor->Shape().Size()); \
return Status::OK(); \
}
#define BINARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, T)
#define BINARY_OP_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_COMPUTE(name, T)
#define BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, T) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED_CLASS(name, class_name, startver, endver, T) \
BINARY_ELEMENTWISE_COMPUTE(class_name, T)
#define BINARY_LOGICALOP_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
BINARY_ELEMENTWISE_COMPUTE(name, T)
// since different ops has different types, we cannot use BINARY_OPS() directly
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, double)
#define BINARY_OP_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_VERSIONED_HFD(name, startver, endver)
#define BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(name, startver, endver) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_VERSIONED_HFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_HFD(name, ver) \
BINARY_OP_TYPED(name, ver, MLFloat16) \
BINARY_OP_TYPED(name, ver, float) \
BINARY_OP_TYPED(name, ver, double) \
BINARY_OP_TYPED(name, ver, BFloat16)
#define BINARY_OP_UZILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
BINARY_OP_HFD(name, ver)
#define BINARY_OP_REGISTER_VERSIONED_OIL(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, bool) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t)
#define BINARY_LOGICALOP_REGISTER_OIL(name, ver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, bool) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t)
#define BINARY_OP_REGISTER_HFD(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, float) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, double) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, BFloat16)
#define BINARY_OP_REGISTER_UZILHFD(name, ver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, uint32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, uint64_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, int64_t) \
BINARY_OP_REGISTER_HFD(name, ver)
#define BINARY_LOGICALOP_REGISTER_UZILHFD(name, ver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, uint64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, int64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, MLFloat16) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, float) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, double) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, BFloat16)
#define BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, double) \
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_HFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, float) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, double) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(name, class_name, startver, endver) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, MLFloat16) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, float) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, double) \
BINARY_OP_TYPED_VERSIONED_V(name, class_name, startver, endver, BFloat16)
#define BINARY_OP_REGISTER_VERSIONED_UZILHFD(name, startver, endver) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, uint64_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int32_t) \
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(name, startver, endver, int64_t) \
BINARY_OP_REGISTER_VERSIONED_HFD(name, startver, endver)
BINARY_OP_VERSIONED_UZILHFD(Add, 7, 12)
BINARY_OP_VERSIONED_UZILHFD(Sub, 7, 12)
BINARY_OP_VERSIONED_UZILHFD(Mul, 7, 12)
BINARY_OP_VERSIONED_UZILHFD(Div, 7, 12)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Add, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Sub, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Mul, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Div, 13, 13)
BINARY_OP_UZILHFD(Add, 14)
BINARY_OP_UZILHFD(Sub, 14)
BINARY_OP_UZILHFD(Mul, 14)
BINARY_OP_UZILHFD(Div, 14)
BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
BINARY_LOGICALOP_TYPED(And, 7, bool)
BINARY_LOGICALOP_TYPED(Or, 7, bool)
BINARY_LOGICALOP_TYPED(Xor, 7, bool)
BINARY_OP_VERSIONED_HFD(PRelu, 7, 8)
BINARY_OP_VERSIONED_HFD(PRelu, 9, 15)
// Opset-16 adds BFloat16 to allowed types for the PRelu operator
BINARY_OP_HFD(PRelu, 16)
// Pow since version 12
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Pow,
kOnnxDomain,
12, 12,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Pow,
kOnnxDomain,
13, 14,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
ONNX_OPERATOR_KERNEL_EX(
Pow,
kOnnxDomain,
15,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t, float, double, MLFloat16>()),
Pow);
namespace pow12_internal {
template <class T>
Status DispatchOnFirstArg(hipStream_t stream, const BinaryElementwisePreparation& prepare) {
namespace on = ONNX_NAMESPACE;
Status s;
switch (prepare.rhs_tensor->GetElementType()) {
case on::TensorProto_DataType_INT32:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<int32_t>::MappedType>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<int32_t>::MappedType*>(prepare.rhs_tensor->Data<int32_t>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
case on::TensorProto_DataType_INT64:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<int64_t>::MappedType>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<int64_t>::MappedType*>(prepare.rhs_tensor->Data<int64_t>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
case on::TensorProto_DataType_FLOAT:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<float>::MappedType>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<float>::MappedType*>(prepare.rhs_tensor->Data<float>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
case on::TensorProto_DataType_DOUBLE:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<double>::MappedType>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<double>::MappedType*>(prepare.rhs_tensor->Data<double>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
case on::TensorProto_DataType_FLOAT16:
ImplT1_Pow<typename ToHipType<T>::MappedType, typename ToHipType<MLFloat16>::MappedType>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const typename ToHipType<MLFloat16>::MappedType*>(prepare.rhs_tensor->Data<MLFloat16>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
break;
default:
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported Y type: ",
DataTypeImpl::ToString(prepare.rhs_tensor->DataType()));
}
return s;
}
} // namespace pow12_internal
Status Pow::ComputeInternal(OpKernelContext* context) const {
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(Prepare(context, &prepare));
namespace on = ONNX_NAMESPACE;
using namespace pow12_internal;
Status s;
switch (prepare.lhs_tensor->GetElementType()) {
case on::TensorProto_DataType_INT32:
s = DispatchOnFirstArg<int32_t>(Stream(), prepare);
break;
case on::TensorProto_DataType_INT64:
s = DispatchOnFirstArg<int64_t>(Stream(), prepare);
break;
case on::TensorProto_DataType_FLOAT:
s = DispatchOnFirstArg<float>(Stream(), prepare);
break;
case on::TensorProto_DataType_DOUBLE:
s = DispatchOnFirstArg<double>(Stream(), prepare);
break;
case on::TensorProto_DataType_FLOAT16:
s = DispatchOnFirstArg<MLFloat16>(Stream(), prepare);
break;
default:
s = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported X type: ",
DataTypeImpl::ToString(prepare.lhs_tensor->DataType()));
}
return s;
}
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Mod, kOnnxDomain, 10, 12, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T",
BuildKernelDefConstraints<int32_t, int64_t, uint32_t, uint64_t, float, double, MLFloat16>()),
Mod);
ONNX_OPERATOR_KERNEL_EX(Mod, kOnnxDomain, 13, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<int32_t, int64_t, uint32_t, uint64_t, float,
double, MLFloat16, BFloat16>()),
Mod);
Status Mod::ComputeInternal(OpKernelContext* context) const {
namespace on = ONNX_NAMESPACE;
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(Prepare(context, &prepare));
auto element_type = prepare.lhs_tensor->GetElementType();
ORT_ENFORCE(fmod_ || element_type == on::TensorProto_DataType_INT32 ||
element_type == on::TensorProto_DataType_INT64 || element_type == on::TensorProto_DataType_UINT32 ||
element_type == on::TensorProto_DataType_UINT64,
"Non-fmod can support integer types only.");
#define CASE_MOD_ELEMENT_TYPE(name, onnx_type, data_type) \
case onnx_type: { \
Impl_##name<typename ToHipType<data_type>::MappedType>( \
Stream(), prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToHipType<data_type>::MappedType*>(prepare.lhs_tensor->Data<data_type>()), \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToHipType<data_type>::MappedType*>(prepare.rhs_tensor->Data<data_type>()), \
&prepare.fdm_output_strides, prepare.fdm_H, prepare.fdm_C, \
reinterpret_cast<typename ToHipType<data_type>::MappedType*>( \
prepare.output_tensor->MutableData<data_type>()), \
prepare.output_tensor->Shape().Size()); \
} break
if (fmod_) {
switch (element_type) {
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_INT32, int32_t);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_INT64, int64_t);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_UINT32, uint32_t);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_UINT64, uint64_t);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_FLOAT, float);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_DOUBLE, double);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_FLOAT16, MLFloat16);
CASE_MOD_ELEMENT_TYPE(Fmod, on::TensorProto_DataType_BFLOAT16, BFloat16);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Unsupported element type: ", DataTypeImpl::ToString(prepare.lhs_tensor->DataType()));
}
} else {
switch (element_type) {
CASE_MOD_ELEMENT_TYPE(Mod, on::TensorProto_DataType_INT32, int32_t);
CASE_MOD_ELEMENT_TYPE(Mod, on::TensorProto_DataType_INT64, int64_t);
CASE_MOD_ELEMENT_TYPE(Mod, on::TensorProto_DataType_UINT32, uint32_t);
CASE_MOD_ELEMENT_TYPE(Mod, on::TensorProto_DataType_UINT64, uint64_t);
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Unsupported element type: ", DataTypeImpl::ToString(prepare.lhs_tensor->DataType()));
}
}
#undef CASE_MOD_ELEMENT_TYPE
return Status::OK();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T, typename HipT>
Status CompareFunction<T, HipT>::CompareMethod(OpKernelContext* context, ImplCompare Impl_Compare) const {
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(Prepare(context, &prepare));
Impl_Compare(
Stream(),
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<ToHipType<bool>::MappedType*>(prepare.output_tensor->MutableData<bool>()),
prepare.output_tensor->Shape().Size());
return Status::OK();
}
//Greater op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status Greater<T>::ComputeInternal(OpKernelContext* context) const {
return this->CompareMethod(context, &ImplT2_Greater);
}
template <typename T>
Status Equal<T>::ComputeInternal(OpKernelContext* context) const {
return this->CompareMethod(context, &ImplT2_Equal);
}
//Less op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status Less<T>::ComputeInternal(OpKernelContext* context) const {
return this->CompareMethod(context, &ImplT2_Less);
}
//GreaterOrEqual op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status GreaterOrEqual<T>::ComputeInternal(OpKernelContext* context) const {
return this->CompareMethod(context, &ImplT2_GreaterOrEqual);
}
//LessOrEqual op output tensor type is bool, so it cannot directly fit in the macros
//for other elementwise ops
template <typename T>
Status LessOrEqual<T>::ComputeInternal(OpKernelContext* context) const {
return this->CompareMethod(context, &ImplT2_LessOrEqual);
}
BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13)
BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 13, bool)
BINARY_OP_REGISTER_VERSIONED_UZILHFD(Equal, 11, 12)
BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 11, 12, bool)
BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 7, 10)
BINARY_LOGICALOP_REGISTER_UZILHFD(Greater, 13)
BINARY_OP_REGISTER_VERSIONED_UZILHFD(Greater, 9, 12)
BINARY_OP_REGISTER_VERSIONED_HFD(Greater, 7, 8)
BINARY_LOGICALOP_REGISTER_UZILHFD(Less, 13)
BINARY_OP_REGISTER_VERSIONED_UZILHFD(Less, 9, 12)
BINARY_OP_REGISTER_VERSIONED_HFD(Less, 7, 8)
BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(GreaterOrEqual, 12, 15)
BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(LessOrEqual, 12, 15)
// Opset-16 adds BFloat16 to allowed types for the GreaterOrEqual operator
BINARY_LOGICALOP_REGISTER_UZILHFD(GreaterOrEqual, 16)
// Opset-16 adds BFloat16 to allowed types for the LessOrEqual operator
BINARY_LOGICALOP_REGISTER_UZILHFD(LessOrEqual, 16)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/shared_inc/fast_divmod.h"
#include "core/providers/cpu/tensor/utils.h"
namespace onnxruntime {
namespace rocm {
struct BinaryElementwisePreparation {
const Tensor* lhs_tensor = nullptr;
const Tensor* rhs_tensor = nullptr;
Tensor* output_tensor = nullptr;
int32_t output_rank_or_simple_broadcast = 0; // for no_broadcast|left_scalar|right_scalar cases, output_rank uses SimpleBroadcast enums
TArray<int64_t> lhs_padded_strides;
TArray<int64_t> rhs_padded_strides;
TArray<fast_divmod> fdm_output_strides;
// these are for RightPerChannel case
fast_divmod fdm_H;
fast_divmod fdm_C;
BinaryElementwisePreparation() {}
Status BinaryElementwiseBroadcastPrepareHelper(const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
const TensorShape& output_shape) {
int32_t lhs_rank = gsl::narrow_cast<int32_t>(lhs_shape.NumDimensions());
int32_t rhs_rank = gsl::narrow_cast<int32_t>(rhs_shape.NumDimensions());
int32_t out_rank = std::max(lhs_rank, rhs_rank);
// early return when shapes match
if (lhs_shape == rhs_shape) {
output_rank_or_simple_broadcast = static_cast<int32_t>(SimpleBroadcast::NoBroadcast);
return Status::OK();
}
// early return if one operand is scalar
if (lhs_shape.Size() == 1 || rhs_shape.Size() == 1) {
output_rank_or_simple_broadcast = static_cast<int32_t>(lhs_shape.Size() == 1
? SimpleBroadcast::LeftScalar
: SimpleBroadcast::RightScalar);
return Status::OK();
}
// special case for lhs(N,C,H) and rhs (C,1) which is used in conv bias
// when N == 1: out[id] = op(lhs[id], rhs[id / H])
// When N > 1: out[id] = op(lhs[id], rhs[id / H % C])
if (lhs_shape == output_shape) {
const auto& rhs_dims = rhs_shape.GetDims();
int64_t C = 0;
if (1 == std::count_if(rhs_dims.begin(), rhs_dims.end(),
[&C](int64_t dim) { if (dim != 1) C = dim; return (dim != 1); })) {
int32_t dim_C = gsl::narrow_cast<int32_t>(std::find(rhs_dims.begin(), rhs_dims.end(), C) - rhs_dims.begin() + output_shape.NumDimensions() - rhs_shape.NumDimensions());
int64_t N = output_shape.SizeToDimension(dim_C);
int64_t H = (dim_C < out_rank - 1 ? output_shape.SizeFromDimension(static_cast<size_t>(dim_C) + 1) : 1);
std::vector<int64_t> new_output_dims;
if (N == 1) {
output_rank_or_simple_broadcast = static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatch1);
fdm_H = fast_divmod(gsl::narrow_cast<int>(H));
} else {
output_rank_or_simple_broadcast = static_cast<int32_t>(SimpleBroadcast::RightPerChannelBatchN);
fdm_H = fast_divmod(gsl::narrow_cast<int>(H));
fdm_C = fast_divmod(gsl::narrow_cast<int>(C));
}
return Status::OK();
}
}
output_rank_or_simple_broadcast = out_rank;
if (lhs_shape != output_shape) {
TensorPitches original_lhs_padded_strides(lhs_shape.GetDims(), out_rank);
lhs_padded_strides.SetSize(out_rank);
auto offset = out_rank - lhs_rank;
for (auto i = offset; i < out_rank; ++i) {
// the stride for broadcast dimension is kept as 0
if (lhs_shape.GetDims()[static_cast<size_t>(i) - offset] != 1) {
lhs_padded_strides[i] = original_lhs_padded_strides[i];
}
}
}
if (rhs_shape != output_shape) {
TensorPitches original_rhs_padded_strides(rhs_shape.GetDims(), out_rank);
rhs_padded_strides.SetSize(out_rank);
auto offset = out_rank - rhs_rank;
for (auto i = offset; i < out_rank; ++i) {
// the stride for broadcast dimension is kept as 0
if (rhs_shape.GetDims()[static_cast<size_t>(i) - offset] != 1) {
rhs_padded_strides[i] = original_rhs_padded_strides[i];
}
}
}
TensorPitches original_output_strides(output_shape.GetDims());
fdm_output_strides.SetSize(out_rank);
for (auto i = 0; i < out_rank; ++i) {
fdm_output_strides[i] = fast_divmod(gsl::narrow_cast<int>(original_output_strides[i]));
}
return Status::OK();
}
};
Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);
Status BinaryElementwiseBroadcastPrepare(
const Tensor* lhs_tensor,
const Tensor* rhs_tensor,
Tensor* output_tensor,
BinaryElementwisePreparation* p,
const TensorShape* override_lhs_shape = nullptr,
const TensorShape* override_rhs_shape = nullptr);
// trait classes to indicate if the kernel supports broadcast
class ShouldBroadcast {
};
class ShouldNotBroadcast {
};
template <typename BroadcastTrait>
class BinaryElementwise : public RocmKernel {
protected:
typedef BroadcastTrait broadcast_type;
BinaryElementwise(const OpKernelInfo& info) : RocmKernel(info) {}
Status ComputeInternal(OpKernelContext*) const override {
return Status(common::ONNXRUNTIME, common::FAIL); // should not reach here
}
Status Prepare(OpKernelContext* context, BinaryElementwisePreparation* p) const;
};
template <typename T>
class Add final : public BinaryElementwise<ShouldBroadcast> {
public:
Add(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Sub final : public BinaryElementwise<ShouldBroadcast> {
public:
Sub(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Mul final : public BinaryElementwise<ShouldBroadcast> {
public:
Mul(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Div final : public BinaryElementwise<ShouldBroadcast> {
public:
Div(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Pow_7 final : public BinaryElementwise<ShouldBroadcast> {
public:
Pow_7(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
// Since version 12
class Pow final : public BinaryElementwise<ShouldBroadcast> {
public:
Pow(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class And final : public BinaryElementwise<ShouldBroadcast> {
public:
And(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Or final : public BinaryElementwise<ShouldBroadcast> {
public:
Or(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Xor final : public BinaryElementwise<ShouldBroadcast> {
public:
Xor(const OpKernelInfo& info) : BinaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
// PRelu is activation function, but it's closer to binary elementwise ops in implementation
template <typename T>
class PRelu final : public BinaryElementwise<ShouldBroadcast> {
public:
PRelu(const OpKernelInfo& info) : BinaryElementwise(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
class Mod final : public BinaryElementwise<ShouldBroadcast> {
public:
Mod(const OpKernelInfo& info) : BinaryElementwise(info) {
int64_t fmod = info.GetAttrOrDefault<int64_t>("fmod", 0LL);
fmod_ = fmod != 0;
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool fmod_{false};
};
template <typename T, typename HipT>
class CompareFunction : public BinaryElementwise<ShouldBroadcast> {
public:
CompareFunction(const OpKernelInfo& info) : BinaryElementwise(info) {}
typedef void (*ImplCompare)(hipStream_t stream,
int32_t output_rank_or_simple_broadcast,
const TArray<int64_t>* lhs_padded_strides,
const HipT* lhs_data,
const TArray<int64_t>* rhs_padded_strides,
const HipT* rhs_data,
const TArray<fast_divmod>* fdm_output_strides,
const fast_divmod& fdm_H,
const fast_divmod& fdm_C,
bool* output_data,
size_t count);
Status CompareMethod(OpKernelContext* context, ImplCompare Impl_Compare) const;
};
template <typename T>
class Greater final : public CompareFunction<T, typename ToHipType<T>::MappedType> {
public:
Greater(const OpKernelInfo& info) : CompareFunction<T, typename ToHipType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Equal final : public CompareFunction<T, typename ToHipType<T>::MappedType> {
public:
Equal(const OpKernelInfo& info) : CompareFunction<T, typename ToHipType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Less final : public CompareFunction<T, typename ToHipType<T>::MappedType> {
public:
Less(const OpKernelInfo& info) : CompareFunction<T, typename ToHipType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class GreaterOrEqual final : public CompareFunction<T, typename ToHipType<T>::MappedType> {
public:
GreaterOrEqual(const OpKernelInfo& info) : CompareFunction<T, typename ToHipType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class LessOrEqual final : public CompareFunction<T, typename ToHipType<T>::MappedType> {
public:
LessOrEqual(const OpKernelInfo& info) : CompareFunction<T, typename ToHipType<T>::MappedType>(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
namespace onnxruntime {
namespace rocm {
#define BINARY_ELEMENTWISE_IMPL(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T, T>(), \
count); \
}
#define BINARY_ELEMENTWISE_IMPL_T1(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION_T1(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T, T1>(), \
count); \
}
#define BINARY_ELEMENTWISE_IMPL_T2(name) \
BINARY_ELEMENTWISE_IMPL_DECLARATION_T2(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T, T1, T2>(), \
count); \
}
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, T) \
template void Impl_##x<T>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, T1) \
template void ImplT1_##x<T, T1>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T1* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(x, T, T1, T2) \
template void ImplT2_##x<T, T1, T2>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, const T1* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, const T2* rhs_data, \
const TArray<fast_divmod>* fdm_output_strides, const fast_divmod& fdm_H, const fast_divmod& fdm_C, T* output_data, size_t count);
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(x, T) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1(x, T, double)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
// create declarations for impl
#define BINARY_OP_NAME_EXPR(name, expr) \
BINARY_ELEMENTWISE_IMPL(name)
BINARY_OPS()
#undef BINARY_OP_NAME_EXPR
// create specialized impl
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Pow_7)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(And, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Xor, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(PRelu)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Max)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Min)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(Mod)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Fmod)
// create declarations for impl for Pow
BINARY_ELEMENTWISE_IMPL_T1(Pow)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int32_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, int64_t)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, float)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, double)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T1_ILHFD(Pow, half)
// create declarations for impl2
#define BINARY_OP_NAME_EXPR2(name, expr) \
BINARY_ELEMENTWISE_IMPL_T2(name)
BINARY_OPS2()
#undef BINARY_OP_NAME_EXPR2
#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(name) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, uint32_t, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, uint64_t, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, int32_t, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, int64_t, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, half, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, float, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, double, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(name, bool, BFloat16, BFloat16)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Greater)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Equal)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_T2(Equal, bool, bool, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(Less)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(GreaterOrEqual)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD2(LessOrEqual)
} // namespace rocm
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment