Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/bias_softmax.h"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/math/bias_softmax_impl.h"
using namespace onnxruntime;
using namespace onnxruntime::rocm;
using namespace onnxruntime::contrib::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
namespace {
template <typename T>
struct DispatchBiasSoftmaxImpl {
Status operator()(hipStream_t stream, miopenHandle_t miopen_handle, Tensor* Y, const Tensor* X, const Tensor* B,
int element_count, int batch_count, bool is_inner_broadcast, int bias_broadcast_size) {
typedef typename ToHipType<T>::MappedType HipT;
HipT* output_data = reinterpret_cast<HipT*>(Y->template MutableData<T>());
const HipT* input_data = reinterpret_cast<const HipT*>(X->template Data<T>());
const HipT* bias_data = reinterpret_cast<const HipT*>(B->template Data<T>());
return BiasSoftmaxImpl<HipT>(stream, miopen_handle, output_data, input_data, bias_data, element_count, batch_count,
is_inner_broadcast, bias_broadcast_size);
}
};
} // namespace
// MIOpen doesn't support double so ROCm kernel doesn't have double support for now.
#ifdef USE_ROCM
#define BIAS_SOFTMAX_TYPES float, MLFloat16
#else
#define BIAS_SOFTMAX_TYPES float, MLFloat16, double
#endif
ONNX_OPERATOR_KERNEL_EX(
BiasSoftmax, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<BIAS_SOFTMAX_TYPES>()), BiasSoftmax);
Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X = ctx->Input<Tensor>(0);
const Tensor* B = ctx->Input<Tensor>(1);
const TensorShape& X_shape = X->Shape();
const TensorShape& B_shape = B->Shape();
Tensor* Y = ctx->Output(0, X_shape);
const int axis = static_cast<int>(HandleNegativeAxis(axis_, X_shape.NumDimensions()));
const int batch_count = static_cast<int>(X_shape.SizeToDimension(axis));
const int element_count = static_cast<int>(X_shape.SizeFromDimension(axis));
int bias_broadcast_size = static_cast<int>(B_shape.Size() / element_count);
if (is_inner_broadcast_) bias_broadcast_size = batch_count / bias_broadcast_size;
utils::MLTypeCallDispatcher<BIAS_SOFTMAX_TYPES> t_disp(X->GetElementType());
return t_disp.InvokeRet<Status, DispatchBiasSoftmaxImpl>(Stream(), MiopenHandle(), Y, X, B, element_count, batch_count,
is_inner_broadcast_, bias_broadcast_size);
}
#undef BIAS_SOFTMAX_TYPES
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// BiasSoftmax follows the OpSet-11 definision of Softmax Op, that is, the input will be coerced to a 2D tensor
// using axis attribute, all dims after axis (included) are in the same batch. This is different from definition
// since OpSet-13. To use BiasSoftmax, during the fusion, if Softmax is OpSet-13 or newer, you can only fuse it
// when axis attribute is the last dim, othewise, the computation result may be wrong.
class BiasSoftmax final : public onnxruntime::rocm::RocmKernel {
public:
BiasSoftmax(const OpKernelInfo& info) : RocmKernel{info} {
info.GetAttrOrDefault("axis", &axis_, static_cast<int64_t>(1));
int64_t is_inner_broadcast_value;
ORT_ENFORCE(info.GetAttr<int64_t>("is_inner_broadcast", &is_inner_broadcast_value).IsOK());
is_inner_broadcast_ = is_inner_broadcast_value != 0;
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t axis_;
bool is_inner_broadcast_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/bias_softmax_impl.h"
#include <limits>
#include <algorithm>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
#include "core/providers/rocm/math/softmax_common.h"
#include "core/providers/rocm/math/softmax_warpwise_impl.cuh"
#include "core/providers/rocm/shared_inc/accumulation_type.h"
using namespace onnxruntime;
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Duplicated softmax_impl.cu here
// So far attempt to use shared kernel with additional template resulted in lost performance
// Note: The intended case for 'input_bias' is the input sequence mask for transformer models
// As an additive mask, it should be zero for preserved tokens and -infty for tokens to screen
// The mask will broadcast from [batch_size, 1, 1, seq_len] to input [batch_size, num_heads, seq_len, seq_len]
// Here element_count = seq_len and bias_broadcast_size_per_batch = num_heads * seq_len
// The softmax + additive mask fusion follows NVIDIA apex's additive_masked_softmax_warp_forward
// see
// https://github.com/NVIDIA/apex/blob/4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a/apex/contrib/csrc/multihead_attn/softmax.h
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_inner_broadcast>
__global__ void BiasSoftmaxWarpForward(output_t* output, const input_t* input, const input_t* input_bias,
int element_count, int batch_count, fast_divmod bias_broadcast_fdm) {
// "WARP" refers to cooperative threads and might not equal 32 threads of GPU warp
// thread block is (WARP_SIZE, 128/WARP_SIZE)
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = next_power_of_two < GPU_WARP_SIZE ? next_power_of_two : GPU_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
#ifdef USE_ROCM
constexpr int WARP_BATCH = 1;
#else
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
#endif
// each "WARP" (<=32) processes WARP_BATCH(one of {1,2}) batches
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// last warp may have fewer batches
int local_batches = batch_count - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// thread will process elements (local_index + n * warp_size) within batch
int local_idx = threadIdx.x;
// push input, input_bias output pointers to batch we need to process
input += first_batch * element_count + local_idx;
output += first_batch * element_count + local_idx;
// load from global memory and apply bias (likely an additive mask)
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
// If is_inner_broadcast, input shape is [x, broadcast_size, element_count], bias shape is [x, 1, element_count].
// Otherwise, input shape is [x, broadcast_size, element_count], bias shape is [1, broadcast_size, element_count].
int bias_batch_offset =
is_inner_broadcast ? bias_broadcast_fdm.div(first_batch + i) : bias_broadcast_fdm.mod(first_batch + i);
int bias_offset = bias_batch_offset * element_count + local_idx;
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] =
(acc_t)input[i * element_count + it * WARP_SIZE] + (acc_t)input_bias[bias_offset + it * WARP_SIZE];
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
}
}
}
// find maximum value within batch for numerical stability
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// normalization factor Z = Sum[ exp(element_i), for element_i in batch ]
acc_t sum[WARP_BATCH]{acc_t(0.0)};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = expf((acc_t)(elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// write back normalized value = exp(element_i)/Z to global memory
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
output[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
} else {
break;
}
}
}
}
template <typename T>
Status BiasSoftmaxImpl(hipStream_t stream, miopenHandle_t miopen_handle, T* output_data, const T* input_data,
const T* bias_data, int element_count, int batch_count, bool is_inner_broadcast,
int bias_broadcast_size) {
if (element_count == 0) return Status::OK();
if (element_count <= 1024 && element_count * static_cast<int>(sizeof(T)) <= 4096) {
typedef AccumulationType_t<T> AccT;
int log2_elements = log2_ceil(element_count);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = std::min(next_power_of_two, GPU_WARP_SIZE_HOST);
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
#ifdef USE_ROCM
int batches_per_warp = 1;
constexpr int threads_per_block = 256;
#else
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
#endif
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
fast_divmod bias_broadcast_fdm = fast_divmod(bias_broadcast_size);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
#define LAUNCHE_BIAS_SOFTMAX_KERNEL(log2_elements_value, is_inner_broadcast_value) \
hipLaunchKernelGGL(HIP_KERNEL_NAME(BiasSoftmaxWarpForward<T, T, AccT, log2_elements_value, is_inner_broadcast_value>), blocks, threads, 0, stream, \
output_data, input_data, bias_data, element_count, batch_count, bias_broadcast_fdm)
#define CASE_LOG2_ELEMENTS(log2_elements_value) \
case log2_elements_value: { \
if (is_inner_broadcast) { \
LAUNCHE_BIAS_SOFTMAX_KERNEL(log2_elements_value, true); \
} else { \
LAUNCHE_BIAS_SOFTMAX_KERNEL(log2_elements_value, false); \
} \
} break
CASE_LOG2_ELEMENTS(0); // 1
CASE_LOG2_ELEMENTS(1); // 2
CASE_LOG2_ELEMENTS(2); // 4
CASE_LOG2_ELEMENTS(3); // 8
CASE_LOG2_ELEMENTS(4); // 16
CASE_LOG2_ELEMENTS(5); // 32
CASE_LOG2_ELEMENTS(6); // 64
CASE_LOG2_ELEMENTS(7); // 128
CASE_LOG2_ELEMENTS(8); // 256
CASE_LOG2_ELEMENTS(9); // 512
CASE_LOG2_ELEMENTS(10); // 1024
#undef CASE_LOG2_ELEMENTS
#undef LAUNCHE_BIAS_SOFTMAX_KERNEL
}
return Status::OK();
}
// For large element count we fall back to explicit Add kernel + ROCM DNN library
// note: This is an unhappy path! There is no performance benefit for the fusion.
int output_rank_or_simple_broadcast = 3;
TArray<int64_t> rhs_strides;
TArray<fast_divmod> output_fdms;
const TArray<int64_t>* p_rhs_strides = nullptr;
const TArray<fast_divmod>* p_output_fdms = nullptr;
fast_divmod fdm_h(1);
fast_divmod fdm_c;
if ((is_inner_broadcast && bias_broadcast_size == 1) || (!is_inner_broadcast && bias_broadcast_size == batch_count)) {
// input and bias shape is same.
output_rank_or_simple_broadcast = static_cast<int>(SimpleBroadcast::NoBroadcast);
} else if (!is_inner_broadcast) {
output_rank_or_simple_broadcast = static_cast<int>(SimpleBroadcast::RightPerChannelBatchN);
fdm_c = fast_divmod(element_count * bias_broadcast_size);
} else {
rhs_strides.SetSize(3);
rhs_strides[0] = static_cast<int64_t>(element_count);
rhs_strides[1] = 0LL;
rhs_strides[2] = 1LL;
p_rhs_strides = &rhs_strides;
output_fdms.SetSize(3);
output_fdms[0] = fast_divmod(element_count * bias_broadcast_size);
output_fdms[1] = fast_divmod(element_count);
output_fdms[2] = fast_divmod(1);
p_output_fdms = &output_fdms;
}
BinaryElementWiseImpl(stream, output_rank_or_simple_broadcast, nullptr, input_data, p_rhs_strides, bias_data,
p_output_fdms, fdm_h, fdm_c, output_data, OP_Add<T, T, T>(),
static_cast<size_t>(batch_count * element_count));
// invoke rocm DNN library for Y = softmax(X)
const int64_t dims[]{batch_count, 1, 1, element_count};
const auto alpha = Consts<T>::One;
const auto beta = Consts<T>::Zero;
MiopenTensor input_tensor, output_tensor;
ORT_RETURN_IF_ERROR(input_tensor.Set(dims, MiopenTensor::GetDataType<T>()));
ORT_RETURN_IF_ERROR(output_tensor.Set(dims, MiopenTensor::GetDataType<T>()));
return SoftmaxForward(miopen_handle, &alpha, input_tensor, output_data, &beta, output_tensor, output_data);
}
#define SPECIALIZED_BIAS_SOFTMAX_IMPL(T) \
template Status BiasSoftmaxImpl<T>(hipStream_t stream, miopenHandle_t miopen_handle, T * output_data, \
const T* input_data, const T* bias_data, int element_count, int batch_count, \
bool is_inner_broadcast, int bias_broadcast_size);
// MIOpen doesn't support double so ROCm kernel doesn't have double support for now.
SPECIALIZED_BIAS_SOFTMAX_IMPL(float)
SPECIALIZED_BIAS_SOFTMAX_IMPL(half)
#ifdef USE_ROCM
SPECIALIZED_BIAS_SOFTMAX_IMPL(double)
#endif
#undef SPECIALIZED_BIAS_SOFTMAX_IMPL
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
Status BiasSoftmaxImpl(hipStream_t stream, miopenHandle_t miopen_handle, T* output_data, const T* input_data,
const T* bias_data, int element_count, int batch_count, bool is_inner_broadcast,
int bias_broadcast_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/binary_elementwise_ops.h"
#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h"
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define CONTRIB_BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kMSDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define CONTRIB_BINARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()), \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(prepare.rhs_tensor->Data<T>()), \
&prepare.fdm_output_strides, \
prepare.fdm_H, \
prepare.fdm_C, \
reinterpret_cast<typename ToHipType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()), \
prepare.output_tensor->Shape().Size()); \
return Status::OK(); \
}
#define CONTRIB_BINARY_OP_TYPED(name, ver, T) \
CONTRIB_BINARY_ELEMENTWISE_REGISTER_KERNEL_TYPED(name, ver, T) \
CONTRIB_BINARY_ELEMENTWISE_COMPUTE(name, T)
// since different ops has different types, we cannot use BINARY_OPS() directly
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
#define CONTRIB_BINARY_OP_HFD(name, ver) \
CONTRIB_BINARY_OP_TYPED(name, ver, MLFloat16) \
CONTRIB_BINARY_OP_TYPED(name, ver, float) \
CONTRIB_BINARY_OP_TYPED(name, ver, double) \
CONTRIB_BINARY_OP_TYPED(name, ver, BFloat16)
CONTRIB_BINARY_OP_HFD(BiasGelu, 1)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fast_divmod.h"
#include "core/providers/cpu/tensor/utils.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
// AddGelu fuse Add + Gelu
template <typename T>
class BiasGelu final : public BinaryElementwise<ShouldBroadcast> {
public:
BiasGelu(const OpKernelInfo& info) : BinaryElementwise(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define OP(name, expr) \
template <class T> \
struct OP_##name { \
__device__ __inline__ T operator()(T a, T b) const { \
return (expr); \
} \
};
#define CONTRIB_BINARY_ELEMENTWISE_IMPL(name) \
CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
BinaryElementWiseImpl(stream, \
output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
rhs_data, \
fdm_output_strides, \
fdm_H, \
fdm_C, \
output_data, \
OP_##name<T>(), \
count); \
}
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, T) \
template void Impl_##x<T>(hipStream_t stream, \
int32_t output_rank, \
const TArray<int64_t>* lhs_padded_strides, \
const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, \
const T* rhs_data, \
const TArray<onnxruntime::rocm::fast_divmod>* fdm_output_strides, \
const onnxruntime::rocm::fast_divmod& fdm_H, \
const onnxruntime::rocm::fast_divmod& fdm_C, \
T* output_data, size_t count);
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(x) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double)
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_OIL(x) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, bool) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t)
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(x) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)
// create declarations for op and impl
#define CONTRIB_BINARY_OP_NAME_EXPR(name, expr) \
OP(name, expr) \
CONTRIB_BINARY_ELEMENTWISE_IMPL(name)
CONTRIB_BINARY_OPS()
#undef CONTRIB_BINARY_OP_NAME_EXPR
// create specialized impl
// the postfix of means the types supported by the op:
// B: uint8_t
// W: uint16_t
// U: uint32_t
// Z: uint64_t
// C: int8_t
// S: int16_t
// I: int32_t
// L: int64_t
// H: float16
// F: float
// D: double
// O: bool
CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(BiasGelu)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include <stdint.h>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
// These macros simplifies coding. To add a new op with following steps:
// 1. Add a new entry in CONTRIB_BINARY_OPS() list
// 2. (optional) Define templated single element operator in binary_elementwise_ops_impl.cu
// 3. (optional) Implement specialized single element operator
// 4. Add op kernel class definition in binary_elementwise_ops.h
// 5. Add op kernel registration and compute specialization in binary_elementwise_ops.cc
#define CONTRIB_BINARY_OPS() \
CONTRIB_BINARY_OP_NAME_EXPR(BiasGelu, _Gelu(a + b))
// NOTE that cu files are compiled with nvcc and should not refer to any onnxruntime headers
// so struct BinaryElementwisePreparation cannot be used here
#define CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
hipStream_t stream, \
int32_t output_rank_or_simple_broadcast, \
const TArray<int64_t>* lhs_padded_strides, \
const T* lhs_data, \
const TArray<int64_t>* rhs_padded_strides, \
const T* rhs_data, \
const TArray<onnxruntime::rocm::fast_divmod>* fdm_output_strides, \
const onnxruntime::rocm::fast_divmod& fdm_H, \
const onnxruntime::rocm::fast_divmod& fdm_C, \
T* output_data, \
size_t count)
#define CONTRIB_BINARY_OP_NAME_EXPR(name, expr) CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name);
CONTRIB_BINARY_OPS()
#undef CONTRIB_BINARY_OP_NAME_EXPR
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/nn/dropout.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(BitmaskDropout, kMSDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T1", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<onnxruntime::rocm::BitmaskElementType>())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2),
onnxruntime::rocm::Dropout<true>);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/math/matmul.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(op_name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
op_name, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
onnxruntime::rocm::MatMul<T>);
// TransposeMatMul is kept here for backward compatibility
REGISTER_KERNEL_TYPED(TransposeMatMul, float)
REGISTER_KERNEL_TYPED(TransposeMatMul, double)
REGISTER_KERNEL_TYPED(TransposeMatMul, MLFloat16)
REGISTER_KERNEL_TYPED(TransposeMatMul, BFloat16)
REGISTER_KERNEL_TYPED(FusedMatMul, float)
REGISTER_KERNEL_TYPED(FusedMatMul, double)
REGISTER_KERNEL_TYPED(FusedMatMul, MLFloat16)
REGISTER_KERNEL_TYPED(FusedMatMul, BFloat16)
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/math/isfinite.h"
#include "isfinite_impl.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {
#define REGISTER_ISALLFINITE_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
IsAllFinite, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("V", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<bool>()), \
IsAllFiniteOp<T>);
template <typename TSrc>
Status IsAllFiniteOp<TSrc>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToHipType<TSrc>::MappedType TSrcCuda;
// Get Input tensor count.
const auto total_tensor_count = context->InputCount();
// Initialize the output to true. GPU kernel will set it
// to false if any value in any tensor is non-finite.
Tensor& output = *context->Output(0, {});
auto* output_data = reinterpret_cast<ToHipType<bool>::MappedType*>(output.MutableData<bool>());
HIP_RETURN_IF_ERROR(hipMemsetAsync(output_data, int(true), sizeof(bool), Stream()));
std::vector<std::vector<void*>> grouped_tensor_pointers(total_tensor_count);
std::vector<int> tensor_sizes(total_tensor_count);
for (int i = 0; i < total_tensor_count; ++i) {
const auto& input = context->Input<Tensor>(i);
grouped_tensor_pointers[i] = {const_cast<TSrc*>(input->Data<TSrc>())};
tensor_sizes[i] = static_cast<int>(input->Shape().Size());
}
typedef IsAllFiniteFunctor<TSrcCuda> TFunctor;
TFunctor functor;
// Check if all values are finite and write true to output.
// Otherwise, false will be written.
launch_multi_tensor_functor<1, TFunctor>(
Stream(), 2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, output_data, isinf_only_, isnan_only_);
return Status::OK();
}
REGISTER_ISALLFINITE_KERNEL_TYPED(MLFloat16)
REGISTER_ISALLFINITE_KERNEL_TYPED(float)
REGISTER_ISALLFINITE_KERNEL_TYPED(double)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/math/isfinite.h"
namespace onnxruntime {
namespace rocm {
template <typename T>
__device__ __forceinline__ bool IsFiniteScalar(const T value) {
return isfinite(value);
}
template <typename T>
__device__ __forceinline__ bool IsInfScalar(const T value) {
return isinf(value);
}
template <typename T>
__device__ __forceinline__ bool IsNaNScalar(const T value) {
return isnan(value);
}
template <>
__device__ __forceinline__ bool IsFiniteScalar(const half value) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return !__hisinf(value) && !__hisnan(value);
#else
return isfinite(float(value));
#endif
}
template <>
__device__ __forceinline__ bool IsInfScalar(const half value) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hisinf(value);
#else
return isinf(float(value));
#endif
}
template <>
__device__ __forceinline__ bool IsNaNScalar(const half value) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hisnan(value);
#else
return isnan(float(value));
#endif
}
template <>
__device__ __forceinline__ bool IsFiniteScalar(const BFloat16 value) {
return isfinite(static_cast<float>(value));
}
template <>
__device__ __forceinline__ bool IsInfScalar(const BFloat16 value) {
return isinf(static_cast<float>(value));
}
template <>
__device__ __forceinline__ bool IsNaNScalar(const BFloat16 value) {
return isnan(static_cast<float>(value));
}
} // namespace rocm
} // namespace onnxruntime
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace rocm {
template <typename TSrc>
class IsAllFiniteOp final : public RocmKernel {
public:
IsAllFiniteOp(const OpKernelInfo& info) : RocmKernel(info) {
int64_t isinf_only;
info.GetAttrOrDefault("isinf_only", &isinf_only, static_cast<int64_t>(0));
isinf_only_ = (isinf_only != 0);
int64_t isnan_only;
info.GetAttrOrDefault("isnan_only", &isnan_only, static_cast<int64_t>(0));
isnan_only_ = (isnan_only != 0);
ORT_ENFORCE(!(isinf_only_ && isnan_only_),
"Both attributes isinf_only and isnan_only cannot be set. Unset both to check for both conditions.");
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool isinf_only_, isnan_only_;
};
} // namespace rocm
} // namespace onnxruntime
\ No newline at end of file
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_fp16.h>
#include "isfinite_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/math/isfinite.cuh"
namespace onnxruntime {
namespace rocm {
template <typename TSrc, bool isinf_only, bool isnan_only>
__global__ void IsAllFiniteMultiTensorImpl(ChunkGroup<1> chunks, bool* output) {
const int block_idx = blockIdx.x;
const int tensor_idx = chunks.block_index_to_tensor_group_index[block_idx];
const int tensor_size = chunks.tensor_sizes[tensor_idx];
const TSrc* tensor_ptr = static_cast<TSrc*>(chunks.tensor_ptrs[0][tensor_idx]);
const int chunk_start_idx = chunks.block_index_to_chunk_start_index[block_idx];
// chunk_size is chunks.chunk_size if the loaded chunk is full. Otherwise (this
// chunk is the last one in the source tensor), the actual size is determined
// by the bound of the source tensor.
const int chunk_size = min(tensor_size, chunk_start_idx + chunks.chunk_size) - chunk_start_idx;
const TSrc* chunk_ptr = tensor_ptr + chunk_start_idx;
bool result = true;
#pragma unroll 4
for (int i = threadIdx.x; i < chunk_size; i += blockDim.x) {
if (isinf_only) {
result &= !IsInfScalar(chunk_ptr[i]);
} else if (isnan_only) {
result &= !IsNaNScalar(chunk_ptr[i]);
} else {
result &= IsFiniteScalar(chunk_ptr[i]);
}
}
if (!result) {
*output = false;
}
}
template <typename T>
void IsAllFiniteFunctor<T>::operator()(hipStream_t stream,
ChunkGroup<1> chunks,
bool* output,
const bool isinf_only,
const bool isnan_only) {
const int block_count = chunks.chunk_count;
const int thread_count = ChunkGroup<1>::thread_count_per_block;
if (isinf_only) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(IsAllFiniteMultiTensorImpl<T, true, false>), block_count, thread_count, 0, stream, chunks, output);
} else if (isnan_only) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(IsAllFiniteMultiTensorImpl<T, false, true>), block_count, thread_count, 0, stream, chunks, output);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(IsAllFiniteMultiTensorImpl<T, false, false>), block_count, thread_count, 0, stream, chunks, output);
}
}
#define INSTANTIATE_ISALLFINITE_FUNCTOR(T) \
template void IsAllFiniteFunctor<T>::operator()(hipStream_t stream, \
ChunkGroup<1> chunks, \
bool* output, \
const bool isinf_only, \
const bool isnan_only);
INSTANTIATE_ISALLFINITE_FUNCTOR(half)
INSTANTIATE_ISALLFINITE_FUNCTOR(float)
INSTANTIATE_ISALLFINITE_FUNCTOR(double)
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include "core/providers/rocm/multi_tensor/common.cuh"
namespace onnxruntime {
namespace rocm {
template <typename T>
struct IsAllFiniteFunctor {
void operator()(hipStream_t stream, ChunkGroup<1> chunks, bool* output, const bool isinf_only, const bool isnan_only);
};
}
} // namespace onnxruntime
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/tensor/trilu.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
Trilu,
kMSDomain,
1,
kRocmExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
onnxruntime::rocm::Trilu);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */
#include "beam_search_topk.h"
#include <hipcub/hipcub.hpp>
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/cu_inc/common.cuh"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T, int max_k>
struct TopK {
int32_t key[max_k];
T value[max_k];
__device__ __forceinline__ void Insert(T elem, int elem_id) {
T v = value[max_k - 1];
if (v < elem ||
(key[max_k - 1] == -1) ||
((elem == value[max_k - 1]) && (elem_id < key[max_k - 1]))) {
value[max_k - 1] = elem;
key[max_k - 1] = elem_id;
}
for (int k = max_k - 2; k >= 0; --k) {
if (value[k + 1] > value[k] ||
key[k] == -1 ||
((value[k + 1] == value[k]) && (key[k + 1] < key[k]))) {
T u2 = value[k];
int p2 = key[k];
value[k] = value[k + 1];
key[k] = key[k + 1];
value[k + 1] = u2;
key[k + 1] = p2;
}
}
}
__device__ __forceinline__ void Init() {
for (int i = 0; i < max_k; i++) {
key[i] = -1;
value[i] = NumericLimits<T>::Min();
}
}
};
template <typename T, int max_k>
__device__ __forceinline__ TopK<T, max_k> reduce_topk_op(const TopK<T, max_k>& a, const TopK<T, max_k>& b) {
TopK<T, max_k> res = a;
for (int i = 0; i < max_k; ++i)
res.Insert(b.value[i], b.key[i]);
return res;
}
// kernel to compute the top k on last axis for tensor with shape: [batch, beam_size, parts_of_vocab, vacab_part_size]
// Its grid is [batch * beam_size, parts_of_vocab]
template <typename T, int max_k, int thread_block_size>
__launch_bounds__(thread_block_size) __global__ void BeamSearchOnlineTopKStage1Kernel(
const T* input,
int32_t k,
int32_t vocab_size,
int32_t vocab_part_size,
T* output_values,
int32_t* output_token) {
TopK<T, max_k> top_k_thread;
top_k_thread.Init();
int batch_beam = blockIdx.x;
int voc_part_id = blockIdx.y;
int token_id_base = voc_part_id * vocab_part_size;
const T* input_block = input + batch_beam * vocab_size;
// voc_part_size
for (int i = threadIdx.x + token_id_base; i < vocab_part_size + token_id_base; i += blockDim.x) {
if (i < vocab_size) {
top_k_thread.Insert(input_block[i], i);
}
}
// reduce in thread block
typedef hipcub::BlockReduce<TopK<T, max_k>, thread_block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, max_k> top_k_block = BlockReduce(temp_storage).Reduce(top_k_thread, reduce_topk_op<T, max_k>);
__syncthreads();
output_values += batch_beam * gridDim.y * k + voc_part_id * k;
output_token += batch_beam * gridDim.y * k + voc_part_id * k;
if (threadIdx.x == 0) {
for (int i = 0; i < k; i++) {
output_values[i] = top_k_block.value[i];
output_token[i] = top_k_block.key[i];
}
}
}
template <typename T, int max_k, int thread_block_size>
__launch_bounds__(thread_block_size) __global__ void BeamSearchOnlineTopKStage2Kernel(
const T* input_values,
const int32_t* input_tokens,
int32_t k,
int32_t vocab_size,
int32_t parts_per_beam,
T* output_values,
int32_t* output_indices) {
const int vector_id = blockIdx.x;
const int thread_id = threadIdx.x;
extern __shared__ char shared_buf_extern[];
T* value_shared_buf = reinterpret_cast<T*>(shared_buf_extern);
int32_t* tokens_shared_buf =
reinterpret_cast<int32_t*>(shared_buf_extern + max_k * parts_per_beam * sizeof(int32_t));
typedef hipcub::BlockReduce<TopK<T, max_k>, thread_block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
input_values += vector_id * k * parts_per_beam;
input_tokens += vector_id * k * parts_per_beam;
TopK<T, max_k> thread_topk;
for (int i = 0; i < max_k; ++i) {
thread_topk.key[i] = -1;
thread_topk.value[i] = NumericLimits<T>::Min();
}
for (int idx = thread_id; idx < k * parts_per_beam; idx += thread_block_size) {
value_shared_buf[idx] = input_values[idx];
tokens_shared_buf[idx] = input_tokens[idx];
}
__syncthreads();
if (thread_id < parts_per_beam) {
T* b_v = value_shared_buf + thread_id * k;
int32_t* b_i = tokens_shared_buf + thread_id * k;
for (int i = 0; i < k; i++) {
thread_topk.Insert(b_v[i], b_i[i]);
}
}
TopK<T, max_k> topk_block = BlockReduce(temp_storage).Reduce(thread_topk, reduce_topk_op<T, max_k>);
if (thread_id == 0) {
output_values += vector_id * k;
output_indices += vector_id * k;
for (int i = 0; i < k; ++i) {
if (i < k) {
output_values[i] = topk_block.value[i];
output_indices[i] = topk_block.key[i];
}
}
}
}
template <typename T, int max_k>
void LaunchBeamSearchOnlineTopKStage2Kernel(
const T* topk_values_tmp,
const int32_t* topk_indices_tmp,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t parts_per_beam,
int32_t K,
T* output_values,
int32_t* output_indices,
hipStream_t stream) {
ORT_ENFORCE(parts_per_beam <= 128, "Parts per beam should not be greater than 128");
int smem_stage2_size = parts_per_beam * max_k * 2 * sizeof(int32_t);
if (parts_per_beam <= 32) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(BeamSearchOnlineTopKStage2Kernel<T, max_k, 32>), batch_size * num_beams, 32, smem_stage2_size, stream,
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
return;
}
if (parts_per_beam <= 64) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(BeamSearchOnlineTopKStage2Kernel<T, max_k, 64>), batch_size * num_beams, 64, smem_stage2_size, stream,
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
return;
}
hipLaunchKernelGGL(HIP_KERNEL_NAME(BeamSearchOnlineTopKStage2Kernel<T, max_k, 128>), batch_size * num_beams, 128, smem_stage2_size, stream,
topk_values_tmp, topk_indices_tmp, K, vocab_size, parts_per_beam, output_values, output_indices);
return;
}
template <typename T, int max_k>
void TopKLauncherMaxK(
const T* input,
int batch_size,
int num_beams,
int vocab_size,
int K,
T* output_values,
int32_t* output_indices,
T* output_values_tmp,
int32_t* output_indices_tmp,
hipStream_t stream) {
constexpr int kThreadBlockSize = (max_k < 16) ? (max_k < 8) ? 256 : 128 : 64;
int voc_parts = 4;
if (batch_size * num_beams < 256) {
// volta has 80 SMs, so we aim for three waves
voc_parts = (240 + batch_size * num_beams - 1) / (batch_size * num_beams);
voc_parts = std::min(128, voc_parts); // we implement up to 128
}
dim3 grid(batch_size * num_beams, voc_parts);
#ifndef USE_ROCM
hipFuncSetAttribute(BeamSearchOnlineTopKStage1Kernel<T, max_k, kThreadBlockSize>,
hipFuncAttributePreferredSharedMemoryCarveout,
rocmSharedmemCarveoutMaxL1);
#endif // !USE_ROCM
hipLaunchKernelGGL(HIP_KERNEL_NAME(BeamSearchOnlineTopKStage1Kernel<T, max_k, kThreadBlockSize>), grid, kThreadBlockSize, 0, stream, input, K, vocab_size, (vocab_size + voc_parts - 1) / voc_parts, output_values_tmp, output_indices_tmp);
LaunchBeamSearchOnlineTopKStage2Kernel<T, max_k>(
output_values_tmp,
output_indices_tmp,
batch_size,
num_beams,
vocab_size,
voc_parts,
K,
output_values,
output_indices,
stream);
}
template <typename T, typename I, int32_t max_k, int32_t thread_block_size>
__launch_bounds__(thread_block_size) __global__ void BatchTopKKernel(
const T* topk_scores,
const I* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
T* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k) {
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
TopK<T, max_k> thread_topk;
if (thread_id == 0) {
thread_topk.Init();
int index_block = block_id * num_beams * k;
for (int32_t i = 0; i < num_beams * k; i++) {
thread_topk.Insert(topk_scores[index_block + i], index_block + i);
}
int index_next = block_id * k;
for (int i = 0; i < k; i++) {
next_tokens[index_next + i] = topk_tokens[thread_topk.key[i]];
next_indices[index_next + i] = (thread_topk.key[i] - index_block) / k;
next_scores[index_next + i] = thread_topk.value[i];
}
}
}
template <typename T, typename I>
void LaunchBatchTopKKernel(const T* topk_scores,
const I* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
T* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream) {
ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256");
#define BatchTopKKernelLauncher(K) \
hipLaunchKernelGGL(HIP_KERNEL_NAME(BatchTopKKernel<T, I, K, 32>), batch_size, 32, 0, stream, topk_scores, \
topk_tokens, \
next_indices, \
next_tokens, \
next_scores, \
batch_size, \
num_beams, \
k);
if (k <= 4) {
BatchTopKKernelLauncher(4);
} else if (k <= 8) {
BatchTopKKernelLauncher(8);
} else if (k <= 16) {
BatchTopKKernelLauncher(16);
} else if (k <= 32) {
BatchTopKKernelLauncher(32);
} else if (k <= 64) {
BatchTopKKernelLauncher(64);
} else if (k <= 128) {
BatchTopKKernelLauncher(128);
} else {
BatchTopKKernelLauncher(256);
}
}
template void LaunchBatchTopKKernel(const float* topk_scores,
const int32_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
float* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream);
template void LaunchBatchTopKKernel(const float* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
float* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream);
template void LaunchBatchTopKKernel(const half* topk_scores,
const int32_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream);
template void LaunchBatchTopKKernel(const half* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream);
template <typename T>
void BeamSearchTopK(
const T* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
T* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
T* tmp_values_2nd_stage,
int32_t* tmp_indices_2nd_stage,
T* output_values,
int32_t* output_tokens,
int32_t* output_indices,
hipStream_t stream) {
ORT_ENFORCE(k <= 64, "BeamSearchTopK doesn't support k > 64");
#define TopKLauncher(K) \
TopKLauncherMaxK<T, K>(input, \
batch_size, \
num_beams, \
vocab_size, \
k, tmp_values_2nd_stage, \
tmp_indices_2nd_stage, \
tmp_values_1st_stage, \
tmp_indices_1st_stage, \
stream);
if (k <= 4) {
TopKLauncher(4)
} else if (k <= 8) {
TopKLauncher(8)
} else if (k <= 16) {
TopKLauncher(16)
} else if (k <= 32) {
TopKLauncher(32)
} else {
TopKLauncher(64)
}
LaunchBatchTopKKernel(tmp_values_2nd_stage,
tmp_indices_2nd_stage,
output_indices,
output_tokens,
output_values,
batch_size,
num_beams,
k,
stream);
}
template void BeamSearchTopK(
const float* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
float* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
float* tmp_values_2st_stage,
int32_t* tmp_indices_2st_stage,
float* output_values,
int32_t* output_tokens,
int32_t* output_indices,
hipStream_t stream);
template void BeamSearchTopK(
const half* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
half* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
half* tmp_values_2st_stage,
int32_t* tmp_indices_2st_stage,
half* output_values,
int32_t* output_tokens,
int32_t* output_indices,
hipStream_t stream);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <hip/hip_runtime.h>
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T, typename I>
void LaunchBatchTopKKernel(
const T* topk_scores,
const I* topk_indices,
int32_t* next_indices,
int32_t* next_tokens,
T* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
hipStream_t stream);
template <typename T>
void BeamSearchTopK(
const T* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
T* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
T* tmp_values_2st_stage,
int32_t* tmp_indices_2st_stage,
T* output_values,
int32_t* output_tokens,
int32_t* output_indices,
hipStream_t stream);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "activations.h"
namespace onnxruntime {
namespace rocm {
#define REGISTER_ACTIVATION_VERSIONED_KERNEL(x, startver, endver, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
startver, \
endver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define REGISTER_ACTIVATION_KERNEL(x, ver, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, T) \
REGISTER_ACTIVATION_VERSIONED_KERNEL(name, startver, endver, T)
#define UNARY_ACTIVATION_OP_VERSIONED_HFD(name, startver, endver) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, float) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, double)
#define UNARY_ACTIVATION_OP_TYPED(name, ver, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_VERSIONED_HFD_WITH_BF16(name, startver, endver) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, float) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, double) \
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, BFloat16)
#define UNARY_ACTIVATION_OP_HFD(name, ver) \
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, double) \
UNARY_ACTIVATION_OP_TYPED(name, ver, BFloat16)
UNARY_ACTIVATION_OP_HFD(Elu, 6);
UNARY_ACTIVATION_OP_HFD(HardSigmoid, 6);
UNARY_ACTIVATION_OP_VERSIONED_HFD(LeakyRelu, 6, 15);
UNARY_ACTIVATION_OP_HFD(Relu, 14);
UNARY_ACTIVATION_OP_VERSIONED_HFD_WITH_BF16(Relu, 13, 13);
UNARY_ACTIVATION_OP_VERSIONED_HFD(Relu, 6, 12);
UNARY_ACTIVATION_OP_HFD(Selu, 6);
UNARY_ACTIVATION_OP_HFD(Sigmoid, 13);
UNARY_ACTIVATION_OP_VERSIONED_HFD(Sigmoid, 6, 12);
UNARY_ACTIVATION_OP_HFD(Softplus, 1);
UNARY_ACTIVATION_OP_HFD(Softsign, 1);
UNARY_ACTIVATION_OP_HFD(Tanh, 13);
UNARY_ACTIVATION_OP_VERSIONED_HFD(Tanh, 6, 12);
UNARY_ACTIVATION_OP_HFD(ThresholdedRelu, 10);
// Opset-16 adds BFloat16 to allowed types for the LeakyRelu operator
UNARY_ACTIVATION_OP_HFD(LeakyRelu, 16);
} // namespace rocm
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/unary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "activations_impl.h"
namespace onnxruntime {
namespace rocm {
#define MAKE_FUNC_CTX_ALPHA() \
inline CtxAlpha MakeFuncCtx() const { \
CtxAlpha ctx; \
ctx.alpha = alpha_; \
return ctx; \
}
#define MAKE_FUNC_CTX_ALPHA_BETA() \
inline CtxAlphaBeta MakeFuncCtx() const { \
CtxAlphaBeta ctx; \
ctx.alpha = alpha_; \
ctx.beta = beta_; \
return ctx; \
}
#define MAKE_FUNC_CTX_ALPHA_GAMMA() \
inline CtxAlphaGamma MakeFuncCtx() const { \
CtxAlphaGamma ctx; \
ctx.alpha = alpha_; \
ctx.gamma = gamma_; \
return ctx; \
}
#define MAKE_FUNC_CTX_NULL() \
inline CtxNull MakeFuncCtx() const { \
CtxNull ctx; \
return ctx; \
}
template <typename T>
class Elu final : public UnaryElementwise {
public:
Elu(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
template <typename T>
class HardSigmoid final : public UnaryElementwise {
public:
HardSigmoid(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class LeakyRelu final : public UnaryElementwise {
public:
LeakyRelu(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
template <typename T>
class Relu final : public UnaryElementwise {
public:
Relu(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class Selu final : public UnaryElementwise {
public:
Selu(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("gamma", &gamma_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_GAMMA()
float alpha_;
float gamma_;
};
template <typename T>
class Sigmoid final : public UnaryElementwise {
public:
Sigmoid(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class Softplus final : public UnaryElementwise {
public:
Softplus(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class Softsign final : public UnaryElementwise {
public:
Softsign(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class Tanh final : public UnaryElementwise {
public:
Tanh(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class ThresholdedRelu final : public UnaryElementwise {
public:
ThresholdedRelu(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
} // namespace rocm
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment