Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
template<bool inputk>
class TopK final : public RocmKernel {
public:
TopK(const OpKernelInfo&);
Status ComputeInternal(OpKernelContext*) const override;
private:
int64_t axis_;
int64_t largest_;
int64_t sorted_;
mutable int64_t K_;
};
} // namespace rocm
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "topk_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "hipcub/hipcub.hpp"
#include <hipcub/backend/rocprim/util_type.hpp>
#include <hipcub/util_allocator.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/backend/rocprim/device/device_radix_sort.hpp>
#include <limits>
//TODO:fix the warnings
#ifdef _MSC_VER
#pragma warning(disable : 4244)
#endif
namespace onnxruntime {
namespace rocm {
using namespace hipcub;
template <typename T>
struct KV {
T key;
int64_t val;
};
#define BT GridDim::maxThreadsPerBlock
#define ALIGN(N) static_cast<int64_t>(pow(2, ceil(log2(static_cast<double>(N)))))
#define FROM(idx) (left_dim + (idx)*mid_dim + right_dim)
#define TO(idx) (left_dim * K / dimension + (idx)*mid_dim + right_dim)
#define TRIVIAL (1 == largest ? type_min : type_max)
#define BIGGER(n, m) (n.key > m.key ? n : (n.key < m.key ? m : (n.val > m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define SMALLER(n, m) (n.key < m.key ? n : (n.key > m.key ? m : (n.val < m.val ? (1 == largest ? m : n) : (1 == largest ? n : m))))
#define IS_SMALLER(n, m) (n.key < m.key || !(n.key > m.key) && (1 == largest ? n.val > m.val : n.val < m.val))
#define LESS(n, m) ((n) <= (m) ? (n) : (m))
template <typename T>
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
int64_t tid = threadIdx.x;
int64_t bid = blockIdx.x;
int64_t bdim = blockDim.x;
extern __shared__ char shared_mem[];
auto S = (KV<T>*)(shared_mem);
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
auto left_dim = bid / mid_dim * elem_nums[axis];
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
for (auto i = tid; i < aligned_dimension; i += bdim) {
S[i].key = i < dimension ? X[FROM(i)] : TRIVIAL;
S[i].val = i;
}
__syncthreads();
//sort each K
for (int64_t len = 1; len < aligned_K; len <<= 1) {
auto dir = len << 1;
for (auto inc = len; inc > 0; inc >>= 1) {
auto low = tid & (inc - 1);
auto i = (tid << 1) - low;
auto j = i + inc;
if (j < aligned_dimension) {
auto reverse = (dir & i) == 0;
auto swap = reverse ^ IS_SMALLER(S[i], S[j]);
if (swap) {
auto tmp = S[i];
S[i] = S[j];
S[j] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
//merge and rebuild K
for (int64_t len = aligned_K; len < aligned_dimension; len <<= 1) {
auto dir = len << 1;
auto i = (tid << 1) - (tid & (len - 1));
auto j = i + len;
if (i % dir < aligned_K && j < aligned_dimension) {
S[i] = 1 == largest ? BIGGER(S[i], S[j]) : SMALLER(S[i], S[j]);
}
__syncthreads();
for (auto inc = aligned_K >> 1; inc > 0; inc >>= 1) {
auto ii = (tid << 1) - (tid & (inc - 1));
auto jj = ii + inc;
if (ii % dir < aligned_K && jj < aligned_dimension) {
auto reverse = (dir & ii) == 0;
auto swap = reverse ^ IS_SMALLER(S[ii], S[jj]);
if (swap) {
auto tmp = S[ii];
S[ii] = S[jj];
S[jj] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
//save top K
if (1 == sorted) {
if (1 == largest) {
auto start = aligned_K - K;
if (tid >= start && tid < aligned_K) {
auto to = TO(aligned_K - 1 - tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
} else {
if (tid < K) {
auto to = TO(tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
}
} else {
if (1 == largest) {
auto start = aligned_K - K;
if (tid < start) {
S[tid].val = aligned_dimension;
}
} else {
if (tid >= K && tid < aligned_K) {
S[tid].val = aligned_dimension;
}
}
__syncthreads();
//sort by index ascending
for (int64_t len = 1; len < aligned_K; len <<= 1) {
auto dir = len << 1;
for (int64_t inc = len; inc > 0; inc >>= 1) {
auto low = tid & (inc - 1);
auto i = (tid << 1) - low;
auto j = i + inc;
if (j < aligned_K) {
auto reverse = (dir & i) == 0;
auto swap = reverse ^ (S[i].val < S[j].val);
if (swap) {
auto tmp = S[i];
S[i] = S[j];
S[j] = tmp;
}
}
__syncthreads();
}
__syncthreads();
}
if (tid < K) {
auto to = TO(tid);
V[to] = S[tid].key;
I[to] = S[tid].val;
}
}
}
template <typename T>
__device__ __forceinline__ bool Equal(const T& t0, const T& t1) {
return t0 == t1;
}
__device__ __forceinline__ bool Equal(const float& t0, const float& t1) {
return !(t0 > t1 || t1 > t0);
}
__device__ __forceinline__ bool Equal(const double& t0, const double& t1) {
return !(t0 > t1 || t1 > t0);
}
template<typename T>
__device__ __forceinline__ bool SamePrefix(const T* t0, const T* t1, int64_t skip) {
return ((*t0)^(*t1))>>skip == 0;
}
__device__ __forceinline__ bool SamePrefix(const half* f0, const half* f1, int64_t skip) {
return SamePrefix((const int16_t*)f0, (const int16_t*)f1, skip);
}
__device__ __forceinline__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) {
return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip);
}
__device__ __forceinline__ bool SamePrefix(const double* d0, const double* d1, int64_t skip) {
return SamePrefix((const int64_t*)d0, (const int64_t*)d1, skip);
}
template<typename T>
__device__ __forceinline__ int32_t Radix(const T* t, int64_t skip) {
return ((*t)>>skip)&255;
}
__device__ __forceinline__ int32_t Radix(const half* f, int64_t skip) {
return Radix((const int16_t*)f, skip);
}
__device__ __forceinline__ int32_t Radix(const float* f, int64_t skip) {
return Radix((const int32_t*)f, skip);
}
__device__ __forceinline__ int32_t Radix(const double* d, int64_t skip) {
return Radix((const int64_t*)d, skip);
}
template<typename T>
__device__ void SetByte(T* t, int64_t byte) {
(*t) |= byte;
}
__device__ __forceinline__ void SetByte(half* f, int64_t byte) {
SetByte((int16_t*)f, byte);
}
__device__ __forceinline__ void SetByte(float* f, int64_t byte) {
SetByte((int32_t*)f, byte);
}
__device__ __forceinline__ void SetByte(double* d, int64_t byte) {
SetByte((int64_t*)d, byte);
}
template<typename T, int64_t THREADS, int64_t KPT>
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
auto tid = threadIdx.x;
auto bid = blockIdx.x;
extern __shared__ char shared_mem[];
auto H = (uint32_t*)shared_mem;
auto mid_dim = axis == size - 1 ? 1 : elem_nums[axis + 1];
auto left_dim = bid / mid_dim * elem_nums[axis];
auto right_dim = axis == size - 1 ? 0 : bid % elem_nums[axis + 1];
T Kth = (T)0, sign = (T)1;
typedef BlockScan<uint32_t, THREADS> BlockScan;
typedef BlockReduce<uint32_t, THREADS> BlockReduce;
typedef BlockRadixSort<T, THREADS, KPT, int64_t> BlockRadixSort;
__shared__ union {
typename BlockScan::TempStorage scan;
typename BlockReduce::TempStorage reduce;
typename BlockRadixSort::TempStorage sort;
} temp_storage;
uint32_t positive = 0, negative = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = X[FROM(x_i)];
if (x > (T)0) {
++positive;
} else if (x < (T)0) {
++negative;
}
}
__syncthreads();
positive = BlockReduce(temp_storage.reduce).Sum(positive);
__syncthreads();
negative = BlockReduce(temp_storage.reduce).Sum(negative);
if (0 == tid) {
H[0] = positive;
H[1] = negative;
}
__syncthreads();
positive = H[0];
negative = H[1];
if ((1 == largest && (K <= positive || dimension - K + 1 <= negative)) ||
(0 == largest && (K <= negative || dimension - K + 1 <= positive))) {
auto KK = K;
if (1 == largest) {
if (KK > positive) {
KK = dimension - KK + 1;
sign = (T)-1;
}
} else {
if (KK > negative) {
KK = dimension - KK + 1;
} else {
sign = (T)-1;
}
}
__syncthreads();
#pragma unroll
for (int64_t byte = sizeof(T)-1; byte > -1; --byte) {
if (tid < 256) H[tid] = 0;
__syncthreads();
auto skip = 8 * byte, prev_skip = 8 * (byte + 1);
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
T x = sign*X[FROM(x_i)];
if (x > (T)0 && (byte == sizeof(T) - 1 || SamePrefix(&x, &Kth, prev_skip))) {
atomicAdd(&H[Radix(&x, skip)], 1);
}
}
__syncthreads();
for (int64_t radix = 255; radix > 0; --radix) {
if (H[radix] < KK) {
KK -= H[radix];
} else {
SetByte(&Kth, radix<<skip);
break;
}
}
__syncthreads();
}
Kth *= sign;
}
uint32_t superior = 0, equal = 0;
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
auto x = X[FROM(x_i)];
if ((1 == largest && x > Kth) || (0 == largest && x < Kth)) {
++superior;
} else if (Equal(x, Kth)) {
++equal;
}
}
__syncthreads();
auto all_superior = superior;
all_superior = BlockReduce(temp_storage.reduce).Sum(all_superior);
if (0 == tid) {
H[0] = all_superior;
}
__syncthreads();
all_superior = H[0];
BlockScan(temp_storage.scan).ExclusiveSum(superior, superior);
__syncthreads();
BlockScan(temp_storage.scan).ExclusiveSum(equal, equal);
__syncthreads();
auto equal_quota = K - all_superior - equal;
auto output_i = superior + LESS(K - all_superior, equal);
for (int64_t x_i = tid; x_i < dimension; x_i += blockDim.x) {
auto x = X[FROM(x_i)];
if ((1 == largest && x > Kth) || (0 == largest && x < Kth)) {
auto to_i = TO(output_i);
V[to_i] = x;
I[to_i] = x_i;
++output_i;
} else if (Equal(x, Kth) && equal_quota > 0) {
auto to_i = TO(output_i);
V[to_i] = x;
I[to_i] = x_i;
++output_i;
--equal_quota;
}
}
__syncthreads();
if (1 == sorted) {
T keys[KPT];
int64_t vals[KPT];
for (int64_t k_i = tid, k_c = 0; k_c < KPT; k_i += blockDim.x, ++k_c) {
if (k_i < K) {
auto to_i = TO(k_i);
keys[k_c] = V[to_i];
vals[k_c] = I[to_i];
} else {
if (1 == largest) {
keys[k_c] = type_min;
} else {
keys[k_c] = type_max;
}
}
}
__syncthreads();
if (1 == largest) {
BlockRadixSort(temp_storage.sort).SortDescending(keys, vals);
} else {
BlockRadixSort(temp_storage.sort).Sort(keys, vals);
}
__syncthreads();
#pragma unroll
for (int64_t k_c = 0; k_c < KPT; ++k_c) {
auto k_i = tid * KPT + k_c;
if (k_i < K) {
auto to_i = TO(k_i);
V[to_i] = keys[k_c];
I[to_i] = vals[k_c];
}
}
}
}
template <typename T>
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis];
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
auto input_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
output_v[id] = input_x[input_offset];
output_i[id] = id;
}
template <typename T>
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, K);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis] * K / dimension;
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
auto output_offset = left + id * (axis == size - 1 ? 1 : elem_nums[axis + 1]) + right;
output_v[output_offset] = input_v[id];
output_i[output_offset] = input_i[id];
}
// template is used to avoid linking issue, since __global__ function cannot be inline-ed
template <typename T>
__global__ void ExcludeOutput(T* output_i, T K, T dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
if (id >= K) {
output_i[id] = dimension;
}
}
template <typename T>
Status TopKImpl(const RocmKernel* kernel, hipStream_t stream, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
typedef typename ToHipType<T>::MappedType HipT;
const HipT* input_x_ptr = reinterpret_cast<const HipT*>(input_x);
HipT* output_v_ptr = reinterpret_cast<HipT*>(output_v);
auto aligned_K = ALIGN(K);
auto aligned_dimension = ALIGN(dimension);
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
BitonicTopK<HipT><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<HipT>), stream>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension, aligned_dimension, NumericLimits<T>::Min(), NumericLimits<T>::Max());
} else if (K <= BT*16 || 0 == sorted) {
auto XPT = static_cast<int64_t>(ceil(static_cast<double>(dimension) / GridDim::maxThreadsPerBlock));
if (BT*2 >= K || 0 == sorted) {
RadixTopK<HipT, BT, 2><<<N, BT, 256 * sizeof(uint32_t), stream>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, NumericLimits<T>::Min(), NumericLimits<T>::Max());
} else if (BT*4>=K) {
RadixTopK<HipT, BT, 4><<<N, BT, 256 * sizeof(uint32_t), stream>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, NumericLimits<T>::Min(), NumericLimits<T>::Max());
} else if (BT*8>=K) {
RadixTopK<HipT, BT, 8><<<N, BT, 256 * sizeof(uint32_t), stream>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, NumericLimits<T>::Min(), NumericLimits<T>::Max());
} else {
RadixTopK<HipT, BT, 16><<<N, BT, 256 * sizeof(uint32_t), stream>>>(input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT, NumericLimits<T>::Min(), NumericLimits<T>::Max());
}
} else {
auto input_key_buffer = kernel->GetScratchBuffer<HipT>(dimension);
auto output_key_buffer = kernel->GetScratchBuffer<HipT>(dimension);
auto input_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto output_value_buffer = kernel->GetScratchBuffer<int64_t>(dimension);
auto* input_key = input_key_buffer.get();
auto* output_key = output_key_buffer.get();
auto* input_value = input_value_buffer.get();
auto* output_value = output_value_buffer.get();
size_t temp_bytes = 0;
HIP_RETURN_IF_ERROR(hipcub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T)*8, stream));
auto temp_storage_buffer = kernel->GetScratchBuffer<char>(temp_bytes);
auto* temp_storage = temp_storage_buffer.get();
auto blocks_per_grid_D = (int)(ceil(static_cast<float>(dimension) / BT));
auto blocks_per_grid_K = (int)(ceil(static_cast<float>(K) / BT));
for (int64_t i = 0; i < N; i++) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(FillInput<HipT>), blocks_per_grid_D, BT, 0, stream, input_x_ptr, input_key, input_value, elem_nums, size, axis, K, i, dimension);
HIP_RETURN_IF_ERROR(1 == largest ? hipcub::DeviceRadixSort::SortPairsDescending(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T)*8, stream)
: hipcub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T)*8, stream));
if (1 == sorted) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(FillOutput<HipT>), blocks_per_grid_K, BT, 0, stream, output_key, output_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
} else { //reorder by ascending index
hipLaunchKernelGGL(HIP_KERNEL_NAME(ExcludeOutput<int64_t>), blocks_per_grid_D, BT, 0, stream, output_value, K, dimension);
HIP_RETURN_IF_ERROR(hipcub::DeviceRadixSort::SortPairs(temp_storage, temp_bytes, output_value, input_value, output_key, input_key, dimension, 0, sizeof(T)*8, stream));
hipLaunchKernelGGL(HIP_KERNEL_NAME(FillOutput<HipT>), blocks_per_grid_K, BT, 0, stream, input_key, input_value, output_v_ptr, output_i, elem_nums, size, axis, K, i, dimension);
}
}
}
return Status::OK();
}
#define TOPKIMPLE(T) template Status TopKImpl<T>(const RocmKernel* kernel, \
hipStream_t stream, \
const T* input_x, \
T* output_v, \
int64_t* output_i, \
const TArray<int64_t>& elem_nums, \
size_t size, \
int32_t axis, \
int64_t K, \
int64_t largest, \
int64_t sorted, \
int64_t N, \
int64_t dimension)
// This file is causing excessive long compilation time in ROCm EP. Split all those compilation into multiple
// translation units to speed it up.
TOPKIMPLE(TOPK_IMPL_TYPE);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/common/common.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
Status TopKImpl(const RocmKernel* kernel, hipStream_t stream, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE MLFloat16
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE float
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE double
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int16_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int32_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int64_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE int8_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint16_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint32_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint64_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#define TOPK_IMPL_TYPE uint8_t
#include "topk_impl.cuh"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "unary_elementwise_ops.h"
#include "unary_elementwise_ops_impl.h"
namespace onnxruntime {
namespace rocm {
Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePreparation* p) const {
p->input_tensor = context->Input<Tensor>(0);
p->output_tensor = context->Output(0, p->input_tensor->Shape());
return Status::OK();
}
#define UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_REGISTER_KERNEL(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>()), \
x<T>);
// 'Not' only has a 'T' type constraint. The other logical ops have T and T1.
#define UNARY_ELEMENTWISE_LOGICALOP_NOT_REGISTER_KERNEL_TYPED(ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Not, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Not<T>);
#define UNARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Impl_##x( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \
UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T)
#define UNARY_OP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_REGISTER_KERNEL(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(name, ver, T) \
UNARY_ELEMENTWISE_COMPUTE(name, T)
#define UNARY_LOGICALOP_NOT_TYPED(ver, T) \
UNARY_ELEMENTWISE_LOGICALOP_NOT_REGISTER_KERNEL_TYPED(ver, T) \
UNARY_ELEMENTWISE_COMPUTE(Not, T)
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define UNARY_OP_VERSIONED_HFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, float) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, double)
#define UNARY_OP_VERSIONED_CSILHFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int8_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int16_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int32_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, int64_t) \
UNARY_OP_VERSIONED_HFD(name, startver, endver)
#define UNARY_OP_VERSIONED_BWUZCSILHFD(name, startver, endver) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint8_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint16_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint32_t) \
UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \
UNARY_OP_VERSIONED_CSILHFD(name, startver, endver)
#define UNARY_OP_HFD(name, ver) \
UNARY_OP_TYPED(name, ver, MLFloat16) \
UNARY_OP_TYPED(name, ver, float) \
UNARY_OP_TYPED(name, ver, double)
#define UNARY_OP_CSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, int8_t) \
UNARY_OP_TYPED(name, ver, int16_t) \
UNARY_OP_TYPED(name, ver, int32_t) \
UNARY_OP_TYPED(name, ver, int64_t) \
UNARY_OP_HFD(name, ver)
#define UNARY_OP_BWUZCSILHFD(name, ver) \
UNARY_OP_TYPED(name, ver, uint8_t) \
UNARY_OP_TYPED(name, ver, uint16_t) \
UNARY_OP_TYPED(name, ver, uint32_t) \
UNARY_OP_TYPED(name, ver, uint64_t) \
UNARY_OP_CSILHFD(name, ver)
UNARY_OP_VERSIONED_BWUZCSILHFD(Abs, 6, 12)
UNARY_OP_VERSIONED_CSILHFD(Neg, 6, 12)
UNARY_OP_VERSIONED_HFD(Floor, 6, 12)
UNARY_OP_VERSIONED_HFD(Ceil, 6, 12)
UNARY_OP_VERSIONED_HFD(Reciprocal, 6, 12)
UNARY_OP_VERSIONED_HFD(Sqrt, 6, 12)
UNARY_OP_VERSIONED_HFD(Log, 6, 12)
UNARY_OP_VERSIONED_HFD(Exp, 6, 12)
UNARY_OP_VERSIONED_HFD(Erf, 9, 12)
UNARY_OP_BWUZCSILHFD(Abs, 13)
UNARY_OP_CSILHFD(Neg, 13)
UNARY_OP_HFD(Floor, 13)
UNARY_OP_HFD(Ceil, 13)
UNARY_OP_HFD(Reciprocal, 13)
UNARY_OP_HFD(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
UNARY_LOGICALOP_NOT_TYPED(1, bool)
UNARY_OP_HFD(Round, 11)
UNARY_OP_HFD(Cos, 7)
UNARY_OP_HFD(Sin, 7)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
struct UnaryElementwisePreparation {
const Tensor* input_tensor = nullptr;
Tensor* output_tensor = nullptr;
};
class UnaryElementwise : public RocmKernel {
protected:
UnaryElementwise(const OpKernelInfo& info) : RocmKernel(info) {}
Status ComputeInternal(OpKernelContext*) const override {
return Status(common::ONNXRUNTIME, common::FAIL); // should not reach here
}
Status Prepare(OpKernelContext* context, UnaryElementwisePreparation* p) const;
};
template <typename T>
class Abs final : public UnaryElementwise {
public:
Abs(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Neg final : public UnaryElementwise {
public:
Neg(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Floor final : public UnaryElementwise {
public:
Floor(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Ceil final : public UnaryElementwise {
public:
Ceil(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Reciprocal final : public UnaryElementwise {
public:
Reciprocal(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Sqrt final : public UnaryElementwise {
public:
Sqrt(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Log final : public UnaryElementwise {
public:
Log(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Exp final : public UnaryElementwise {
public:
Exp(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Erf final : public UnaryElementwise {
public:
Erf(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Not final : public UnaryElementwise {
public:
Not(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Round final : public UnaryElementwise {
public:
Round(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Sin final : public UnaryElementwise {
public:
Sin(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
template <typename T>
class Cos final : public UnaryElementwise {
public:
Cos(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "unary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
namespace onnxruntime {
namespace rocm {
#define OP(name, expr) \
template <typename T> \
struct OP_##name { \
__device__ __inline__ T operator()(const T& a) const { \
return expr; \
} \
};
#define UNARY_ELEMENTWISE_IMPL(name) \
UNARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
OP_##name<T>(), \
count); \
}
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, size_t count);
#define UNARY_OP_NAME_EXPR(name, expr) \
OP(name, expr) \
UNARY_ELEMENTWISE_IMPL(name)
UNARY_OPS()
#undef UNARY_OP_NAME_EXPR
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int8_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name)
#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint16_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint32_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint64_t) \
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(name)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Abs)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool)
// When casting, half needs to be converted via float type from most other types
template <typename T>
struct ViaTypeMap {
typedef T ViaT;
};
template <>
struct ViaTypeMap<half> {
typedef float ViaT;
};
template <>
struct ViaTypeMap<BFloat16> {
typedef float ViaT;
};
template <typename InT, typename OutT>
struct OP_Cast {
__device__ __inline__ OutT operator()(const InT& a) const {
const bool any_float16 = std::is_same<half, InT>::value || std::is_same<half, OutT>::value;
const bool any_bf16 = std::is_same<BFloat16, InT>::value || std::is_same<BFloat16, OutT>::value;
typedef typename std::conditional<any_bf16, BFloat16, OutT>::type T1;
typedef typename std::conditional<any_float16, half, T1>::type T;
typedef typename ViaTypeMap<T>::ViaT ViaT;
return (OutT)((ViaT)a);
}
};
template <typename InT, typename OutT>
void Impl_Cast(
hipStream_t stream,
const InT* input_data,
OutT* output_data,
size_t count) {
UnaryElementWiseImpl(stream,
input_data,
output_data,
OP_Cast<InT, OutT>(),
count);
}
#define SPECIALIZED_CAST_IMPL2(InT, OutT) \
template void Impl_Cast<InT, OutT>(hipStream_t stream, const InT* input_data, OutT* output_data, size_t count);
#define SPECIALIZED_CAST_FROM(T) \
SPECIALIZED_CAST_IMPL2(T, half) \
SPECIALIZED_CAST_IMPL2(T, float) \
SPECIALIZED_CAST_IMPL2(T, double) \
SPECIALIZED_CAST_IMPL2(T, int8_t) \
SPECIALIZED_CAST_IMPL2(T, int16_t) \
SPECIALIZED_CAST_IMPL2(T, int32_t) \
SPECIALIZED_CAST_IMPL2(T, int64_t) \
SPECIALIZED_CAST_IMPL2(T, uint8_t) \
SPECIALIZED_CAST_IMPL2(T, uint16_t) \
SPECIALIZED_CAST_IMPL2(T, uint32_t) \
SPECIALIZED_CAST_IMPL2(T, uint64_t) \
SPECIALIZED_CAST_IMPL2(T, bool) \
SPECIALIZED_CAST_IMPL2(T, BFloat16)
SPECIALIZED_CAST_FROM(half)
SPECIALIZED_CAST_FROM(float)
SPECIALIZED_CAST_FROM(double)
SPECIALIZED_CAST_FROM(int8_t)
SPECIALIZED_CAST_FROM(int16_t)
SPECIALIZED_CAST_FROM(int32_t)
SPECIALIZED_CAST_FROM(int64_t)
SPECIALIZED_CAST_FROM(uint8_t)
SPECIALIZED_CAST_FROM(uint16_t)
SPECIALIZED_CAST_FROM(uint32_t)
SPECIALIZED_CAST_FROM(uint64_t)
SPECIALIZED_CAST_FROM(bool)
SPECIALIZED_CAST_FROM(BFloat16)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace rocm {
// This macro simplifies coding to add a new op with following steps:
// 1. Add a new entry in UNARY_OPS() list
// 2. (optional) Define templated single element operator in unary_elementwise_ops_impl.cu
// 3. (optional) Implement specialized single element operator
// 4. Add op kernel class definition in unary_elementwise_ops.h
// 5. Add op kernel registration and compute specialization in unary_elementwise_ops.cc
#define UNARY_OPS() \
UNARY_OP_NAME_EXPR(Abs, _Abs(a)) \
UNARY_OP_NAME_EXPR(Neg, -a) \
UNARY_OP_NAME_EXPR(Ceil, _Ceil(a)) \
UNARY_OP_NAME_EXPR(Floor, _Floor(a)) \
UNARY_OP_NAME_EXPR(Reciprocal, T(1) / a) \
UNARY_OP_NAME_EXPR(Sqrt, _Sqrt(a)) \
UNARY_OP_NAME_EXPR(Exp, _Exp(a)) \
UNARY_OP_NAME_EXPR(Log, _Log(a)) \
UNARY_OP_NAME_EXPR(Erf, _Erf(a)) \
UNARY_OP_NAME_EXPR(Not, !a) \
UNARY_OP_NAME_EXPR(Round, _Round(a)) \
UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \
UNARY_OP_NAME_EXPR(Cos, _Cos(a))
#define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
hipStream_t stream, \
const T* input_data, \
T* output_data, \
size_t count)
#define UNARY_OP_NAME_EXPR(name, expr) UNARY_ELEMENTWISE_IMPL_DECLARATION(name);
UNARY_OPS()
#undef UNARY_OP_NAME_EXPR
template <typename InT, typename OutT>
void Impl_Cast(
hipStream_t stream,
const InT* input_data,
OutT* output_data,
size_t count);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/math/variadic_elementwise_ops.h"
#include <cassert>
#include <algorithm>
#include "core/framework/data_types_internal.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/variadic_elementwise_ops_impl.h"
#include "core/providers/rocm/math/variadic_elementwise_ops_tags.h"
namespace onnxruntime {
namespace rocm {
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
template <typename T>
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::NoBroadcastBatchImplDispatchTarget<
T>::operator()(hipStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
using HipT = typename ToHipType<T>::MappedType;
size_t input_count = inputs.size();
assert(input_count > 1);
size_t index = std::min(input_count, static_cast<size_t>(k_max_input_batch_size));
InputBatchArray<HipT> input_data_batch{static_cast<int32_t>(index)};
for (size_t i = 0; i < index; ++i) {
input_data_batch[static_cast<int32_t>(i)] = reinterpret_cast<const HipT*>(inputs[i].get().Data<T>());
}
HipT* output_data = reinterpret_cast<HipT*>(output.MutableData<T>());
Impl_NoBroadcastInputBatch<HipT, VariadicElementwiseOpTag>(stream, input_data_batch, output_data,
output.Shape().Size());
while (index < input_count) {
size_t left_count = input_count - index + 1;
size_t batch = std::min(left_count, static_cast<size_t>(k_max_input_batch_size));
// Special case for 2 inputs left.
if (batch == 2) {
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[input_count - 1].get(), &output, &prepare));
Impl_General<HipT, VariadicElementwiseOpTag>(
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<HipT*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
// Must be the last.
break;
}
InputBatchArray<HipT> left_input_data_batch{static_cast<int32_t>(batch)};
left_input_data_batch[0] = reinterpret_cast<const HipT*>(output.Data<T>());
for (size_t i = 1; i < batch; ++i) {
left_input_data_batch[static_cast<int32_t>(i)] =
reinterpret_cast<const HipT*>(inputs[index].get().Data<T>());
index++;
}
Impl_NoBroadcastInputBatch<HipT, VariadicElementwiseOpTag>(stream, left_input_data_batch, output_data,
output.Shape().Size());
}
return Status::OK();
}
// special case for 2 tensors to avoid memset zero
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
template <typename T>
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::
BinaryImplDispatchTarget<T>::operator()(hipStream_t stream, const Tensor& lhs, const Tensor& rhs, Tensor& output) const {
using HipT = typename ToHipType<T>::MappedType;
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&lhs, &rhs, &output, &prepare));
Impl_General<HipT, VariadicElementwiseOpTag>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<HipT*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
return Status::OK();
}
// for more than 2 inputs, we need to accumulate into output tensor, as the shape from input0 + input1 might be different from output shape
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
template <typename T>
Status
VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::GeneralImplDispatchTarget<T>::operator()(
hipStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
assert(inputs.size() > 1);
using HipT = typename ToHipType<T>::MappedType;
// If there is any input having the same shape with output, we don't need the memset.
size_t index_of_same_shape = 0;
for (; index_of_same_shape < inputs.size(); index_of_same_shape++) {
if (inputs[index_of_same_shape].get().Shape() == output.Shape()) {
break;
}
}
BinaryElementwisePreparation prepare;
// No input has same shape of output, memset the output, and add the 1st input as initialization.
if (index_of_same_shape == inputs.size()) {
HIP_RETURN_IF_ERROR(hipMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream));
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare));
Impl_Add(stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<HipT*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
} else {
// First operation is between input[0] and input[index_of_same_shape] if index_of_same_shape is not 0.
size_t index = index_of_same_shape == 0 ? 1 : 0;
ORT_RETURN_IF_ERROR(
BinaryElementwiseBroadcastPrepare(&inputs[index_of_same_shape].get(), &inputs[index].get(), &output, &prepare));
Impl_General<HipT, VariadicElementwiseOpTag>(
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<HipT*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
}
for (size_t index = 1; index < inputs.size(); index++) {
// If index_of_same_shape is 0, we already handle the 1st and 2nd inputs.
if (index == index_of_same_shape || (index_of_same_shape == 0 && index == 1)) {
continue;
}
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare));
Impl_General<HipT, VariadicElementwiseOpTag>(
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.lhs_tensor->Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const HipT*>(prepare.rhs_tensor->Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<HipT*>(prepare.output_tensor->MutableData<T>()),
prepare.output_tensor->Shape().Size());
}
return Status::OK();
}
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::ComputeInternal(
OpKernelContext* context) const {
const auto& node = Node();
const auto& node_name = node.Name();
auto input_count = node.InputArgCount().front();
ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs");
const InputTensorVector input_tensors =
[&context, input_count]() {
InputTensorVector result{};
result.reserve(input_count);
for (int i = 0; i < input_count; ++i) {
const auto& tensor = context->RequiredInput<Tensor>(i);
result.push_back(std::cref(tensor));
}
return result;
}();
const auto& first_input_tensor = input_tensors[0].get();
// special case for 1 input
if (input_count == 1) {
auto& output_tensor = context->RequiredOutput(0, first_input_tensor.Shape());
if (first_input_tensor.DataRaw() != output_tensor.DataRaw()) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(
output_tensor.MutableDataRaw(), first_input_tensor.DataRaw(), first_input_tensor.SizeInBytes(),
hipMemcpyDeviceToDevice, Stream()));
}
return Status::OK();
}
const auto element_type = first_input_tensor.GetElementType();
utils::MLTypeCallDispatcher<SupportedElementTypes...> dispatcher(element_type);
// Special case for no broadcasting.
if (std::all_of(input_tensors.begin() + 1, input_tensors.end(),
[&first_input_tensor](InputTensorVector::value_type t) {
return first_input_tensor.Shape() == t.get().Shape();
})) {
auto& output_tensor = context->RequiredOutput(0, first_input_tensor.Shape());
// special case for no broadcasting and 2 inputs
if (input_count == 2) {
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(Stream(), input_tensors[0],
input_tensors[1], output_tensor);
}
return dispatcher.template InvokeRet<Status, NoBroadcastBatchImplDispatchTarget>(Stream(), input_tensors,
output_tensor);
}
// compute output shape first, using broadcast rule
TensorShape output_shape;
TensorShape previous_output_shape = first_input_tensor.Shape();
for (int index = 1; index < input_count; index++) {
ORT_RETURN_IF_ERROR(ComputeOutputShape(
node_name, previous_output_shape, input_tensors[index].get().Shape(), output_shape));
previous_output_shape = output_shape;
}
Tensor& output_tensor = context->RequiredOutput(0, output_shape);
// special case for 2 inputs
if (input_count == 2) {
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(
Stream(), input_tensors[0], input_tensors[1], output_tensor);
}
// general case for more than 2 inputs
return dispatcher.template InvokeRet<Status, GeneralImplDispatchTarget>(
Stream(), input_tensors, output_tensor);
}
namespace {
using SumOp = VariadicElementwiseOp<variadic_elementwise_ops::Sum, MLFloat16, float, double, BFloat16>;
using MinOp = VariadicElementwiseOp<variadic_elementwise_ops::Min, uint32_t, uint64_t, int32_t, int64_t, MLFloat16,
float, double, BFloat16>;
using MaxOp = VariadicElementwiseOp<variadic_elementwise_ops::Max, uint32_t, uint64_t, int32_t, int64_t, MLFloat16,
float, double, BFloat16>;
} // namespace
// kernel registration
#define REGISTER_KERNEL(name, impl_class, version, datatypes) \
ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, version, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<datatypes>()), \
impl_class)
#define REGISTER_VERSIONED_KERNEL(name, impl_class, start_version, end_version, datatypes) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
name, kOnnxDomain, start_version, end_version, kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<datatypes>()), impl_class)
#define UZILHFD_TYPES uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16
#define HFD_TYPES MLFloat16, float, double, BFloat16
REGISTER_KERNEL(Sum, SumOp, 13, HFD_TYPES)
REGISTER_VERSIONED_KERNEL(Sum, SumOp, 8, 12, HFD_TYPES)
REGISTER_VERSIONED_KERNEL(Sum, SumOp, 6, 7, HFD_TYPES)
REGISTER_KERNEL(Min, MinOp, 13, UZILHFD_TYPES)
REGISTER_VERSIONED_KERNEL(Min, MinOp, 12, 12, UZILHFD_TYPES)
REGISTER_VERSIONED_KERNEL(Min, MinOp, 6, 11, HFD_TYPES)
REGISTER_KERNEL(Max, MaxOp, 13, UZILHFD_TYPES)
REGISTER_VERSIONED_KERNEL(Max, MaxOp, 12, 12, UZILHFD_TYPES)
REGISTER_VERSIONED_KERNEL(Max, MaxOp, 6, 11, HFD_TYPES)
#undef HFD_TYPES
#undef UZILHFD_TYPES
#undef REGISTER_VERSIONED_KERNEL
#undef REGISTER_KERNEL
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include <vector>
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
using InputTensorVector = std::vector<std::reference_wrapper<const Tensor>>;
template <typename VariadicElementwiseOpTag,
typename... SupportedElementTypes>
class VariadicElementwiseOp : public RocmKernel {
public:
VariadicElementwiseOp(const OpKernelInfo& info) : RocmKernel(info) {}
private:
Status ComputeInternal(OpKernelContext* context) const override;
template <typename T>
struct NoBroadcastBatchImplDispatchTarget {
Status operator()(hipStream_t stream, const InputTensorVector& inputs, Tensor& output) const;
};
template <typename T>
struct BinaryImplDispatchTarget {
Status operator()(hipStream_t stream, const Tensor& lhs, const Tensor& rhs, Tensor& output) const;
};
template <typename T>
struct GeneralImplDispatchTarget {
Status operator()(hipStream_t stream, const InputTensorVector& inputs, Tensor& output) const;
};
};
} // namespace rocm
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment